GenerativeAIService.swift 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, *)
  19. struct GenerativeAIService {
  20. /// The language of the SDK in the format `gl-<language>/<version>`.
  21. static let languageTag = "gl-swift/5"
  22. /// The Firebase SDK version in the format `fire/<version>`.
  23. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  24. private let projectID: String
  25. /// Gives permission to talk to the backend.
  26. private let apiKey: String
  27. private let appCheck: AppCheckInterop?
  28. private let auth: AuthInterop?
  29. private let urlSession: URLSession
  30. init(projectID: String, apiKey: String, appCheck: AppCheckInterop?, auth: AuthInterop?,
  31. urlSession: URLSession) {
  32. self.projectID = projectID
  33. self.apiKey = apiKey
  34. self.appCheck = appCheck
  35. self.auth = auth
  36. self.urlSession = urlSession
  37. }
  38. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  39. let urlRequest = try await urlRequest(request: request)
  40. #if DEBUG
  41. printCURLCommand(from: urlRequest)
  42. #endif
  43. let data: Data
  44. let rawResponse: URLResponse
  45. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  46. let response = try httpResponse(urlResponse: rawResponse)
  47. // Verify the status code is 200
  48. guard response.statusCode == 200 else {
  49. Logging.network.error("[FirebaseVertexAI] The server responded with an error: \(response)")
  50. if let responseString = String(data: data, encoding: .utf8) {
  51. Logging.default.error("[FirebaseVertexAI] Response payload: \(responseString)")
  52. }
  53. throw parseError(responseData: data)
  54. }
  55. return try parseResponse(T.Response.self, from: data)
  56. }
  57. @available(macOS 12.0, *)
  58. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  59. -> AsyncThrowingStream<T.Response, Error> {
  60. return AsyncThrowingStream { continuation in
  61. Task {
  62. let urlRequest: URLRequest
  63. do {
  64. urlRequest = try await self.urlRequest(request: request)
  65. } catch {
  66. continuation.finish(throwing: error)
  67. return
  68. }
  69. #if DEBUG
  70. printCURLCommand(from: urlRequest)
  71. #endif
  72. let stream: URLSession.AsyncBytes
  73. let rawResponse: URLResponse
  74. do {
  75. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  76. } catch {
  77. continuation.finish(throwing: error)
  78. return
  79. }
  80. // Verify the status code is 200
  81. let response: HTTPURLResponse
  82. do {
  83. response = try httpResponse(urlResponse: rawResponse)
  84. } catch {
  85. continuation.finish(throwing: error)
  86. return
  87. }
  88. // Verify the status code is 200
  89. guard response.statusCode == 200 else {
  90. Logging.network
  91. .error("[FirebaseVertexAI] The server responded with an error: \(response)")
  92. var responseBody = ""
  93. for try await line in stream.lines {
  94. responseBody += line + "\n"
  95. }
  96. Logging.default.error("[FirebaseVertexAI] Response payload: \(responseBody)")
  97. continuation.finish(throwing: parseError(responseBody: responseBody))
  98. return
  99. }
  100. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  101. var extraLines: String = ""
  102. let decoder = JSONDecoder()
  103. decoder.keyDecodingStrategy = .convertFromSnakeCase
  104. for try await line in stream.lines {
  105. Logging.network.debug("[FirebaseVertexAI] Stream response: \(line)")
  106. if line.hasPrefix("data:") {
  107. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  108. let jsonText = String(line.dropFirst(5))
  109. let data: Data
  110. do {
  111. data = try jsonData(jsonText: jsonText)
  112. } catch {
  113. continuation.finish(throwing: error)
  114. return
  115. }
  116. // Handle the content.
  117. do {
  118. let content = try parseResponse(T.Response.self, from: data)
  119. continuation.yield(content)
  120. } catch {
  121. continuation.finish(throwing: error)
  122. return
  123. }
  124. } else {
  125. extraLines += line
  126. }
  127. }
  128. if extraLines.count > 0 {
  129. continuation.finish(throwing: parseError(responseBody: extraLines))
  130. return
  131. }
  132. continuation.finish(throwing: nil)
  133. }
  134. }
  135. }
  136. // MARK: - Private Helpers
  137. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  138. var urlRequest = URLRequest(url: request.url)
  139. urlRequest.httpMethod = "POST"
  140. urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
  141. urlRequest.setValue(
  142. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  143. forHTTPHeaderField: "x-goog-api-client"
  144. )
  145. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  146. if let appCheck {
  147. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  148. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  149. if let error = tokenResult.error {
  150. Logging.default
  151. .debug("[FirebaseVertexAI] Failed to fetch AppCheck token. Error: \(error)")
  152. }
  153. }
  154. if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
  155. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  156. }
  157. let encoder = JSONEncoder()
  158. encoder.keyEncodingStrategy = .convertToSnakeCase
  159. urlRequest.httpBody = try encoder.encode(request)
  160. if let timeoutInterval = request.options.timeout {
  161. urlRequest.timeoutInterval = timeoutInterval
  162. }
  163. return urlRequest
  164. }
  165. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  166. // Verify the status code is 200
  167. guard let response = urlResponse as? HTTPURLResponse else {
  168. Logging.default
  169. .error(
  170. "[FirebaseVertexAI] Response wasn't an HTTP response, internal error \(urlResponse)"
  171. )
  172. throw NSError(
  173. domain: "com.google.generative-ai",
  174. code: -1,
  175. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  176. )
  177. }
  178. return response
  179. }
  180. private func jsonData(jsonText: String) throws -> Data {
  181. guard let data = jsonText.data(using: .utf8) else {
  182. let error = NSError(
  183. domain: "com.google.generative-ai",
  184. code: -1,
  185. userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."]
  186. )
  187. throw error
  188. }
  189. return data
  190. }
  191. private func parseError(responseBody: String) -> Error {
  192. do {
  193. let data = try jsonData(jsonText: responseBody)
  194. return parseError(responseData: data)
  195. } catch {
  196. return error
  197. }
  198. }
  199. private func parseError(responseData: Data) -> Error {
  200. do {
  201. let rpcError = try JSONDecoder().decode(RPCError.self, from: responseData)
  202. logRPCError(rpcError)
  203. return rpcError
  204. } catch {
  205. // TODO: Return an error about an unrecognized error payload with the response body
  206. return error
  207. }
  208. }
  209. // Log specific RPC errors that cannot be mitigated or handled by user code.
  210. // These errors do not produce specific GenerateContentError or CountTokensError cases.
  211. private func logRPCError(_ error: RPCError) {
  212. if error.isFirebaseMLServiceDisabledError() {
  213. Logging.default.error("""
  214. The Vertex AI for Firebase SDK requires the Firebase ML API `firebaseml.googleapis.com` to \
  215. be enabled for your project. Get started in the Firebase Console \
  216. (https://console.firebase.google.com/project/\(projectID)/genai/vertex) or verify that the \
  217. API is enabled in the Google Cloud Console \
  218. (https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview?project=\
  219. \(projectID)).
  220. """)
  221. }
  222. }
  223. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  224. do {
  225. return try JSONDecoder().decode(type, from: data)
  226. } catch {
  227. if let json = String(data: data, encoding: .utf8) {
  228. Logging.network.error("[FirebaseVertexAI] JSON response: \(json)")
  229. }
  230. Logging.default.error("[FirebaseVertexAI] Error decoding server JSON: \(error)")
  231. throw error
  232. }
  233. }
  234. #if DEBUG
  235. private func cURLCommand(from request: URLRequest) -> String {
  236. var returnValue = "curl "
  237. if let allHeaders = request.allHTTPHeaderFields {
  238. for (key, value) in allHeaders {
  239. returnValue += "-H '\(key): \(value)' "
  240. }
  241. }
  242. guard let url = request.url else { return "" }
  243. returnValue += "'\(url.absoluteString)' "
  244. guard let body = request.httpBody,
  245. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  246. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  247. returnValue += "-d '\(escapedJSON)'"
  248. return returnValue
  249. }
  250. private func printCURLCommand(from request: URLRequest) {
  251. let command = cURLCommand(from: request)
  252. Logging.verbose.debug("""
  253. [FirebaseVertexAI] Creating request with the equivalent cURL command:
  254. ----- cURL command -----
  255. \(command, privacy: .private)
  256. ------------------------
  257. """)
  258. }
  259. #endif // DEBUG
  260. }