GenerativeModel.swift 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import Foundation
  16. /// A type that represents a remote multimodal model (like Gemini), with the ability to generate
  17. /// content based on various input types.
  18. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
  19. public final class GenerativeModel {
  20. // The prefix for a model resource in the Gemini API.
  21. private static let modelResourcePrefix = "models/"
  22. /// The resource name of the model in the backend; has the format "models/model-name".
  23. let modelResourceName: String
  24. /// The backing service responsible for sending and receiving model requests to the backend.
  25. let generativeAIService: GenerativeAIService
  26. /// Configuration parameters used for the MultiModalModel.
  27. let generationConfig: GenerationConfig?
  28. /// The safety settings to be used for prompts.
  29. let safetySettings: [SafetySetting]?
  30. /// A list of tools the model may use to generate the next response.
  31. let tools: [Tool]?
  32. /// Tool configuration for any `Tool` specified in the request.
  33. let toolConfig: ToolConfig?
  34. /// Instructions that direct the model to behave a certain way.
  35. let systemInstruction: ModelContent?
  36. /// Configuration parameters for sending requests to the backend.
  37. let requestOptions: RequestOptions
  38. /// Initializes a new remote model with the given parameters.
  39. ///
  40. /// - Parameters:
  41. /// - name: The name of the model to use, for example `"gemini-1.0-pro"`; see
  42. /// [Gemini models](https://ai.google.dev/models/gemini) for a list of supported model names.
  43. /// - apiKey: The API key for your project.
  44. /// - generationConfig: The content generation parameters your model should use.
  45. /// - safetySettings: A value describing what types of harmful content your model should allow.
  46. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
  47. /// - toolConfig: Tool configuration for any `Tool` specified in the request.
  48. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently
  49. /// only text content is supported.
  50. /// - requestOptions: Configuration parameters for sending requests to the backend.
  51. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
  52. init(name: String,
  53. apiKey: String,
  54. generationConfig: GenerationConfig? = nil,
  55. safetySettings: [SafetySetting]? = nil,
  56. tools: [Tool]?,
  57. toolConfig: ToolConfig? = nil,
  58. systemInstruction: ModelContent? = nil,
  59. requestOptions: RequestOptions,
  60. appCheck: AppCheckInterop?,
  61. urlSession: URLSession = .shared) {
  62. modelResourceName = GenerativeModel.modelResourceName(name: name)
  63. generativeAIService = GenerativeAIService(
  64. apiKey: apiKey,
  65. appCheck: appCheck,
  66. urlSession: urlSession
  67. )
  68. self.generationConfig = generationConfig
  69. self.safetySettings = safetySettings
  70. self.tools = tools
  71. self.toolConfig = toolConfig
  72. self.systemInstruction = systemInstruction
  73. self.requestOptions = requestOptions
  74. Logging.default.info("""
  75. [GoogleGenerativeAI] Model \(
  76. name,
  77. privacy: .public
  78. ) initialized. To enable additional logging, add \
  79. `\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode.
  80. """)
  81. Logging.verbose.debug("[GoogleGenerativeAI] Verbose logging enabled.")
  82. }
  83. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  84. /// representable as one or more ``ModelContent/Part``s.
  85. ///
  86. /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating
  87. /// content from
  88. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  89. /// or "direct" prompts. For
  90. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  91. /// prompts, see ``generateContent(_:)``.
  92. ///
  93. /// - Parameter content: The input(s) given to the model as a prompt (see
  94. /// ``ThrowingPartsRepresentable``
  95. /// for conforming types).
  96. /// - Returns: The content generated by the model.
  97. /// - Throws: A ``GenerateContentError`` if the request failed.
  98. public func generateContent(_ parts: any ThrowingPartsRepresentable...)
  99. async throws -> GenerateContentResponse {
  100. return try await generateContent([ModelContent(parts: parts)])
  101. }
  102. /// Generates new content from input content given to the model as a prompt.
  103. ///
  104. /// - Parameter content: The input(s) given to the model as a prompt.
  105. /// - Returns: The generated content response from the model.
  106. /// - Throws: A ``GenerateContentError`` if the request failed.
  107. public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws
  108. -> GenerateContentResponse {
  109. let response: GenerateContentResponse
  110. do {
  111. let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
  112. contents: content(),
  113. generationConfig: generationConfig,
  114. safetySettings: safetySettings,
  115. tools: tools,
  116. toolConfig: toolConfig,
  117. systemInstruction: systemInstruction,
  118. isStreaming: false,
  119. options: requestOptions)
  120. response = try await generativeAIService.loadRequest(request: generateContentRequest)
  121. } catch {
  122. if let imageError = error as? ImageConversionError {
  123. throw GenerateContentError.promptImageContentError(underlying: imageError)
  124. }
  125. throw GenerativeModel.generateContentError(from: error)
  126. }
  127. // Check the prompt feedback to see if the prompt was blocked.
  128. if response.promptFeedback?.blockReason != nil {
  129. throw GenerateContentError.promptBlocked(response: response)
  130. }
  131. // Check to see if an error should be thrown for stop reason.
  132. if let reason = response.candidates.first?.finishReason, reason != .stop {
  133. throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
  134. }
  135. return response
  136. }
  137. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  138. /// representable as one or more ``ModelContent/Part``s.
  139. ///
  140. /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating
  141. /// content from
  142. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  143. /// or "direct" prompts. For
  144. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  145. /// prompts, see ``generateContent(_:)``.
  146. ///
  147. /// - Parameter content: The input(s) given to the model as a prompt (see
  148. /// ``ThrowingPartsRepresentable``
  149. /// for conforming types).
  150. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  151. /// error if an error occurred.
  152. @available(macOS 12.0, *)
  153. public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
  154. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  155. return try generateContentStream([ModelContent(parts: parts)])
  156. }
  157. /// Generates new content from input content given to the model as a prompt.
  158. ///
  159. /// - Parameter content: The input(s) given to the model as a prompt.
  160. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  161. /// error if an error occurred.
  162. @available(macOS 12.0, *)
  163. public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
  164. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  165. let evaluatedContent: [ModelContent]
  166. do {
  167. evaluatedContent = try content()
  168. } catch let underlying {
  169. return AsyncThrowingStream { continuation in
  170. let error: Error
  171. if let contentError = underlying as? ImageConversionError {
  172. error = GenerateContentError.promptImageContentError(underlying: contentError)
  173. } else {
  174. error = GenerateContentError.internalError(underlying: underlying)
  175. }
  176. continuation.finish(throwing: error)
  177. }
  178. }
  179. let generateContentRequest = GenerateContentRequest(model: modelResourceName,
  180. contents: evaluatedContent,
  181. generationConfig: generationConfig,
  182. safetySettings: safetySettings,
  183. tools: tools,
  184. toolConfig: toolConfig,
  185. systemInstruction: systemInstruction,
  186. isStreaming: true,
  187. options: requestOptions)
  188. var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
  189. .makeAsyncIterator()
  190. return AsyncThrowingStream {
  191. let response: GenerateContentResponse?
  192. do {
  193. response = try await responseIterator.next()
  194. } catch {
  195. throw GenerativeModel.generateContentError(from: error)
  196. }
  197. // The responseIterator will return `nil` when it's done.
  198. guard let response = response else {
  199. // This is the end of the stream! Signal it by sending `nil`.
  200. return nil
  201. }
  202. // Check the prompt feedback to see if the prompt was blocked.
  203. if response.promptFeedback?.blockReason != nil {
  204. throw GenerateContentError.promptBlocked(response: response)
  205. }
  206. // If the stream ended early unexpectedly, throw an error.
  207. if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
  208. throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
  209. } else {
  210. // Response was valid content, pass it along and continue.
  211. return response
  212. }
  213. }
  214. }
  215. /// Creates a new chat conversation using this model with the provided history.
  216. public func startChat(history: [ModelContent] = []) -> Chat {
  217. return Chat(model: self, history: history)
  218. }
  219. /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more
  220. /// ``ModelContent/Part``s.
  221. ///
  222. /// Since ``ModelContent/Part``s do not specify a role, this method is intended for tokenizing
  223. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  224. /// or "direct" prompts. For
  225. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  226. /// input, see ``countTokens(_:)``.
  227. ///
  228. /// - Parameter content: The input(s) given to the model as a prompt (see
  229. /// ``ThrowingPartsRepresentable``
  230. /// for conforming types).
  231. /// - Returns: The results of running the model's tokenizer on the input; contains
  232. /// ``CountTokensResponse/totalTokens``.
  233. /// - Throws: A ``CountTokensError`` if the tokenization request failed.
  234. public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws
  235. -> CountTokensResponse {
  236. return try await countTokens([ModelContent(parts: parts)])
  237. }
  238. /// Runs the model's tokenizer on the input content and returns the token count.
  239. ///
  240. /// - Parameter content: The input given to the model as a prompt.
  241. /// - Returns: The results of running the model's tokenizer on the input; contains
  242. /// ``CountTokensResponse/totalTokens``.
  243. /// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was
  244. /// invalid.
  245. public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
  246. -> CountTokensResponse {
  247. do {
  248. let countTokensRequest = try CountTokensRequest(
  249. model: modelResourceName,
  250. contents: content(),
  251. options: requestOptions
  252. )
  253. return try await generativeAIService.loadRequest(request: countTokensRequest)
  254. } catch {
  255. throw CountTokensError.internalError(underlying: error)
  256. }
  257. }
  258. /// Returns a model resource name of the form "models/model-name" based on `name`.
  259. private static func modelResourceName(name: String) -> String {
  260. if name.contains("/") {
  261. return name
  262. } else {
  263. return modelResourcePrefix + name
  264. }
  265. }
  266. /// Returns a `GenerateContentError` (for public consumption) from an internal error.
  267. ///
  268. /// If `error` is already a `GenerateContentError` the error is returned unchanged.
  269. private static func generateContentError(from error: Error) -> GenerateContentError {
  270. if let error = error as? GenerateContentError {
  271. return error
  272. }
  273. return GenerateContentError.internalError(underlying: error)
  274. }
  275. }
  276. /// See ``GenerativeModel/countTokens(_:)``.
  277. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
  278. public enum CountTokensError: Error {
  279. case internalError(underlying: Error)
  280. }