GenerativeAIService.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. import os.log
  19. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. struct GenerativeAIService {
  21. /// The language of the SDK in the format `gl-<language>/<version>`.
  22. static let languageTag = "gl-swift/5"
  23. /// The Firebase SDK version in the format `fire/<version>`.
  24. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  25. private let projectID: String
  26. /// Gives permission to talk to the backend.
  27. private let apiKey: String
  28. private let appCheck: AppCheckInterop?
  29. private let auth: AuthInterop?
  30. private let urlSession: URLSession
  31. init(projectID: String, apiKey: String, appCheck: AppCheckInterop?, auth: AuthInterop?,
  32. urlSession: URLSession) {
  33. self.projectID = projectID
  34. self.apiKey = apiKey
  35. self.appCheck = appCheck
  36. self.auth = auth
  37. self.urlSession = urlSession
  38. }
  39. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  40. let urlRequest = try await urlRequest(request: request)
  41. #if DEBUG
  42. printCURLCommand(from: urlRequest)
  43. #endif
  44. let data: Data
  45. let rawResponse: URLResponse
  46. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  47. let response = try httpResponse(urlResponse: rawResponse)
  48. // Verify the status code is 200
  49. guard response.statusCode == 200 else {
  50. VertexLog.error(
  51. code: .loadRequestResponseError,
  52. "The server responded with an error: \(response)"
  53. )
  54. if let responseString = String(data: data, encoding: .utf8) {
  55. VertexLog.error(
  56. code: .loadRequestResponseErrorPayload,
  57. "Response payload: \(responseString)"
  58. )
  59. }
  60. throw parseError(responseData: data)
  61. }
  62. return try parseResponse(T.Response.self, from: data)
  63. }
  64. @available(macOS 12.0, *)
  65. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  66. -> AsyncThrowingStream<T.Response, Error> {
  67. return AsyncThrowingStream { continuation in
  68. Task {
  69. let urlRequest: URLRequest
  70. do {
  71. urlRequest = try await self.urlRequest(request: request)
  72. } catch {
  73. continuation.finish(throwing: error)
  74. return
  75. }
  76. #if DEBUG
  77. printCURLCommand(from: urlRequest)
  78. #endif
  79. let stream: URLSession.AsyncBytes
  80. let rawResponse: URLResponse
  81. do {
  82. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  83. } catch {
  84. continuation.finish(throwing: error)
  85. return
  86. }
  87. // Verify the status code is 200
  88. let response: HTTPURLResponse
  89. do {
  90. response = try httpResponse(urlResponse: rawResponse)
  91. } catch {
  92. continuation.finish(throwing: error)
  93. return
  94. }
  95. // Verify the status code is 200
  96. guard response.statusCode == 200 else {
  97. VertexLog.error(
  98. code: .loadRequestStreamResponseError,
  99. "The server responded with an error: \(response)"
  100. )
  101. var responseBody = ""
  102. for try await line in stream.lines {
  103. responseBody += line + "\n"
  104. }
  105. VertexLog.error(
  106. code: .loadRequestStreamResponseErrorPayload,
  107. "Response payload: \(responseBody)"
  108. )
  109. continuation.finish(throwing: parseError(responseBody: responseBody))
  110. return
  111. }
  112. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  113. var extraLines = ""
  114. let decoder = JSONDecoder()
  115. decoder.keyDecodingStrategy = .convertFromSnakeCase
  116. for try await line in stream.lines {
  117. VertexLog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")
  118. if line.hasPrefix("data:") {
  119. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  120. let jsonText = String(line.dropFirst(5))
  121. let data: Data
  122. do {
  123. data = try jsonData(jsonText: jsonText)
  124. } catch {
  125. continuation.finish(throwing: error)
  126. return
  127. }
  128. // Handle the content.
  129. do {
  130. let content = try parseResponse(T.Response.self, from: data)
  131. continuation.yield(content)
  132. } catch {
  133. continuation.finish(throwing: error)
  134. return
  135. }
  136. } else {
  137. extraLines += line
  138. }
  139. }
  140. if extraLines.count > 0 {
  141. continuation.finish(throwing: parseError(responseBody: extraLines))
  142. return
  143. }
  144. continuation.finish(throwing: nil)
  145. }
  146. }
  147. }
  148. // MARK: - Private Helpers
  149. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  150. var urlRequest = URLRequest(url: request.url)
  151. urlRequest.httpMethod = "POST"
  152. urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
  153. urlRequest.setValue(
  154. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  155. forHTTPHeaderField: "x-goog-api-client"
  156. )
  157. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  158. if let appCheck {
  159. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  160. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  161. if let error = tokenResult.error {
  162. VertexLog.error(
  163. code: .appCheckTokenFetchFailed,
  164. "Failed to fetch AppCheck token. Error: \(error)"
  165. )
  166. }
  167. }
  168. if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
  169. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  170. }
  171. let encoder = JSONEncoder()
  172. encoder.keyEncodingStrategy = .convertToSnakeCase
  173. urlRequest.httpBody = try encoder.encode(request)
  174. urlRequest.timeoutInterval = request.options.timeout
  175. return urlRequest
  176. }
  177. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  178. // Verify the status code is 200
  179. guard let response = urlResponse as? HTTPURLResponse else {
  180. VertexLog.error(
  181. code: .generativeAIServiceNonHTTPResponse,
  182. "Response wasn't an HTTP response, internal error \(urlResponse)"
  183. )
  184. throw NSError(
  185. domain: "com.google.generative-ai",
  186. code: -1,
  187. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  188. )
  189. }
  190. return response
  191. }
  192. private func jsonData(jsonText: String) throws -> Data {
  193. guard let data = jsonText.data(using: .utf8) else {
  194. let error = NSError(
  195. domain: "com.google.generative-ai",
  196. code: -1,
  197. userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."]
  198. )
  199. throw error
  200. }
  201. return data
  202. }
  203. private func parseError(responseBody: String) -> Error {
  204. do {
  205. let data = try jsonData(jsonText: responseBody)
  206. return parseError(responseData: data)
  207. } catch {
  208. return error
  209. }
  210. }
  211. private func parseError(responseData: Data) -> Error {
  212. do {
  213. let rpcError = try JSONDecoder().decode(RPCError.self, from: responseData)
  214. logRPCError(rpcError)
  215. return rpcError
  216. } catch {
  217. // TODO: Return an error about an unrecognized error payload with the response body
  218. return error
  219. }
  220. }
  221. // Log specific RPC errors that cannot be mitigated or handled by user code.
  222. // These errors do not produce specific GenerateContentError or CountTokensError cases.
  223. private func logRPCError(_ error: RPCError) {
  224. // TODO(andrewheard): Remove this check after the Vertex AI in Firebase API launch.
  225. if error.isFirebaseMLServiceDisabledError() {
  226. VertexLog.error(code: .vertexAIInFirebaseAPIDisabled, """
  227. The Vertex AI in Firebase SDK requires the Firebase ML API (`firebaseml.googleapis.com`) to \
  228. be enabled in your Firebase project. Enable this API by visiting the Firebase Console at
  229. https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \
  230. If you enabled this API recently, wait a few minutes for the action to propagate to our \
  231. systems and then retry.
  232. """)
  233. }
  234. if error.isVertexAIInFirebaseServiceDisabledError() {
  235. VertexLog.error(code: .vertexAIInFirebaseAPIDisabled, """
  236. The Vertex AI in Firebase SDK requires the Vertex AI in Firebase API \
  237. (`firebasevertexai.googleapis.com`) to be enabled in your Firebase project. Enable this API \
  238. by visiting the Firebase Console at
  239. https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \
  240. If you enabled this API recently, wait a few minutes for the action to propagate to our \
  241. systems and then retry.
  242. """)
  243. }
  244. }
  245. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  246. do {
  247. return try JSONDecoder().decode(type, from: data)
  248. } catch {
  249. if let json = String(data: data, encoding: .utf8) {
  250. VertexLog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
  251. }
  252. VertexLog.error(
  253. code: .loadRequestParseResponseFailedJSONError,
  254. "Error decoding server JSON: \(error)"
  255. )
  256. throw error
  257. }
  258. }
  259. #if DEBUG
  260. private func cURLCommand(from request: URLRequest) -> String {
  261. var returnValue = "curl "
  262. if let allHeaders = request.allHTTPHeaderFields {
  263. for (key, value) in allHeaders {
  264. returnValue += "-H '\(key): \(value)' "
  265. }
  266. }
  267. guard let url = request.url else { return "" }
  268. returnValue += "'\(url.absoluteString)' "
  269. guard let body = request.httpBody,
  270. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  271. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  272. returnValue += "-d '\(escapedJSON)'"
  273. return returnValue
  274. }
  275. private func printCURLCommand(from request: URLRequest) {
  276. guard VertexLog.additionalLoggingEnabled() else {
  277. return
  278. }
  279. let command = cURLCommand(from: request)
  280. os_log(.debug, log: VertexLog.logObject, """
  281. \(VertexLog.service) Creating request with the equivalent cURL command:
  282. ----- cURL command -----
  283. \(command, privacy: .private)
  284. ------------------------
  285. """)
  286. }
  287. #endif // DEBUG
  288. }