GenerativeModel.swift 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import Foundation
  17. /// A type that represents a remote multimodal model (like Gemini), with the ability to generate
  18. /// content based on various input types.
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. public final class GenerativeModel: Sendable {
  21. /// Model name prefix to identify Gemini models.
  22. static let geminiModelNamePrefix = "gemini-"
  23. /// Model name prefix to identify Gemma models.
  24. static let gemmaModelNamePrefix = "gemma-"
  25. /// The name of the model, for example "gemini-2.0-flash".
  26. let modelName: String
  27. /// The model resource name corresponding with `modelName` in the backend.
  28. let modelResourceName: String
  29. /// Configuration for the backend API used by this model.
  30. let apiConfig: APIConfig
  31. /// The backing service responsible for sending and receiving model requests to the backend.
  32. let generativeAIService: GenerativeAIService
  33. /// Configuration parameters used for the MultiModalModel.
  34. let generationConfig: GenerationConfig?
  35. /// The safety settings to be used for prompts.
  36. let safetySettings: [SafetySetting]?
  37. /// A list of tools the model may use to generate the next response.
  38. let tools: [Tool]?
  39. /// Tool configuration for any `Tool` specified in the request.
  40. let toolConfig: ToolConfig?
  41. /// Instructions that direct the model to behave a certain way.
  42. let systemInstruction: ModelContent?
  43. /// Configuration parameters for sending requests to the backend.
  44. let requestOptions: RequestOptions
  45. /// Initializes a new remote model with the given parameters.
  46. ///
  47. /// - Parameters:
  48. /// - modelName: The name of the model, for example "gemini-2.0-flash".
  49. /// - modelResourceName: The model resource name corresponding with `modelName` in the backend.
  50. /// The form depends on the backend and will be one of:
  51. /// - Vertex AI via Firebase AI SDK:
  52. /// `"projects/{projectID}/locations/{locationID}/publishers/google/models/{modelName}"`
  53. /// - Developer API via Firebase AI SDK: `"projects/{projectID}/models/{modelName}"`
  54. /// - Developer API via Generative Language: `"models/{modelName}"`
  55. /// - firebaseInfo: Firebase data used by the SDK, including project ID and API key.
  56. /// - apiConfig: Configuration for the backend API used by this model.
  57. /// - generationConfig: The content generation parameters your model should use.
  58. /// - safetySettings: A value describing what types of harmful content your model should allow.
  59. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
  60. /// - toolConfig: Tool configuration for any `Tool` specified in the request.
  61. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently
  62. /// only text content is supported.
  63. /// - requestOptions: Configuration parameters for sending requests to the backend.
  64. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
  65. init(modelName: String,
  66. modelResourceName: String,
  67. firebaseInfo: FirebaseInfo,
  68. apiConfig: APIConfig,
  69. generationConfig: GenerationConfig? = nil,
  70. safetySettings: [SafetySetting]? = nil,
  71. tools: [Tool]?,
  72. toolConfig: ToolConfig? = nil,
  73. systemInstruction: ModelContent? = nil,
  74. requestOptions: RequestOptions,
  75. urlSession: URLSession = GenAIURLSession.default) {
  76. self.modelName = modelName
  77. self.modelResourceName = modelResourceName
  78. self.apiConfig = apiConfig
  79. generativeAIService = GenerativeAIService(
  80. firebaseInfo: firebaseInfo,
  81. urlSession: urlSession
  82. )
  83. self.generationConfig = generationConfig
  84. self.safetySettings = safetySettings
  85. self.tools = tools
  86. self.toolConfig = toolConfig
  87. self.systemInstruction = systemInstruction.map {
  88. // The `role` defaults to "user" but is ignored in system instructions. However, it is
  89. // erroneously counted towards the prompt and total token count in `countTokens` when using
  90. // the Developer API backend; set to `nil` to avoid token count discrepancies between
  91. // `countTokens` and `generateContent`.
  92. ModelContent(role: nil, parts: $0.parts)
  93. }
  94. self.requestOptions = requestOptions
  95. if AILog.additionalLoggingEnabled() {
  96. AILog.debug(code: .verboseLoggingEnabled, "Verbose logging enabled.")
  97. } else {
  98. AILog.info(code: .verboseLoggingDisabled, """
  99. [FirebaseVertexAI] To enable additional logging, add \
  100. `\(AILog.enableArgumentKey)` as a launch argument in Xcode.
  101. """)
  102. }
  103. AILog.debug(code: .generativeModelInitialized, "Model \(modelResourceName) initialized.")
  104. }
  105. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  106. /// representable as one or more ``Part``s.
  107. ///
  108. /// Since ``Part``s do not specify a role, this method is intended for generating content from
  109. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  110. /// or "direct" prompts. For
  111. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  112. /// prompts, see `generateContent(_ content: [ModelContent])`.
  113. ///
  114. /// - Parameters:
  115. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  116. /// conforming types).
  117. /// - Returns: The content generated by the model.
  118. /// - Throws: A ``GenerateContentError`` if the request failed.
  119. public func generateContent(_ parts: any PartsRepresentable...)
  120. async throws -> GenerateContentResponse {
  121. return try await generateContent([ModelContent(parts: parts)])
  122. }
  123. public func generateContent(@ModelContentBuilder _ contentBuilder: ()
  124. -> ModelContent) async throws -> GenerateContentResponse {
  125. let content = contentBuilder()
  126. return try await generateContent([content])
  127. }
  128. /// Generates new content from input content given to the model as a prompt.
  129. ///
  130. /// - Parameter content: The input(s) given to the model as a prompt.
  131. /// - Returns: The generated content response from the model.
  132. /// - Throws: A ``GenerateContentError`` if the request failed.
  133. public func generateContent(_ content: [ModelContent]) async throws
  134. -> GenerateContentResponse {
  135. try content.throwIfError()
  136. let response: GenerateContentResponse
  137. let generateContentRequest = GenerateContentRequest(
  138. model: modelResourceName,
  139. contents: content,
  140. generationConfig: generationConfig,
  141. safetySettings: safetySettings,
  142. tools: tools,
  143. toolConfig: toolConfig,
  144. systemInstruction: systemInstruction,
  145. apiConfig: apiConfig,
  146. apiMethod: .generateContent,
  147. options: requestOptions
  148. )
  149. do {
  150. response = try await generativeAIService.loadRequest(request: generateContentRequest)
  151. } catch {
  152. throw GenerativeModel.generateContentError(from: error)
  153. }
  154. // Check the prompt feedback to see if the prompt was blocked.
  155. if response.promptFeedback?.blockReason != nil {
  156. throw GenerateContentError.promptBlocked(response: response)
  157. }
  158. // Check to see if an error should be thrown for stop reason.
  159. if let reason = response.candidates.first?.finishReason, reason != .stop {
  160. throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
  161. }
  162. return response
  163. }
  164. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  165. /// representable as one or more ``Part``s.
  166. ///
  167. /// Since ``Part``s do not specify a role, this method is intended for generating content from
  168. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  169. /// or "direct" prompts. For
  170. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  171. /// prompts, see `generateContentStream(_ content: @autoclosure () throws -> [ModelContent])`.
  172. ///
  173. /// - Parameters:
  174. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  175. /// conforming types).
  176. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  177. /// error if an error occurred.
  178. @available(macOS 12.0, *)
  179. public func generateContentStream(_ parts: any PartsRepresentable...) throws
  180. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  181. return try generateContentStream([ModelContent(parts: parts)])
  182. }
  183. /// Generates new content from input content given to the model as a prompt.
  184. ///
  185. /// - Parameter content: The input(s) given to the model as a prompt.
  186. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  187. /// error if an error occurred.
  188. @available(macOS 12.0, *)
  189. public func generateContentStream(_ content: [ModelContent]) throws
  190. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  191. try content.throwIfError()
  192. let generateContentRequest = GenerateContentRequest(
  193. model: modelResourceName,
  194. contents: content,
  195. generationConfig: generationConfig,
  196. safetySettings: safetySettings,
  197. tools: tools,
  198. toolConfig: toolConfig,
  199. systemInstruction: systemInstruction,
  200. apiConfig: apiConfig,
  201. apiMethod: .streamGenerateContent,
  202. options: requestOptions
  203. )
  204. return AsyncThrowingStream { continuation in
  205. let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)
  206. Task {
  207. do {
  208. for try await response in responseStream {
  209. // Check the prompt feedback to see if the prompt was blocked.
  210. if response.promptFeedback?.blockReason != nil {
  211. throw GenerateContentError.promptBlocked(response: response)
  212. }
  213. // If the stream ended early unexpectedly, throw an error.
  214. if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
  215. throw GenerateContentError.responseStoppedEarly(
  216. reason: finishReason,
  217. response: response
  218. )
  219. }
  220. continuation.yield(response)
  221. }
  222. continuation.finish()
  223. } catch {
  224. continuation.finish(throwing: GenerativeModel.generateContentError(from: error))
  225. return
  226. }
  227. }
  228. }
  229. }
  230. /// Creates a new chat conversation using this model with the provided history.
  231. public func startChat(history: [ModelContent] = []) -> Chat {
  232. return Chat(model: self, history: history)
  233. }
  234. /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more
  235. /// ``Part``s.
  236. ///
  237. /// Since ``Part``s do not specify a role, this method is intended for tokenizing
  238. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  239. /// or "direct" prompts. For
  240. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  241. /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`.
  242. ///
  243. /// - Parameters:
  244. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  245. /// conforming types).
  246. /// - Returns: The results of running the model's tokenizer on the input; contains
  247. /// ``CountTokensResponse/totalTokens``.
  248. public func countTokens(_ parts: any PartsRepresentable...) async throws -> CountTokensResponse {
  249. return try await countTokens([ModelContent(parts: parts)])
  250. }
  251. /// Runs the model's tokenizer on the input content and returns the token count.
  252. ///
  253. /// - Parameter content: The input given to the model as a prompt.
  254. /// - Returns: The results of running the model's tokenizer on the input; contains
  255. /// ``CountTokensResponse/totalTokens``.
  256. public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
  257. let requestContent = switch apiConfig.service {
  258. case .vertexAI:
  259. content
  260. case .googleAI:
  261. // The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
  262. // erroneously counted towards the prompt and total token count when using the Developer API
  263. // backend; set to `nil` to avoid token count discrepancies between `countTokens` and
  264. // `generateContent` and the two backend APIs.
  265. content.map { ModelContent(role: nil, parts: $0.parts) }
  266. }
  267. // When using the Developer API via the Firebase backend, the model name of the
  268. // `GenerateContentRequest` nested in the `CountTokensRequest` must be of the form
  269. // "models/model-name". This field is unaltered by the Firebase backend before forwarding the
  270. // request to the Generative Language backend, which expects the form "models/model-name".
  271. let generateContentRequestModelResourceName = switch apiConfig.service {
  272. case .vertexAI, .googleAI(endpoint: .googleAIBypassProxy):
  273. modelResourceName
  274. case .googleAI(endpoint: .firebaseProxyProd),
  275. .googleAI(endpoint: .firebaseProxyStaging):
  276. "models/\(modelName)"
  277. }
  278. let generateContentRequest = GenerateContentRequest(
  279. model: generateContentRequestModelResourceName,
  280. contents: requestContent,
  281. generationConfig: generationConfig,
  282. safetySettings: safetySettings,
  283. tools: tools,
  284. toolConfig: toolConfig,
  285. systemInstruction: systemInstruction,
  286. apiConfig: apiConfig,
  287. apiMethod: .countTokens,
  288. options: requestOptions
  289. )
  290. let countTokensRequest = CountTokensRequest(
  291. modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
  292. )
  293. return try await generativeAIService.loadRequest(request: countTokensRequest)
  294. }
  295. /// Returns a `GenerateContentError` (for public consumption) from an internal error.
  296. ///
  297. /// If `error` is already a `GenerateContentError` the error is returned unchanged.
  298. private static func generateContentError(from error: Error) -> GenerateContentError {
  299. if let error = error as? GenerateContentError {
  300. return error
  301. }
  302. return GenerateContentError.internalError(underlying: error)
  303. }
  304. }