GenerativeAIService.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. import os.log
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. struct GenerativeAIService {
  21. /// The language of the SDK in the format `gl-<language>/<version>`.
  22. static let languageTag = "gl-swift/5"
  23. /// The Firebase SDK version in the format `fire/<version>`.
  24. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  25. private let firebaseInfo: FirebaseInfo
  26. /// Configuration for the backend API used by this model.
  27. private let apiConfig: APIConfig
  28. private let jsonDecoder: JSONDecoder
  29. private let jsonEncoder: JSONEncoder
  30. private let urlSession: URLSession
  31. init(firebaseInfo: FirebaseInfo, apiConfig: APIConfig, urlSession: URLSession) {
  32. self.firebaseInfo = firebaseInfo
  33. self.apiConfig = apiConfig
  34. jsonDecoder = JSONDecoder(apiConfig: apiConfig)
  35. jsonEncoder = JSONEncoder(apiConfig: apiConfig)
  36. self.urlSession = urlSession
  37. }
  38. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  39. let urlRequest = try await urlRequest(request: request)
  40. #if DEBUG
  41. printCURLCommand(from: urlRequest)
  42. #endif
  43. let data: Data
  44. let rawResponse: URLResponse
  45. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  46. let response = try httpResponse(urlResponse: rawResponse)
  47. // Verify the status code is 200
  48. guard response.statusCode == 200 else {
  49. VertexLog.error(
  50. code: .loadRequestResponseError,
  51. "The server responded with an error: \(response)"
  52. )
  53. if let responseString = String(data: data, encoding: .utf8) {
  54. VertexLog.error(
  55. code: .loadRequestResponseErrorPayload,
  56. "Response payload: \(responseString)"
  57. )
  58. }
  59. throw parseError(responseData: data)
  60. }
  61. return try parseResponse(T.Response.self, from: data)
  62. }
  63. @available(macOS 12.0, *)
  64. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  65. -> AsyncThrowingStream<T.Response, Error> where T: Sendable {
  66. return AsyncThrowingStream { continuation in
  67. Task {
  68. let urlRequest: URLRequest
  69. do {
  70. urlRequest = try await self.urlRequest(request: request)
  71. } catch {
  72. continuation.finish(throwing: error)
  73. return
  74. }
  75. #if DEBUG
  76. printCURLCommand(from: urlRequest)
  77. #endif
  78. let stream: URLSession.AsyncBytes
  79. let rawResponse: URLResponse
  80. do {
  81. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  82. } catch {
  83. continuation.finish(throwing: error)
  84. return
  85. }
  86. // Verify the status code is 200
  87. let response: HTTPURLResponse
  88. do {
  89. response = try httpResponse(urlResponse: rawResponse)
  90. } catch {
  91. continuation.finish(throwing: error)
  92. return
  93. }
  94. // Verify the status code is 200
  95. guard response.statusCode == 200 else {
  96. VertexLog.error(
  97. code: .loadRequestStreamResponseError,
  98. "The server responded with an error: \(response)"
  99. )
  100. var responseBody = ""
  101. for try await line in stream.lines {
  102. responseBody += line + "\n"
  103. }
  104. VertexLog.error(
  105. code: .loadRequestStreamResponseErrorPayload,
  106. "Response payload: \(responseBody)"
  107. )
  108. continuation.finish(throwing: parseError(responseBody: responseBody))
  109. return
  110. }
  111. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  112. var extraLines = ""
  113. for try await line in stream.lines {
  114. VertexLog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")
  115. if line.hasPrefix("data:") {
  116. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  117. let jsonText = String(line.dropFirst(5))
  118. let data: Data
  119. do {
  120. data = try jsonData(jsonText: jsonText)
  121. } catch {
  122. continuation.finish(throwing: error)
  123. return
  124. }
  125. // Handle the content.
  126. do {
  127. let content = try parseResponse(T.Response.self, from: data)
  128. continuation.yield(content)
  129. } catch {
  130. continuation.finish(throwing: error)
  131. return
  132. }
  133. } else {
  134. extraLines += line
  135. }
  136. }
  137. if extraLines.count > 0 {
  138. continuation.finish(throwing: parseError(responseBody: extraLines))
  139. return
  140. }
  141. continuation.finish(throwing: nil)
  142. }
  143. }
  144. }
  145. // MARK: - Private Helpers
  146. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  147. var urlRequest = URLRequest(url: request.requestURL(apiConfig: apiConfig))
  148. urlRequest.httpMethod = "POST"
  149. urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key")
  150. urlRequest.setValue(
  151. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  152. forHTTPHeaderField: "x-goog-api-client"
  153. )
  154. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  155. if let appCheck = firebaseInfo.appCheck {
  156. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  157. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  158. if let error = tokenResult.error {
  159. VertexLog.error(
  160. code: .appCheckTokenFetchFailed,
  161. "Failed to fetch AppCheck token. Error: \(error)"
  162. )
  163. }
  164. }
  165. if let auth = firebaseInfo.auth, let authToken = try await auth.getToken(
  166. forcingRefresh: false
  167. ) {
  168. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  169. }
  170. if firebaseInfo.app.isDataCollectionDefaultEnabled {
  171. urlRequest.setValue(firebaseInfo.firebaseAppID, forHTTPHeaderField: "X-Firebase-AppId")
  172. if let appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String {
  173. urlRequest.setValue(appVersion, forHTTPHeaderField: "X-Firebase-AppVersion")
  174. }
  175. }
  176. urlRequest.httpBody = try jsonEncoder.encode(request)
  177. urlRequest.timeoutInterval = request.options.timeout
  178. return urlRequest
  179. }
  180. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  181. // The following condition should always be true: "Whenever you make HTTP URL load requests, any
  182. // response objects you get back from the URLSession, NSURLConnection, or NSURLDownload class
  183. // are instances of the HTTPURLResponse class."
  184. guard let response = urlResponse as? HTTPURLResponse else {
  185. VertexLog.error(
  186. code: .generativeAIServiceNonHTTPResponse,
  187. "Response wasn't an HTTP response, internal error \(urlResponse)"
  188. )
  189. throw URLError(
  190. .badServerResponse,
  191. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  192. )
  193. }
  194. return response
  195. }
  196. private func jsonData(jsonText: String) throws -> Data {
  197. guard let data = jsonText.data(using: .utf8) else {
  198. throw DecodingError.dataCorrupted(DecodingError.Context(
  199. codingPath: [],
  200. debugDescription: "Could not parse response as UTF8."
  201. ))
  202. }
  203. return data
  204. }
  205. private func parseError(responseBody: String) -> Error {
  206. do {
  207. let data = try jsonData(jsonText: responseBody)
  208. return parseError(responseData: data)
  209. } catch {
  210. return error
  211. }
  212. }
  213. private func parseError(responseData: Data) -> Error {
  214. do {
  215. let rpcError = try jsonDecoder.decode(BackendError.self, from: responseData)
  216. logRPCError(rpcError)
  217. return rpcError
  218. } catch {
  219. // TODO: Return an error about an unrecognized error payload with the response body
  220. return error
  221. }
  222. }
  223. // Log specific RPC errors that cannot be mitigated or handled by user code.
  224. // These errors do not produce specific GenerateContentError or CountTokensError cases.
  225. private func logRPCError(_ error: BackendError) {
  226. let projectID = firebaseInfo.projectID
  227. if error.isVertexAIInFirebaseServiceDisabledError() {
  228. VertexLog.error(code: .vertexAIInFirebaseAPIDisabled, """
  229. The Vertex AI in Firebase SDK requires the Vertex AI in Firebase API \
  230. (`firebasevertexai.googleapis.com`) to be enabled in your Firebase project. Enable this API \
  231. by visiting the Firebase Console at
  232. https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \
  233. If you enabled this API recently, wait a few minutes for the action to propagate to our \
  234. systems and then retry.
  235. """)
  236. }
  237. }
  238. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  239. do {
  240. return try jsonDecoder.decode(type, from: data)
  241. } catch {
  242. if let json = String(data: data, encoding: .utf8) {
  243. VertexLog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
  244. }
  245. VertexLog.error(
  246. code: .loadRequestParseResponseFailedJSONError,
  247. "Error decoding server JSON: \(error)"
  248. )
  249. throw error
  250. }
  251. }
  252. #if DEBUG
  253. private func cURLCommand(from request: URLRequest) -> String {
  254. var returnValue = "curl "
  255. if let allHeaders = request.allHTTPHeaderFields {
  256. for (key, value) in allHeaders {
  257. returnValue += "-H '\(key): \(value)' "
  258. }
  259. }
  260. guard let url = request.url else { return "" }
  261. returnValue += "'\(url.absoluteString)' "
  262. guard let body = request.httpBody,
  263. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  264. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  265. returnValue += "-d '\(escapedJSON)'"
  266. return returnValue
  267. }
  268. private func printCURLCommand(from request: URLRequest) {
  269. guard VertexLog.additionalLoggingEnabled() else {
  270. return
  271. }
  272. let command = cURLCommand(from: request)
  273. os_log(.debug, log: VertexLog.logObject, """
  274. \(VertexLog.service) Creating request with the equivalent cURL command:
  275. ----- cURL command -----
  276. \(command, privacy: .private)
  277. ------------------------
  278. """)
  279. }
  280. #endif // DEBUG
  281. }