GenerativeModel.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import Foundation
  17. /// A type that represents a remote multimodal model (like Gemini), with the ability to generate
  18. /// content based on various input types.
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. public final class GenerativeModel: Sendable {
  21. /// The resource name of the model in the backend; has the format "models/model-name".
  22. let modelResourceName: String
  23. /// Configuration for the backend API used by this model.
  24. let apiConfig: APIConfig
  25. /// The backing service responsible for sending and receiving model requests to the backend.
  26. let generativeAIService: GenerativeAIService
  27. /// Configuration parameters used for the MultiModalModel.
  28. let generationConfig: GenerationConfig?
  29. /// The safety settings to be used for prompts.
  30. let safetySettings: [SafetySetting]?
  31. /// A list of tools the model may use to generate the next response.
  32. let tools: [Tool]?
  33. /// Tool configuration for any `Tool` specified in the request.
  34. let toolConfig: ToolConfig?
  35. /// Instructions that direct the model to behave a certain way.
  36. let systemInstruction: ModelContent?
  37. /// Configuration parameters for sending requests to the backend.
  38. let requestOptions: RequestOptions
  39. /// Initializes a new remote model with the given parameters.
  40. ///
  41. /// - Parameters:
  42. /// - name: The name of the model to use, for example `"gemini-1.0-pro"`.
  43. /// - firebaseInfo: Firebase data used by the SDK, including project ID and API key.
  44. /// - apiConfig: Configuration for the backend API used by this model.
  45. /// - generationConfig: The content generation parameters your model should use.
  46. /// - safetySettings: A value describing what types of harmful content your model should allow.
  47. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
  48. /// - toolConfig: Tool configuration for any `Tool` specified in the request.
  49. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently
  50. /// only text content is supported.
  51. /// - requestOptions: Configuration parameters for sending requests to the backend.
  52. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
  53. init(name: String,
  54. firebaseInfo: FirebaseInfo,
  55. apiConfig: APIConfig,
  56. generationConfig: GenerationConfig? = nil,
  57. safetySettings: [SafetySetting]? = nil,
  58. tools: [Tool]?,
  59. toolConfig: ToolConfig? = nil,
  60. systemInstruction: ModelContent? = nil,
  61. requestOptions: RequestOptions,
  62. urlSession: URLSession = .shared) {
  63. modelResourceName = name
  64. self.apiConfig = apiConfig
  65. generativeAIService = GenerativeAIService(
  66. firebaseInfo: firebaseInfo,
  67. apiConfig: apiConfig,
  68. urlSession: urlSession
  69. )
  70. self.generationConfig = generationConfig
  71. self.safetySettings = safetySettings
  72. self.tools = tools
  73. self.toolConfig = toolConfig
  74. self.systemInstruction = systemInstruction.map {
  75. // The `role` defaults to "user" but is ignored in system instructions. However, it is
  76. // erroneously counted towards the prompt and total token count in `countTokens` when using
  77. // the Developer API backend; set to `nil` to avoid token count discrepancies between
  78. // `countTokens` and `generateContent`.
  79. ModelContent(role: nil, parts: $0.parts)
  80. }
  81. self.requestOptions = requestOptions
  82. if VertexLog.additionalLoggingEnabled() {
  83. VertexLog.debug(code: .verboseLoggingEnabled, "Verbose logging enabled.")
  84. } else {
  85. VertexLog.info(code: .verboseLoggingDisabled, """
  86. [FirebaseVertexAI] To enable additional logging, add \
  87. `\(VertexLog.enableArgumentKey)` as a launch argument in Xcode.
  88. """)
  89. }
  90. VertexLog.debug(code: .generativeModelInitialized, "Model \(name) initialized.")
  91. }
  92. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  93. /// representable as one or more ``Part``s.
  94. ///
  95. /// Since ``Part``s do not specify a role, this method is intended for generating content from
  96. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  97. /// or "direct" prompts. For
  98. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  99. /// prompts, see `generateContent(_ content: [ModelContent])`.
  100. ///
  101. /// - Parameters:
  102. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  103. /// conforming types).
  104. /// - Returns: The content generated by the model.
  105. /// - Throws: A ``GenerateContentError`` if the request failed.
  106. public func generateContent(_ parts: any PartsRepresentable...)
  107. async throws -> GenerateContentResponse {
  108. return try await generateContent([ModelContent(parts: parts)])
  109. }
  110. /// Generates new content from input content given to the model as a prompt.
  111. ///
  112. /// - Parameter content: The input(s) given to the model as a prompt.
  113. /// - Returns: The generated content response from the model.
  114. /// - Throws: A ``GenerateContentError`` if the request failed.
  115. public func generateContent(_ content: [ModelContent]) async throws
  116. -> GenerateContentResponse {
  117. try content.throwIfError()
  118. let response: GenerateContentResponse
  119. let generateContentRequest = GenerateContentRequest(
  120. model: modelResourceName,
  121. contents: content,
  122. generationConfig: generationConfig,
  123. safetySettings: safetySettings,
  124. tools: tools,
  125. toolConfig: toolConfig,
  126. systemInstruction: systemInstruction,
  127. apiMethod: .generateContent,
  128. options: requestOptions
  129. )
  130. do {
  131. response = try await generativeAIService.loadRequest(request: generateContentRequest)
  132. } catch {
  133. throw GenerativeModel.generateContentError(from: error)
  134. }
  135. // Check the prompt feedback to see if the prompt was blocked.
  136. if response.promptFeedback?.blockReason != nil {
  137. throw GenerateContentError.promptBlocked(response: response)
  138. }
  139. // Check to see if an error should be thrown for stop reason.
  140. if let reason = response.candidates.first?.finishReason, reason != .stop {
  141. throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
  142. }
  143. return response
  144. }
  145. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  146. /// representable as one or more ``Part``s.
  147. ///
  148. /// Since ``Part``s do not specify a role, this method is intended for generating content from
  149. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  150. /// or "direct" prompts. For
  151. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  152. /// prompts, see `generateContentStream(_ content: @autoclosure () throws -> [ModelContent])`.
  153. ///
  154. /// - Parameters:
  155. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  156. /// conforming types).
  157. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  158. /// error if an error occurred.
  159. @available(macOS 12.0, *)
  160. public func generateContentStream(_ parts: any PartsRepresentable...) throws
  161. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  162. return try generateContentStream([ModelContent(parts: parts)])
  163. }
  164. /// Generates new content from input content given to the model as a prompt.
  165. ///
  166. /// - Parameter content: The input(s) given to the model as a prompt.
  167. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  168. /// error if an error occurred.
  169. @available(macOS 12.0, *)
  170. public func generateContentStream(_ content: [ModelContent]) throws
  171. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  172. try content.throwIfError()
  173. let generateContentRequest = GenerateContentRequest(
  174. model: modelResourceName,
  175. contents: content,
  176. generationConfig: generationConfig,
  177. safetySettings: safetySettings,
  178. tools: tools,
  179. toolConfig: toolConfig,
  180. systemInstruction: systemInstruction,
  181. apiMethod: .streamGenerateContent,
  182. options: requestOptions
  183. )
  184. return AsyncThrowingStream { continuation in
  185. let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)
  186. Task {
  187. do {
  188. for try await response in responseStream {
  189. // Check the prompt feedback to see if the prompt was blocked.
  190. if response.promptFeedback?.blockReason != nil {
  191. throw GenerateContentError.promptBlocked(response: response)
  192. }
  193. // If the stream ended early unexpectedly, throw an error.
  194. if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
  195. throw GenerateContentError.responseStoppedEarly(
  196. reason: finishReason,
  197. response: response
  198. )
  199. }
  200. continuation.yield(response)
  201. }
  202. continuation.finish()
  203. } catch {
  204. continuation.finish(throwing: GenerativeModel.generateContentError(from: error))
  205. return
  206. }
  207. }
  208. }
  209. }
  210. /// Creates a new chat conversation using this model with the provided history.
  211. public func startChat(history: [ModelContent] = []) -> Chat {
  212. return Chat(model: self, history: history)
  213. }
  214. /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more
  215. /// ``Part``s.
  216. ///
  217. /// Since ``Part``s do not specify a role, this method is intended for tokenizing
  218. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  219. /// or "direct" prompts. For
  220. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  221. /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`.
  222. ///
  223. /// - Parameters:
  224. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  225. /// conforming types).
  226. /// - Returns: The results of running the model's tokenizer on the input; contains
  227. /// ``CountTokensResponse/totalTokens``.
  228. public func countTokens(_ parts: any PartsRepresentable...) async throws -> CountTokensResponse {
  229. return try await countTokens([ModelContent(parts: parts)])
  230. }
  231. /// Runs the model's tokenizer on the input content and returns the token count.
  232. ///
  233. /// - Parameter content: The input given to the model as a prompt.
  234. /// - Returns: The results of running the model's tokenizer on the input; contains
  235. /// ``CountTokensResponse/totalTokens``.
  236. public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
  237. let requestContent = switch apiConfig.service {
  238. case .vertexAI:
  239. content
  240. case .developer:
  241. // The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
  242. // erroneously counted towards the prompt and total token count when using the Developer API
  243. // backend; set to `nil` to avoid token count discrepancies between `countTokens` and
  244. // `generateContent` and the two backend APIs.
  245. content.map { ModelContent(role: nil, parts: $0.parts) }
  246. }
  247. let generateContentRequest = GenerateContentRequest(
  248. model: modelResourceName,
  249. contents: requestContent,
  250. generationConfig: generationConfig,
  251. safetySettings: safetySettings,
  252. tools: tools,
  253. toolConfig: toolConfig,
  254. systemInstruction: systemInstruction,
  255. apiMethod: .countTokens,
  256. options: requestOptions
  257. )
  258. let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)
  259. return try await generativeAIService.loadRequest(request: countTokensRequest)
  260. }
  261. /// Returns a `GenerateContentError` (for public consumption) from an internal error.
  262. ///
  263. /// If `error` is already a `GenerateContentError` the error is returned unchanged.
  264. private static func generateContentError(from error: Error) -> GenerateContentError {
  265. if let error = error as? GenerateContentError {
  266. return error
  267. }
  268. return GenerateContentError.internalError(underlying: error)
  269. }
  270. }