GenerativeModel.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import Foundation
  17. /// A type that represents a remote multimodal model (like Gemini), with the ability to generate
  18. /// content based on various input types.
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. public final class GenerativeModel {
  21. /// The resource name of the model in the backend; has the format "models/model-name".
  22. let modelResourceName: String
  23. /// The backing service responsible for sending and receiving model requests to the backend.
  24. let generativeAIService: GenerativeAIService
  25. let location: String
  26. /// Configuration parameters used for the MultiModalModel.
  27. let generationConfig: GenerationConfig?
  28. /// The safety settings to be used for prompts.
  29. let safetySettings: [SafetySetting]?
  30. /// A list of tools the model may use to generate the next response.
  31. let tools: [Tool]?
  32. /// Tool configuration for any `Tool` specified in the request.
  33. let toolConfig: ToolConfig?
  34. /// Instructions that direct the model to behave a certain way.
  35. let systemInstruction: ModelContent?
  36. /// Configuration parameters for sending requests to the backend.
  37. let requestOptions: RequestOptions
  38. /// Initializes a new remote model with the given parameters.
  39. ///
  40. /// - Parameters:
  41. /// - name: The name of the model to use, for example `"gemini-1.0-pro"`.
  42. /// - projectID: The project ID from the Firebase console.
  43. /// - apiKey: The API key for your project.
  44. /// - generationConfig: The content generation parameters your model should use.
  45. /// - safetySettings: A value describing what types of harmful content your model should allow.
  46. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
  47. /// - toolConfig: Tool configuration for any `Tool` specified in the request.
  48. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently
  49. /// only text content is supported.
  50. /// - requestOptions: Configuration parameters for sending requests to the backend.
  51. /// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
  52. init(name: String,
  53. projectID: String,
  54. apiKey: String,
  55. location: String = "us-central1",
  56. generationConfig: GenerationConfig? = nil,
  57. safetySettings: [SafetySetting]? = nil,
  58. tools: [Tool]?,
  59. toolConfig: ToolConfig? = nil,
  60. systemInstruction: ModelContent? = nil,
  61. requestOptions: RequestOptions,
  62. appCheck: AppCheckInterop?,
  63. auth: AuthInterop?,
  64. urlSession: URLSession = .shared) {
  65. modelResourceName = name
  66. generativeAIService = GenerativeAIService(
  67. projectID: projectID,
  68. apiKey: apiKey,
  69. appCheck: appCheck,
  70. auth: auth,
  71. urlSession: urlSession
  72. )
  73. self.location = location
  74. self.generationConfig = generationConfig
  75. self.safetySettings = safetySettings
  76. self.tools = tools
  77. self.toolConfig = toolConfig
  78. self.systemInstruction = systemInstruction
  79. self.requestOptions = requestOptions
  80. if VertexLog.additionalLoggingEnabled() {
  81. VertexLog.debug(code: .verboseLoggingEnabled, "Verbose logging enabled.")
  82. } else {
  83. VertexLog.info(code: .verboseLoggingDisabled, """
  84. [FirebaseVertexAI] To enable additional logging, add \
  85. `\(VertexLog.enableArgumentKey)` as a launch argument in Xcode.
  86. """)
  87. }
  88. VertexLog.debug(code: .generativeModelInitialized, "Model \(name) initialized.")
  89. }
  90. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  91. /// representable as one or more ``Part``s.
  92. ///
  93. /// Since ``Part``s do not specify a role, this method is intended for generating content from
  94. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  95. /// or "direct" prompts. For
  96. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  97. /// prompts, see `generateContent(_ content: [ModelContent])`.
  98. ///
  99. /// - Parameters:
  100. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  101. /// conforming types).
  102. /// - Returns: The content generated by the model.
  103. /// - Throws: A ``GenerateContentError`` if the request failed.
  104. public func generateContent(_ parts: any PartsRepresentable...)
  105. async throws -> GenerateContentResponse {
  106. return try await generateContent([ModelContent(parts: parts)])
  107. }
  108. /// Generates new content from input content given to the model as a prompt.
  109. ///
  110. /// - Parameter content: The input(s) given to the model as a prompt.
  111. /// - Returns: The generated content response from the model.
  112. /// - Throws: A ``GenerateContentError`` if the request failed.
  113. public func generateContent(_ content: [ModelContent]) async throws
  114. -> GenerateContentResponse {
  115. try content.throwIfError()
  116. let response: GenerateContentResponse
  117. let generateContentRequest = GenerateContentRequest(model: modelResourceName,
  118. location: location,
  119. contents: content,
  120. generationConfig: generationConfig,
  121. safetySettings: safetySettings,
  122. tools: tools,
  123. toolConfig: toolConfig,
  124. systemInstruction: systemInstruction,
  125. isStreaming: false,
  126. options: requestOptions)
  127. do {
  128. response = try await generativeAIService.loadRequest(request: generateContentRequest)
  129. } catch {
  130. throw GenerativeModel.generateContentError(from: error)
  131. }
  132. // Check the prompt feedback to see if the prompt was blocked.
  133. if response.promptFeedback?.blockReason != nil {
  134. throw GenerateContentError.promptBlocked(response: response)
  135. }
  136. // Check to see if an error should be thrown for stop reason.
  137. if let reason = response.candidates.first?.finishReason, reason != .stop {
  138. throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
  139. }
  140. return response
  141. }
  142. /// Generates content from String and/or image inputs, given to the model as a prompt, that are
  143. /// representable as one or more ``Part``s.
  144. ///
  145. /// Since ``Part``s do not specify a role, this method is intended for generating content from
  146. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  147. /// or "direct" prompts. For
  148. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  149. /// prompts, see `generateContentStream(_ content: @autoclosure () throws -> [ModelContent])`.
  150. ///
  151. /// - Parameters:
  152. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  153. /// conforming types).
  154. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  155. /// error if an error occurred.
  156. @available(macOS 12.0, *)
  157. public func generateContentStream(_ parts: any PartsRepresentable...) throws
  158. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  159. return try generateContentStream([ModelContent(parts: parts)])
  160. }
  161. /// Generates new content from input content given to the model as a prompt.
  162. ///
  163. /// - Parameter content: The input(s) given to the model as a prompt.
  164. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
  165. /// error if an error occurred.
  166. @available(macOS 12.0, *)
  167. public func generateContentStream(_ content: [ModelContent]) throws
  168. -> AsyncThrowingStream<GenerateContentResponse, Error> {
  169. try content.throwIfError()
  170. let generateContentRequest = GenerateContentRequest(model: modelResourceName,
  171. location: location,
  172. contents: content,
  173. generationConfig: generationConfig,
  174. safetySettings: safetySettings,
  175. tools: tools,
  176. toolConfig: toolConfig,
  177. systemInstruction: systemInstruction,
  178. isStreaming: true,
  179. options: requestOptions)
  180. var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
  181. .makeAsyncIterator()
  182. return AsyncThrowingStream {
  183. let response: GenerateContentResponse?
  184. do {
  185. response = try await responseIterator.next()
  186. } catch {
  187. throw GenerativeModel.generateContentError(from: error)
  188. }
  189. // The responseIterator will return `nil` when it's done.
  190. guard let response = response else {
  191. // This is the end of the stream! Signal it by sending `nil`.
  192. return nil
  193. }
  194. // Check the prompt feedback to see if the prompt was blocked.
  195. if response.promptFeedback?.blockReason != nil {
  196. throw GenerateContentError.promptBlocked(response: response)
  197. }
  198. // If the stream ended early unexpectedly, throw an error.
  199. if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
  200. throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
  201. } else {
  202. // Response was valid content, pass it along and continue.
  203. return response
  204. }
  205. }
  206. }
  207. /// Creates a new chat conversation using this model with the provided history.
  208. public func startChat(history: [ModelContent] = []) -> Chat {
  209. return Chat(model: self, history: history)
  210. }
  211. /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more
  212. /// ``Part``s.
  213. ///
  214. /// Since ``Part``s do not specify a role, this method is intended for tokenizing
  215. /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting)
  216. /// or "direct" prompts. For
  217. /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
  218. /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`.
  219. ///
  220. /// - Parameters:
  221. /// - parts: The input(s) given to the model as a prompt (see ``PartsRepresentable`` for
  222. /// conforming types).
  223. /// - Returns: The results of running the model's tokenizer on the input; contains
  224. /// ``CountTokensResponse/totalTokens``.
  225. public func countTokens(_ parts: any PartsRepresentable...) async throws -> CountTokensResponse {
  226. return try await countTokens([ModelContent(parts: parts)])
  227. }
  228. /// Runs the model's tokenizer on the input content and returns the token count.
  229. ///
  230. /// - Parameter content: The input given to the model as a prompt.
  231. /// - Returns: The results of running the model's tokenizer on the input; contains
  232. /// ``CountTokensResponse/totalTokens``.
  233. public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
  234. let countTokensRequest = CountTokensRequest(
  235. model: modelResourceName,
  236. location: location,
  237. contents: content,
  238. systemInstruction: systemInstruction,
  239. tools: tools,
  240. generationConfig: generationConfig,
  241. options: requestOptions
  242. )
  243. return try await generativeAIService.loadRequest(request: countTokensRequest)
  244. }
  245. /// Returns a `GenerateContentError` (for public consumption) from an internal error.
  246. ///
  247. /// If `error` is already a `GenerateContentError` the error is returned unchanged.
  248. private static func generateContentError(from error: Error) -> GenerateContentError {
  249. if let error = error as? GenerateContentError {
  250. return error
  251. }
  252. return GenerateContentError.internalError(underlying: error)
  253. }
  254. }