GenerativeModel.swift 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import Foundation
  17. /// A type that represents a remote multimodal model (like Gemini), with the ability to generate
  18. /// content based on various input types.
  19. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
  20. public final class GenerativeModel {
  21. // The prefix for a model resource in the Gemini API.
  22. private static let modelResourcePrefix = "models/"
  23. /// The resource name of the model in the backend; has the format "models/model-name".
  24. let modelResourceName: String
  25. /// The backing service responsible for sending and receiving model requests to the backend.
  26. let generativeAIService: GenerativeAIService
  27. /// Configuration parameters used for the MultiModalModel.
  28. let generationConfig: GenerationConfig?
  29. /// The safety settings to be used for prompts.
  30. let safetySettings: [SafetySetting]?
  31. /// A list of tools the model may use to generate the next response.
  32. let tools: [Tool]?
  33. /// Tool configuration for any `Tool` specified in the request.
  34. let toolConfig: ToolConfig?
  35. /// Instructions that direct the model to behave a certain way.
  36. let systemInstruction: ModelContent?
  37. /// Configuration parameters for sending requests to the backend.
  38. let requestOptions: RequestOptions
  39. /// Initializes a new remote model with the given parameters.
  40. ///
  41. /// - Parameters:
  42. /// - name: The name of the model to use, for example `"gemini-1.0-pro"`.
  43. /// - projectID: The project ID from the Firebase console.
  44. /// - apiKey: The API key for your project.
  45. /// - generationConfig: The content generation parameters your model should use.
  46. /// - safetySettings: A value describing what types of harmful content your model should allow.
  47. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
  48. /// - toolConfig: Tool configuration for any `Tool` specified in the request.
  49. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently
  50. /// only text content is supported.
  51. /// - requestOptions: Configuration parameters for sending requests to the backend.
  52. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
  53. init(name: String,
  54. projectID: String,
  55. apiKey: String,
  56. generationConfig: GenerationConfig? = nil,
  57. safetySettings: [SafetySetting]? = nil,
  58. tools: [Tool]?,
  59. toolConfig: ToolConfig? = nil,
  60. systemInstruction: ModelContent? = nil,
  61. requestOptions: RequestOptions,
  62. appCheck: AppCheckInterop?,
  63. auth: AuthInterop?,
  64. urlSession: URLSession = .shared) {
  65. modelResourceName = GenerativeModel.modelResourceName(name: name)
  66. generativeAIService = GenerativeAIService(
  67. projectID: projectID,
  68. apiKey: apiKey,
  69. appCheck: appCheck,
  70. auth: auth,
  71. urlSession: urlSession
  72. )
  73. self.generationConfig = generationConfig
  74. self.safetySettings = safetySettings
  75. self.tools = tools
  76. self.toolConfig = toolConfig
  77. self.systemInstruction = systemInstruction
  78. self.requestOptions = requestOptions
  79. if Logging.additionalLoggingEnabled() {
  80. if ProcessInfo.processInfo.arguments.contains(Logging.migrationEnableArgumentKey) {
  81. Logging.verbose.debug("""
  82. [FirebaseVertexAI] Verbose logging enabled with the \
  83. \(Logging.migrationEnableArgumentKey, privacy: .public) launch argument; please migrate to \
  84. the \(Logging.enableArgumentKey, privacy: .public) argument to ensure future compatibility.
  85. """)
  86. } else {
  87. Logging.verbose.debug("[FirebaseVertexAI] Verbose logging enabled.")
  88. }
  89. } else {
  90. Logging.default.info("""
  91. [FirebaseVertexAI] To enable additional logging, add \
  92. `\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode.
  93. """)
  94. }
  95. Logging.default.debug("[FirebaseVertexAI] Model \(name, privacy: .public) initialized.")
  96. }
  97. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  98. /// representable as one or more ``ModelContent/Part``s.
  99. ///
  100. /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating
  101. /// content from
  102. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  103. /// or "direct" prompts. For
  104. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  105. /// prompts, see `generateContent(_ content: @autoclosure () throws -> [ModelContent])`.
  106. ///
  107. /// - Parameter content: The input(s) given to the model as a prompt (see
  108. /// ``ThrowingPartsRepresentable``
  109. /// for conforming types).
  110. /// - Returns: The content generated by the model.
  111. /// - Throws: A ``GenerateContentError`` if the request failed.
  112. public func generateContent(_ parts: any ThrowingPartsRepresentable...)
  113. async throws -> GenerateContentResponse {
  114. return try await generateContent([ModelContent(parts: parts)])
  115. }
  116. /// Generates new content from input content given to the model as a prompt.
  117. ///
  118. /// - Parameter content: The input(s) given to the model as a prompt.
  119. /// - Returns: The generated content response from the model.
  120. /// - Throws: A ``GenerateContentError`` if the request failed.
  121. public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws
  122. -> GenerateContentResponse {
  123. let response: GenerateContentResponse
  124. do {
  125. let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
  126. contents: content(),
  127. generationConfig: generationConfig,
  128. safetySettings: safetySettings,
  129. tools: tools,
  130. toolConfig: toolConfig,
  131. systemInstruction: systemInstruction,
  132. isStreaming: false,
  133. options: requestOptions)
  134. response = try await generativeAIService.loadRequest(request: generateContentRequest)
  135. } catch {
  136. if let imageError = error as? ImageConversionError {
  137. throw GenerateContentError.promptImageContentError(underlying: imageError)
  138. }
  139. throw GenerativeModel.generateContentError(from: error)
  140. }
  141. // Check the prompt feedback to see if the prompt was blocked.
  142. if response.promptFeedback?.blockReason != nil {
  143. throw GenerateContentError.promptBlocked(response: response)
  144. }
  145. // Check to see if an error should be thrown for stop reason.
  146. if let reason = response.candidates.first?.finishReason, reason != .stop {
  147. throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
  148. }
  149. return response
  150. }
  151. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  152. /// representable as one or more ``ModelContent/Part``s.
  153. ///
  154. /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating
  155. /// content from
  156. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  157. /// or "direct" prompts. For
  158. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  159. /// prompts, see `generateContentStream(_ content: @autoclosure () throws -> [ModelContent])`.
  160. ///
  161. /// - Parameter content: The input(s) given to the model as a prompt (see
  162. /// ``ThrowingPartsRepresentable``
  163. /// for conforming types).
  164. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  165. /// error if an error occurred.
  166. @available(macOS 12.0, *)
  167. public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
  168. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  169. return try generateContentStream([ModelContent(parts: parts)])
  170. }
  171. /// Generates new content from input content given to the model as a prompt.
  172. ///
  173. /// - Parameter content: The input(s) given to the model as a prompt.
  174. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  175. /// error if an error occurred.
  176. @available(macOS 12.0, *)
  177. public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
  178. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  179. let evaluatedContent: [ModelContent]
  180. do {
  181. evaluatedContent = try content()
  182. } catch let underlying {
  183. return AsyncThrowingStream { continuation in
  184. let error: Error
  185. if let contentError = underlying as? ImageConversionError {
  186. error = GenerateContentError.promptImageContentError(underlying: contentError)
  187. } else {
  188. error = GenerateContentError.internalError(underlying: underlying)
  189. }
  190. continuation.finish(throwing: error)
  191. }
  192. }
  193. let generateContentRequest = GenerateContentRequest(model: modelResourceName,
  194. contents: evaluatedContent,
  195. generationConfig: generationConfig,
  196. safetySettings: safetySettings,
  197. tools: tools,
  198. toolConfig: toolConfig,
  199. systemInstruction: systemInstruction,
  200. isStreaming: true,
  201. options: requestOptions)
  202. var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
  203. .makeAsyncIterator()
  204. return AsyncThrowingStream {
  205. let response: GenerateContentResponse?
  206. do {
  207. response = try await responseIterator.next()
  208. } catch {
  209. throw GenerativeModel.generateContentError(from: error)
  210. }
  211. // The responseIterator will return `nil` when it's done.
  212. guard let response = response else {
  213. // This is the end of the stream! Signal it by sending `nil`.
  214. return nil
  215. }
  216. // Check the prompt feedback to see if the prompt was blocked.
  217. if response.promptFeedback?.blockReason != nil {
  218. throw GenerateContentError.promptBlocked(response: response)
  219. }
  220. // If the stream ended early unexpectedly, throw an error.
  221. if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
  222. throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
  223. } else {
  224. // Response was valid content, pass it along and continue.
  225. return response
  226. }
  227. }
  228. }
  229. /// Creates a new chat conversation using this model with the provided history.
  230. public func startChat(history: [ModelContent] = []) -> Chat {
  231. return Chat(model: self, history: history)
  232. }
  233. /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more
  234. /// ``ModelContent/Part``s.
  235. ///
  236. /// Since ``ModelContent/Part``s do not specify a role, this method is intended for tokenizing
  237. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  238. /// or "direct" prompts. For
  239. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  240. /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`.
  241. ///
  242. /// - Parameter content: The input(s) given to the model as a prompt (see
  243. /// ``ThrowingPartsRepresentable``
  244. /// for conforming types).
  245. /// - Returns: The results of running the model's tokenizer on the input; contains
  246. /// ``CountTokensResponse/totalTokens``.
  247. /// - Throws: A ``CountTokensError`` if the tokenization request failed.
  248. public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws
  249. -> CountTokensResponse {
  250. return try await countTokens([ModelContent(parts: parts)])
  251. }
  252. /// Runs the model's tokenizer on the input content and returns the token count.
  253. ///
  254. /// - Parameter content: The input given to the model as a prompt.
  255. /// - Returns: The results of running the model's tokenizer on the input; contains
  256. /// ``CountTokensResponse/totalTokens``.
  257. /// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was
  258. /// invalid.
  259. public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
  260. -> CountTokensResponse {
  261. do {
  262. let countTokensRequest = try CountTokensRequest(
  263. model: modelResourceName,
  264. contents: content(),
  265. options: requestOptions
  266. )
  267. return try await generativeAIService.loadRequest(request: countTokensRequest)
  268. } catch {
  269. throw CountTokensError.internalError(underlying: error)
  270. }
  271. }
  272. /// Returns a model resource name of the form "models/model-name" based on `name`.
  273. private static func modelResourceName(name: String) -> String {
  274. if name.contains("/") {
  275. return name
  276. } else {
  277. return modelResourcePrefix + name
  278. }
  279. }
  280. /// Returns a `GenerateContentError` (for public consumption) from an internal error.
  281. ///
  282. /// If `error` is already a `GenerateContentError` the error is returned unchanged.
  283. private static func generateContentError(from error: Error) -> GenerateContentError {
  284. if let error = error as? GenerateContentError {
  285. return error
  286. }
  287. return GenerateContentError.internalError(underlying: error)
  288. }
  289. }
  290. /// An error thrown in `GenerativeModel.countTokens(_:)`.
  291. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
  292. public enum CountTokensError: Error {
  293. case internalError(underlying: Error)
  294. }