GenerativeAIService.swift 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. import os.log
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. struct GenerativeAIService {
  21. /// The language of the SDK in the format `gl-<language>/<version>`.
  22. static let languageTag = "gl-swift/5"
  23. /// The Firebase SDK version in the format `fire/<version>`.
  24. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  25. private let firebaseInfo: FirebaseInfo
  26. private let urlSession: URLSession
  27. init(firebaseInfo: FirebaseInfo, urlSession: URLSession) {
  28. self.firebaseInfo = firebaseInfo
  29. self.urlSession = urlSession
  30. }
  31. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  32. let urlRequest = try await urlRequest(request: request)
  33. #if DEBUG
  34. printCURLCommand(from: urlRequest)
  35. #endif
  36. let data: Data
  37. let rawResponse: URLResponse
  38. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  39. let response = try httpResponse(urlResponse: rawResponse)
  40. // Verify the status code is 200
  41. guard response.statusCode == 200 else {
  42. VertexLog.error(
  43. code: .loadRequestResponseError,
  44. "The server responded with an error: \(response)"
  45. )
  46. if let responseString = String(data: data, encoding: .utf8) {
  47. VertexLog.error(
  48. code: .loadRequestResponseErrorPayload,
  49. "Response payload: \(responseString)"
  50. )
  51. }
  52. throw parseError(responseData: data)
  53. }
  54. return try parseResponse(T.Response.self, from: data)
  55. }
  56. @available(macOS 12.0, *)
  57. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  58. -> AsyncThrowingStream<T.Response, Error> where T: Sendable {
  59. return AsyncThrowingStream { continuation in
  60. Task {
  61. let urlRequest: URLRequest
  62. do {
  63. urlRequest = try await self.urlRequest(request: request)
  64. } catch {
  65. continuation.finish(throwing: error)
  66. return
  67. }
  68. #if DEBUG
  69. printCURLCommand(from: urlRequest)
  70. #endif
  71. let stream: URLSession.AsyncBytes
  72. let rawResponse: URLResponse
  73. do {
  74. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  75. } catch {
  76. continuation.finish(throwing: error)
  77. return
  78. }
  79. // Verify the status code is 200
  80. let response: HTTPURLResponse
  81. do {
  82. response = try httpResponse(urlResponse: rawResponse)
  83. } catch {
  84. continuation.finish(throwing: error)
  85. return
  86. }
  87. // Verify the status code is 200
  88. guard response.statusCode == 200 else {
  89. VertexLog.error(
  90. code: .loadRequestStreamResponseError,
  91. "The server responded with an error: \(response)"
  92. )
  93. var responseBody = ""
  94. for try await line in stream.lines {
  95. responseBody += line + "\n"
  96. }
  97. VertexLog.error(
  98. code: .loadRequestStreamResponseErrorPayload,
  99. "Response payload: \(responseBody)"
  100. )
  101. continuation.finish(throwing: parseError(responseBody: responseBody))
  102. return
  103. }
  104. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  105. var extraLines = ""
  106. let decoder = JSONDecoder()
  107. decoder.keyDecodingStrategy = .convertFromSnakeCase
  108. for try await line in stream.lines {
  109. VertexLog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")
  110. if line.hasPrefix("data:") {
  111. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  112. let jsonText = String(line.dropFirst(5))
  113. let data: Data
  114. do {
  115. data = try jsonData(jsonText: jsonText)
  116. } catch {
  117. continuation.finish(throwing: error)
  118. return
  119. }
  120. // Handle the content.
  121. do {
  122. let content = try parseResponse(T.Response.self, from: data)
  123. continuation.yield(content)
  124. } catch {
  125. continuation.finish(throwing: error)
  126. return
  127. }
  128. } else {
  129. extraLines += line
  130. }
  131. }
  132. if extraLines.count > 0 {
  133. continuation.finish(throwing: parseError(responseBody: extraLines))
  134. return
  135. }
  136. continuation.finish(throwing: nil)
  137. }
  138. }
  139. }
  140. // MARK: - Private Helpers
  141. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  142. var urlRequest = URLRequest(url: request.url)
  143. urlRequest.httpMethod = "POST"
  144. urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key")
  145. urlRequest.setValue(
  146. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  147. forHTTPHeaderField: "x-goog-api-client"
  148. )
  149. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  150. if let appCheck = firebaseInfo.appCheck {
  151. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  152. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  153. if let error = tokenResult.error {
  154. VertexLog.error(
  155. code: .appCheckTokenFetchFailed,
  156. "Failed to fetch AppCheck token. Error: \(error)"
  157. )
  158. }
  159. }
  160. if let auth = firebaseInfo.auth, let authToken = try await auth.getToken(
  161. forcingRefresh: false
  162. ) {
  163. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  164. }
  165. // TODO: wait for release approval.
  166. // if firebaseInfo.app.isDataCollectionDefaultEnabled {
  167. // urlRequest.setValue(firebaseInfo.googleAppID, forHTTPHeaderField: "X-Firebase-AppId")
  168. // }
  169. let encoder = JSONEncoder()
  170. urlRequest.httpBody = try encoder.encode(request)
  171. urlRequest.timeoutInterval = request.options.timeout
  172. return urlRequest
  173. }
  174. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  175. // The following condition should always be true: "Whenever you make HTTP URL load requests, any
  176. // response objects you get back from the URLSession, NSURLConnection, or NSURLDownload class
  177. // are instances of the HTTPURLResponse class."
  178. guard let response = urlResponse as? HTTPURLResponse else {
  179. VertexLog.error(
  180. code: .generativeAIServiceNonHTTPResponse,
  181. "Response wasn't an HTTP response, internal error \(urlResponse)"
  182. )
  183. throw URLError(
  184. .badServerResponse,
  185. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  186. )
  187. }
  188. return response
  189. }
  190. private func jsonData(jsonText: String) throws -> Data {
  191. guard let data = jsonText.data(using: .utf8) else {
  192. throw DecodingError.dataCorrupted(DecodingError.Context(
  193. codingPath: [],
  194. debugDescription: "Could not parse response as UTF8."
  195. ))
  196. }
  197. return data
  198. }
  199. private func parseError(responseBody: String) -> Error {
  200. do {
  201. let data = try jsonData(jsonText: responseBody)
  202. return parseError(responseData: data)
  203. } catch {
  204. return error
  205. }
  206. }
  207. private func parseError(responseData: Data) -> Error {
  208. do {
  209. let rpcError = try JSONDecoder().decode(BackendError.self, from: responseData)
  210. logRPCError(rpcError)
  211. return rpcError
  212. } catch {
  213. // TODO: Return an error about an unrecognized error payload with the response body
  214. return error
  215. }
  216. }
  217. // Log specific RPC errors that cannot be mitigated or handled by user code.
  218. // These errors do not produce specific GenerateContentError or CountTokensError cases.
  219. private func logRPCError(_ error: BackendError) {
  220. let projectID = firebaseInfo.projectID
  221. if error.isVertexAIInFirebaseServiceDisabledError() {
  222. VertexLog.error(code: .vertexAIInFirebaseAPIDisabled, """
  223. The Vertex AI in Firebase SDK requires the Vertex AI in Firebase API \
  224. (`firebasevertexai.googleapis.com`) to be enabled in your Firebase project. Enable this API \
  225. by visiting the Firebase Console at
  226. https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \
  227. If you enabled this API recently, wait a few minutes for the action to propagate to our \
  228. systems and then retry.
  229. """)
  230. }
  231. }
  232. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  233. do {
  234. return try JSONDecoder().decode(type, from: data)
  235. } catch {
  236. if let json = String(data: data, encoding: .utf8) {
  237. VertexLog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
  238. }
  239. VertexLog.error(
  240. code: .loadRequestParseResponseFailedJSONError,
  241. "Error decoding server JSON: \(error)"
  242. )
  243. throw error
  244. }
  245. }
  246. #if DEBUG
  247. private func cURLCommand(from request: URLRequest) -> String {
  248. var returnValue = "curl "
  249. if let allHeaders = request.allHTTPHeaderFields {
  250. for (key, value) in allHeaders {
  251. returnValue += "-H '\(key): \(value)' "
  252. }
  253. }
  254. guard let url = request.url else { return "" }
  255. returnValue += "'\(url.absoluteString)' "
  256. guard let body = request.httpBody,
  257. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  258. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  259. returnValue += "-d '\(escapedJSON)'"
  260. return returnValue
  261. }
  262. private func printCURLCommand(from request: URLRequest) {
  263. guard VertexLog.additionalLoggingEnabled() else {
  264. return
  265. }
  266. let command = cURLCommand(from: request)
  267. os_log(.debug, log: VertexLog.logObject, """
  268. \(VertexLog.service) Creating request with the equivalent cURL command:
  269. ----- cURL command -----
  270. \(command, privacy: .private)
  271. ------------------------
  272. """)
  273. }
  274. #endif // DEBUG
  275. }