GenerativeAIService.swift 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. import os.log
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. struct GenerativeAIService {
  21. /// The language of the SDK in the format `gl-<language>/<version>`.
  22. static let languageTag = "gl-swift/5"
  23. /// The Firebase SDK version in the format `fire/<version>`.
  24. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  25. private let firebaseInfo: FirebaseInfo
  26. private let urlSession: URLSession
  27. private let useLimitedUseAppCheckTokens: Bool
  28. init(firebaseInfo: FirebaseInfo, urlSession: URLSession, useLimitedUseAppCheckTokens: Bool) {
  29. self.firebaseInfo = firebaseInfo
  30. self.urlSession = urlSession
  31. self.useLimitedUseAppCheckTokens = useLimitedUseAppCheckTokens
  32. }
  33. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  34. let urlRequest = try await urlRequest(request: request)
  35. #if DEBUG
  36. printCURLCommand(from: urlRequest)
  37. #endif
  38. let data: Data
  39. let rawResponse: URLResponse
  40. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  41. let response = try httpResponse(urlResponse: rawResponse)
  42. // Verify the status code is 200
  43. guard response.statusCode == 200 else {
  44. AILog.error(
  45. code: .loadRequestResponseError,
  46. "The server responded with an error: \(response)"
  47. )
  48. if let responseString = String(data: data, encoding: .utf8) {
  49. AILog.error(
  50. code: .loadRequestResponseErrorPayload,
  51. "Response payload: \(responseString)"
  52. )
  53. }
  54. throw parseError(responseData: data)
  55. }
  56. return try parseResponse(T.Response.self, from: data)
  57. }
  58. @available(macOS 12.0, *)
  59. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  60. -> AsyncThrowingStream<T.Response, Error> where T: Sendable {
  61. return AsyncThrowingStream { continuation in
  62. Task {
  63. let urlRequest: URLRequest
  64. do {
  65. urlRequest = try await self.urlRequest(request: request)
  66. } catch {
  67. continuation.finish(throwing: error)
  68. return
  69. }
  70. #if DEBUG
  71. printCURLCommand(from: urlRequest)
  72. #endif
  73. let stream: URLSession.AsyncBytes
  74. let rawResponse: URLResponse
  75. do {
  76. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  77. } catch {
  78. continuation.finish(throwing: error)
  79. return
  80. }
  81. // Verify the status code is 200
  82. let response: HTTPURLResponse
  83. do {
  84. response = try httpResponse(urlResponse: rawResponse)
  85. } catch {
  86. continuation.finish(throwing: error)
  87. return
  88. }
  89. // Verify the status code is 200
  90. guard response.statusCode == 200 else {
  91. AILog.error(
  92. code: .loadRequestStreamResponseError,
  93. "The server responded with an error: \(response)"
  94. )
  95. var responseBody = ""
  96. for try await line in stream.lines {
  97. responseBody += line + "\n"
  98. }
  99. AILog.error(
  100. code: .loadRequestStreamResponseErrorPayload,
  101. "Response payload: \(responseBody)"
  102. )
  103. continuation.finish(throwing: parseError(responseBody: responseBody))
  104. return
  105. }
  106. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  107. var extraLines = ""
  108. let decoder = JSONDecoder()
  109. decoder.keyDecodingStrategy = .convertFromSnakeCase
  110. for try await line in stream.lines {
  111. AILog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")
  112. if line.hasPrefix("data:") {
  113. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  114. let jsonText = String(line.dropFirst(5))
  115. let data: Data
  116. do {
  117. data = try jsonData(jsonText: jsonText)
  118. } catch {
  119. continuation.finish(throwing: error)
  120. return
  121. }
  122. // Handle the content.
  123. do {
  124. let content = try parseResponse(T.Response.self, from: data)
  125. continuation.yield(content)
  126. } catch {
  127. continuation.finish(throwing: error)
  128. return
  129. }
  130. } else {
  131. extraLines += line
  132. }
  133. }
  134. if extraLines.count > 0 {
  135. continuation.finish(throwing: parseError(responseBody: extraLines))
  136. return
  137. }
  138. continuation.finish(throwing: nil)
  139. }
  140. }
  141. }
  142. // MARK: - Private Helpers
  143. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  144. var urlRequest = URLRequest(url: request.url)
  145. urlRequest.httpMethod = "POST"
  146. urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key")
  147. urlRequest.setValue(
  148. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  149. forHTTPHeaderField: "x-goog-api-client"
  150. )
  151. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  152. if let appCheck = firebaseInfo.appCheck {
  153. let tokenResult = try await fetchAppCheckToken(appCheck: appCheck)
  154. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  155. if let error = tokenResult.error {
  156. AILog.error(
  157. code: .appCheckTokenFetchFailed,
  158. "Failed to fetch AppCheck token. Error: \(error)"
  159. )
  160. }
  161. }
  162. if let auth = firebaseInfo.auth, let authToken = try await auth.getToken(
  163. forcingRefresh: false
  164. ) {
  165. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  166. }
  167. if firebaseInfo.app.isDataCollectionDefaultEnabled {
  168. urlRequest.setValue(firebaseInfo.firebaseAppID, forHTTPHeaderField: "X-Firebase-AppId")
  169. if let appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String {
  170. urlRequest.setValue(appVersion, forHTTPHeaderField: "X-Firebase-AppVersion")
  171. }
  172. }
  173. let encoder = JSONEncoder()
  174. urlRequest.httpBody = try encoder.encode(request)
  175. urlRequest.timeoutInterval = request.options.timeout
  176. return urlRequest
  177. }
  178. private func fetchAppCheckToken(appCheck: AppCheckInterop) async throws
  179. -> FIRAppCheckTokenResultInterop {
  180. if useLimitedUseAppCheckTokens {
  181. if let token = await getLimitedUseAppCheckToken(appCheck: appCheck) {
  182. return token
  183. }
  184. let errorMessage =
  185. "The provided App Check token provider doesn't implement getLimitedUseToken(), but requireLimitedUseTokens was enabled."
  186. #if Debug
  187. fatalError(errorMessage)
  188. #else
  189. throw NSError(
  190. domain: "\(Constants.baseErrorDomain).\(Self.self)",
  191. code: AILog.MessageCode.appCheckTokenFetchFailed.rawValue,
  192. userInfo: [NSLocalizedDescriptionKey: errorMessage]
  193. )
  194. #endif
  195. }
  196. return await appCheck.getToken(forcingRefresh: false)
  197. }
  198. private func getLimitedUseAppCheckToken(appCheck: AppCheckInterop) async
  199. -> FIRAppCheckTokenResultInterop? {
  200. // At the moment, `await` doesn’t get along with Objective-C’s optional protocol methods.
  201. await withCheckedContinuation { (continuation: CheckedContinuation<
  202. FIRAppCheckTokenResultInterop?,
  203. Never
  204. >) in
  205. guard
  206. useLimitedUseAppCheckTokens,
  207. // `getLimitedUseToken(completion:)` is an optional protocol method. Optional binding
  208. // is performed to make sure `continuation` is called even if the method’s not implemented.
  209. let limitedUseTokenClosure = appCheck.getLimitedUseToken
  210. else {
  211. return continuation.resume(returning: nil)
  212. }
  213. limitedUseTokenClosure { tokenResult in
  214. // The placeholder token should be used in the case of App Check error.
  215. continuation.resume(returning: tokenResult)
  216. }
  217. }
  218. }
  219. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  220. // The following condition should always be true: "Whenever you make HTTP URL load requests, any
  221. // response objects you get back from the URLSession, NSURLConnection, or NSURLDownload class
  222. // are instances of the HTTPURLResponse class."
  223. guard let response = urlResponse as? HTTPURLResponse else {
  224. AILog.error(
  225. code: .generativeAIServiceNonHTTPResponse,
  226. "Response wasn't an HTTP response, internal error \(urlResponse)"
  227. )
  228. throw URLError(
  229. .badServerResponse,
  230. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  231. )
  232. }
  233. return response
  234. }
  235. private func jsonData(jsonText: String) throws -> Data {
  236. guard let data = jsonText.data(using: .utf8) else {
  237. throw DecodingError.dataCorrupted(DecodingError.Context(
  238. codingPath: [],
  239. debugDescription: "Could not parse response as UTF8."
  240. ))
  241. }
  242. return data
  243. }
  244. private func parseError(responseBody: String) -> Error {
  245. do {
  246. let data = try jsonData(jsonText: responseBody)
  247. return parseError(responseData: data)
  248. } catch {
  249. return error
  250. }
  251. }
  252. private func parseError(responseData: Data) -> Error {
  253. do {
  254. let rpcError = try JSONDecoder().decode(BackendError.self, from: responseData)
  255. logRPCError(rpcError)
  256. return rpcError
  257. } catch {
  258. // TODO: Return an error about an unrecognized error payload with the response body
  259. return error
  260. }
  261. }
  262. // Log specific RPC errors that cannot be mitigated or handled by user code.
  263. // These errors do not produce specific GenerateContentError or CountTokensError cases.
  264. private func logRPCError(_ error: BackendError) {
  265. let projectID = firebaseInfo.projectID
  266. if error.isVertexAIInFirebaseServiceDisabledError() {
  267. AILog.error(code: .vertexAIInFirebaseAPIDisabled, """
  268. The Firebase AI SDK requires the Firebase AI API \
  269. (`firebasevertexai.googleapis.com`) to be enabled in your Firebase project. Enable this API \
  270. by visiting the Firebase Console at
  271. https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \
  272. If you enabled this API recently, wait a few minutes for the action to propagate to our \
  273. systems and then retry.
  274. """)
  275. }
  276. }
  277. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  278. do {
  279. return try JSONDecoder().decode(type, from: data)
  280. } catch {
  281. if let json = String(data: data, encoding: .utf8) {
  282. AILog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
  283. }
  284. AILog.error(
  285. code: .loadRequestParseResponseFailedJSONError,
  286. "Error decoding server JSON: \(error)"
  287. )
  288. throw error
  289. }
  290. }
  291. #if DEBUG
  292. private func cURLCommand(from request: URLRequest) -> String {
  293. var returnValue = "curl "
  294. if let allHeaders = request.allHTTPHeaderFields {
  295. for (key, value) in allHeaders {
  296. returnValue += "-H '\(key): \(value)' "
  297. }
  298. }
  299. guard let url = request.url else { return "" }
  300. returnValue += "'\(url.absoluteString)' "
  301. guard let body = request.httpBody,
  302. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  303. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  304. returnValue += "-d '\(escapedJSON)'"
  305. return returnValue
  306. }
  307. private func printCURLCommand(from request: URLRequest) {
  308. guard AILog.additionalLoggingEnabled() else {
  309. return
  310. }
  311. let command = cURLCommand(from: request)
  312. os_log(.debug, log: AILog.logObject, """
  313. \(AILog.service) Creating request with the equivalent cURL command:
  314. ----- cURL command -----
  315. \(command, privacy: .private)
  316. ------------------------
  317. """)
  318. }
  319. #endif // DEBUG
  320. }