| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- // Copyright 2023 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- import FirebaseAppCheckInterop
- import FirebaseAuthInterop
- import Foundation
- /// A type that represents a remote multimodal model (like Gemini), with the ability to generate
- /// content based on various input types.
- @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
- public final class GenerativeModel: Sendable {
- /// Model name prefix to identify Gemini models.
- static let geminiModelNamePrefix = "gemini-"
- /// The name of the model, for example "gemini-2.0-flash".
- let modelName: String
- /// The model resource name corresponding with `modelName` in the backend.
- let modelResourceName: String
- /// Configuration for the backend API used by this model.
- let apiConfig: APIConfig
- /// The backing service responsible for sending and receiving model requests to the backend.
- let generativeAIService: GenerativeAIService
- /// Configuration parameters used for the MultiModalModel.
- let generationConfig: GenerationConfig?
- /// The safety settings to be used for prompts.
- let safetySettings: [SafetySetting]?
- /// A list of tools the model may use to generate the next response.
- let tools: [Tool]?
- /// Tool configuration for any `Tool` specified in the request.
- let toolConfig: ToolConfig?
- /// Instructions that direct the model to behave a certain way.
- let systemInstruction: ModelContent?
- /// Configuration parameters for sending requests to the backend.
- let requestOptions: RequestOptions
- /// Initializes a new remote model with the given parameters.
- ///
- /// - Parameters:
- /// - modelName: The name of the model, for example "gemini-2.0-flash".
- /// - modelResourceName: The model resource name corresponding with `modelName` in the backend.
- /// The form depends on the backend and will be one of:
- /// - Vertex AI via Vertex AI in Firebase:
- /// `"projects/{projectID}/locations/{locationID}/publishers/google/models/{modelName}"`
- /// - Developer API via Vertex AI in Firebase: `"projects/{projectID}/models/{modelName}"`
- /// - Developer API via Generative Language: `"models/{modelName}"`
- /// - firebaseInfo: Firebase data used by the SDK, including project ID and API key.
- /// - apiConfig: Configuration for the backend API used by this model.
- /// - generationConfig: The content generation parameters your model should use.
- /// - safetySettings: A value describing what types of harmful content your model should allow.
- /// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
- /// - toolConfig: Tool configuration for any `Tool` specified in the request.
- /// - systemInstruction: Instructions that direct the model to behave a certain way; currently
- /// only text content is supported.
- /// - requestOptions: Configuration parameters for sending requests to the backend.
- /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
- init(modelName: String,
- modelResourceName: String,
- firebaseInfo: FirebaseInfo,
- apiConfig: APIConfig,
- generationConfig: GenerationConfig? = nil,
- safetySettings: [SafetySetting]? = nil,
- tools: [Tool]?,
- toolConfig: ToolConfig? = nil,
- systemInstruction: ModelContent? = nil,
- requestOptions: RequestOptions,
- urlSession: URLSession = GenAIURLSession.default) {
- self.modelName = modelName
- self.modelResourceName = modelResourceName
- self.apiConfig = apiConfig
- generativeAIService = GenerativeAIService(
- firebaseInfo: firebaseInfo,
- urlSession: urlSession
- )
- self.generationConfig = generationConfig
- self.safetySettings = safetySettings
- self.tools = tools
- self.toolConfig = toolConfig
- self.systemInstruction = systemInstruction.map {
- // The `role` defaults to "user" but is ignored in system instructions. However, it is
- // erroneously counted towards the prompt and total token count in `countTokens` when using
- // the Developer API backend; set to `nil` to avoid token count discrepancies between
- // `countTokens` and `generateContent`.
- ModelContent(role: nil, parts: $0.parts)
- }
- self.requestOptions = requestOptions
- if AILog.additionalLoggingEnabled() {
- AILog.debug(code: .verboseLoggingEnabled, "Verbose logging enabled.")
- } else {
- AILog.info(code: .verboseLoggingDisabled, """
- [FirebaseVertexAI] To enable additional logging, add \
- `\(AILog.enableArgumentKey)` as a launch argument in Xcode.
- """)
- }
- AILog.debug(code: .generativeModelInitialized, "Model \(modelResourceName) initialized.")
- }
- /// Generates content from String and/or image inputs, given to the model as a prompt, that are
- /// representable as one or more ``Part``s.
- ///
- /// Since ``Part``s do not specify a role, this method is intended for generating content from
- /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
- /// or "direct" prompts. For
- /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
- /// prompts, see `generateContent(_ content: [ModelContent])`.
- ///
- /// - Parameters:
- /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
- /// conforming types).
- /// - Returns: The content generated by the model.
- /// - Throws: A ``GenerateContentError`` if the request failed.
- public func generateContent(_ parts: any PartsRepresentable...)
- async throws -> GenerateContentResponse {
- return try await generateContent([ModelContent(parts: parts)])
- }
- /// Generates new content from input content given to the model as a prompt.
- ///
- /// - Parameter content: The input(s) given to the model as a prompt.
- /// - Returns: The generated content response from the model.
- /// - Throws: A ``GenerateContentError`` if the request failed.
- public func generateContent(_ content: [ModelContent]) async throws
- -> GenerateContentResponse {
- try content.throwIfError()
- let response: GenerateContentResponse
- let generateContentRequest = GenerateContentRequest(
- model: modelResourceName,
- contents: content,
- generationConfig: generationConfig,
- safetySettings: safetySettings,
- tools: tools,
- toolConfig: toolConfig,
- systemInstruction: systemInstruction,
- apiConfig: apiConfig,
- apiMethod: .generateContent,
- options: requestOptions
- )
- do {
- response = try await generativeAIService.loadRequest(request: generateContentRequest)
- } catch {
- throw GenerativeModel.generateContentError(from: error)
- }
- // Check the prompt feedback to see if the prompt was blocked.
- if response.promptFeedback?.blockReason != nil {
- throw GenerateContentError.promptBlocked(response: response)
- }
- // Check to see if an error should be thrown for stop reason.
- if let reason = response.candidates.first?.finishReason, reason != .stop {
- throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
- }
- return response
- }
- /// Generates content from String and/or image inputs, given to the model as a prompt, that are
- /// representable as one or more ``Part``s.
- ///
- /// Since ``Part``s do not specify a role, this method is intended for generating content from
- /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
- /// or "direct" prompts. For
- /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
- /// prompts, see `generateContentStream(_ content: @autoclosure () throws -> [ModelContent])`.
- ///
- /// - Parameters:
- /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
- /// conforming types).
- /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
- /// error if an error occurred.
- @available(macOS 12.0, *)
- public func generateContentStream(_ parts: any PartsRepresentable...) throws
- -> AsyncThrowingStream<GenerateContentResponse, Error> {
- return try generateContentStream([ModelContent(parts: parts)])
- }
- /// Generates new content from input content given to the model as a prompt.
- ///
- /// - Parameter content: The input(s) given to the model as a prompt.
- /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
- /// error if an error occurred.
- @available(macOS 12.0, *)
- public func generateContentStream(_ content: [ModelContent]) throws
- -> AsyncThrowingStream<GenerateContentResponse, Error> {
- try content.throwIfError()
- let generateContentRequest = GenerateContentRequest(
- model: modelResourceName,
- contents: content,
- generationConfig: generationConfig,
- safetySettings: safetySettings,
- tools: tools,
- toolConfig: toolConfig,
- systemInstruction: systemInstruction,
- apiConfig: apiConfig,
- apiMethod: .streamGenerateContent,
- options: requestOptions
- )
- return AsyncThrowingStream { continuation in
- let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)
- Task {
- do {
- for try await response in responseStream {
- // Check the prompt feedback to see if the prompt was blocked.
- if response.promptFeedback?.blockReason != nil {
- throw GenerateContentError.promptBlocked(response: response)
- }
- // If the stream ended early unexpectedly, throw an error.
- if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
- throw GenerateContentError.responseStoppedEarly(
- reason: finishReason,
- response: response
- )
- }
- continuation.yield(response)
- }
- continuation.finish()
- } catch {
- continuation.finish(throwing: GenerativeModel.generateContentError(from: error))
- return
- }
- }
- }
- }
- /// Creates a new chat conversation using this model with the provided history.
- public func startChat(history: [ModelContent] = []) -> Chat {
- return Chat(model: self, history: history)
- }
- /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more
- /// ``Part``s.
- ///
- /// Since ``Part``s do not specify a role, this method is intended for tokenizing
- /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
- /// or "direct" prompts. For
- /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
- /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`.
- ///
- /// - Parameters:
- /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
- /// conforming types).
- /// - Returns: The results of running the model's tokenizer on the input; contains
- /// ``CountTokensResponse/totalTokens``.
- public func countTokens(_ parts: any PartsRepresentable...) async throws -> CountTokensResponse {
- return try await countTokens([ModelContent(parts: parts)])
- }
- /// Runs the model's tokenizer on the input content and returns the token count.
- ///
- /// - Parameter content: The input given to the model as a prompt.
- /// - Returns: The results of running the model's tokenizer on the input; contains
- /// ``CountTokensResponse/totalTokens``.
- public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
- let requestContent = switch apiConfig.service {
- case .vertexAI:
- content
- case .developer:
- // The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
- // erroneously counted towards the prompt and total token count when using the Developer API
- // backend; set to `nil` to avoid token count discrepancies between `countTokens` and
- // `generateContent` and the two backend APIs.
- content.map { ModelContent(role: nil, parts: $0.parts) }
- }
- // When using the Developer API via the Firebase backend, the model name of the
- // `GenerateContentRequest` nested in the `CountTokensRequest` must be of the form
- // "models/model-name". This field is unaltered by the Firebase backend before forwarding the
- // request to the Generative Language backend, which expects the form "models/model-name".
- let generateContentRequestModelResourceName = switch apiConfig.service {
- case .vertexAI, .developer(endpoint: .generativeLanguage):
- modelResourceName
- case .developer(endpoint: .firebaseVertexAIProd),
- .developer(endpoint: .firebaseVertexAIStaging):
- "models/\(modelName)"
- }
- let generateContentRequest = GenerateContentRequest(
- model: generateContentRequestModelResourceName,
- contents: requestContent,
- generationConfig: generationConfig,
- safetySettings: safetySettings,
- tools: tools,
- toolConfig: toolConfig,
- systemInstruction: systemInstruction,
- apiConfig: apiConfig,
- apiMethod: .countTokens,
- options: requestOptions
- )
- let countTokensRequest = CountTokensRequest(
- modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
- )
- return try await generativeAIService.loadRequest(request: countTokensRequest)
- }
- /// Returns a `GenerateContentError` (for public consumption) from an internal error.
- ///
- /// If `error` is already a `GenerateContentError` the error is returned unchanged.
- private static func generateContentError(from error: Error) -> GenerateContentError {
- if let error = error as? GenerateContentError {
- return error
- }
- return GenerateContentError.internalError(underlying: error)
- }
- }
|