GenerativeAIService.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import FirebaseSharedSwift
  18. import Foundation
  19. import os.log
  20. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  21. struct GenerativeAIService {
  22. /// The language of the SDK in the format `gl-<language>/<version>`.
  23. static let languageTag = "gl-swift/5"
  24. /// The Firebase SDK version in the format `fire/<version>`.
  25. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  26. private let projectID: String
  27. /// Gives permission to talk to the backend.
  28. private let apiKey: String
  29. private let appCheck: AppCheckInterop?
  30. private let auth: AuthInterop?
  31. private let urlSession: URLSession
  32. init(projectID: String, apiKey: String, appCheck: AppCheckInterop?, auth: AuthInterop?,
  33. urlSession: URLSession) {
  34. self.projectID = projectID
  35. self.apiKey = apiKey
  36. self.appCheck = appCheck
  37. self.auth = auth
  38. self.urlSession = urlSession
  39. }
  40. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  41. let urlRequest = try await urlRequest(request: request)
  42. #if DEBUG
  43. printCURLCommand(from: urlRequest)
  44. #endif
  45. let data: Data
  46. let rawResponse: URLResponse
  47. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  48. let response = try httpResponse(urlResponse: rawResponse)
  49. // Verify the status code is 200
  50. guard response.statusCode == 200 else {
  51. FirebaseLogger.vertexAI.error(
  52. messageID: LogMessageID.loadRequestError.rawValue,
  53. "The server responded with an error: \(response)"
  54. )
  55. if let responseString = String(data: data, encoding: .utf8) {
  56. FirebaseLogger.vertexAI.error(
  57. messageID: LogMessageID.loadRequestErrorResponse.rawValue,
  58. "Response payload: \(responseString)"
  59. )
  60. }
  61. throw parseError(responseData: data)
  62. }
  63. return try parseResponse(T.Response.self, from: data)
  64. }
  65. @available(macOS 12.0, *)
  66. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  67. -> AsyncThrowingStream<T.Response, Error> {
  68. return AsyncThrowingStream { continuation in
  69. Task {
  70. let urlRequest: URLRequest
  71. do {
  72. urlRequest = try await self.urlRequest(request: request)
  73. } catch {
  74. continuation.finish(throwing: error)
  75. return
  76. }
  77. #if DEBUG
  78. printCURLCommand(from: urlRequest)
  79. #endif
  80. let stream: URLSession.AsyncBytes
  81. let rawResponse: URLResponse
  82. do {
  83. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  84. } catch {
  85. continuation.finish(throwing: error)
  86. return
  87. }
  88. // Verify the status code is 200
  89. let response: HTTPURLResponse
  90. do {
  91. response = try httpResponse(urlResponse: rawResponse)
  92. } catch {
  93. continuation.finish(throwing: error)
  94. return
  95. }
  96. // Verify the status code is 200
  97. guard response.statusCode == 200 else {
  98. FirebaseLogger.vertexAI.error(
  99. messageID: LogMessageID.loadRequestStreamError.rawValue,
  100. "The server responded with an error: \(response)"
  101. )
  102. var responseBody = ""
  103. for try await line in stream.lines {
  104. responseBody += line + "\n"
  105. }
  106. FirebaseLogger.vertexAI.error(
  107. messageID: LogMessageID.loadRequestStreamErrorResponse.rawValue,
  108. "Response payload: \(responseBody)"
  109. )
  110. continuation.finish(throwing: parseError(responseBody: responseBody))
  111. return
  112. }
  113. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  114. var extraLines = ""
  115. let decoder = JSONDecoder()
  116. decoder.keyDecodingStrategy = .convertFromSnakeCase
  117. for try await line in stream.lines {
  118. FirebaseLogger.vertexAI.debug(
  119. messageID: LogMessageID.loadRequestStreamLine.rawValue,
  120. "Stream response: \(line)"
  121. )
  122. if line.hasPrefix("data:") {
  123. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  124. let jsonText = String(line.dropFirst(5))
  125. let data: Data
  126. do {
  127. data = try jsonData(jsonText: jsonText)
  128. } catch {
  129. continuation.finish(throwing: error)
  130. return
  131. }
  132. // Handle the content.
  133. do {
  134. let content = try parseResponse(T.Response.self, from: data)
  135. continuation.yield(content)
  136. } catch {
  137. continuation.finish(throwing: error)
  138. return
  139. }
  140. } else {
  141. extraLines += line
  142. }
  143. }
  144. if extraLines.count > 0 {
  145. continuation.finish(throwing: parseError(responseBody: extraLines))
  146. return
  147. }
  148. continuation.finish(throwing: nil)
  149. }
  150. }
  151. }
  152. // MARK: - Private Helpers
  153. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  154. var urlRequest = URLRequest(url: request.url)
  155. urlRequest.httpMethod = "POST"
  156. urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
  157. urlRequest.setValue(
  158. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  159. forHTTPHeaderField: "x-goog-api-client"
  160. )
  161. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  162. if let appCheck {
  163. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  164. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  165. if let error = tokenResult.error {
  166. FirebaseLogger.vertexAI.debug(
  167. messageID: LogMessageID.appCheckTokenFetchError.rawValue,
  168. "Failed to fetch AppCheck token. Error: \(error)"
  169. )
  170. }
  171. }
  172. if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
  173. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  174. }
  175. let encoder = JSONEncoder()
  176. encoder.keyEncodingStrategy = .convertToSnakeCase
  177. urlRequest.httpBody = try encoder.encode(request)
  178. if let timeoutInterval = request.options.timeout {
  179. urlRequest.timeoutInterval = timeoutInterval
  180. }
  181. return urlRequest
  182. }
  183. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  184. // Verify the status code is 200
  185. guard let response = urlResponse as? HTTPURLResponse else {
  186. FirebaseLogger.vertexAI.error(
  187. messageID: LogMessageID.nonHTTPResponseError.rawValue,
  188. "Response wasn't an HTTP response, internal error \(urlResponse)"
  189. )
  190. throw NSError(
  191. domain: "com.google.generative-ai",
  192. code: -1,
  193. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  194. )
  195. }
  196. return response
  197. }
  198. private func jsonData(jsonText: String) throws -> Data {
  199. guard let data = jsonText.data(using: .utf8) else {
  200. let error = NSError(
  201. domain: "com.google.generative-ai",
  202. code: -1,
  203. userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."]
  204. )
  205. throw error
  206. }
  207. return data
  208. }
  209. private func parseError(responseBody: String) -> Error {
  210. do {
  211. let data = try jsonData(jsonText: responseBody)
  212. return parseError(responseData: data)
  213. } catch {
  214. return error
  215. }
  216. }
  217. private func parseError(responseData: Data) -> Error {
  218. do {
  219. let rpcError = try JSONDecoder().decode(RPCError.self, from: responseData)
  220. logRPCError(rpcError)
  221. return rpcError
  222. } catch {
  223. // TODO: Return an error about an unrecognized error payload with the response body
  224. return error
  225. }
  226. }
  227. // Log specific RPC errors that cannot be mitigated or handled by user code.
  228. // These errors do not produce specific GenerateContentError or CountTokensError cases.
  229. private func logRPCError(_ error: RPCError) {
  230. if error.isFirebaseMLServiceDisabledError() {
  231. FirebaseLogger.vertexAI.error(messageID: LogMessageID.mlServiceDisabledError.rawValue, """
  232. The Vertex AI for Firebase SDK requires the Firebase ML API `firebaseml.googleapis.com` to \
  233. be enabled for your project. Get started in the Firebase Console \
  234. (https://console.firebase.google.com/project/\(projectID)/genai/vertex) or verify that the \
  235. API is enabled in the Google Cloud Console \
  236. (https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview?project=\
  237. \(projectID)).
  238. """)
  239. }
  240. }
  241. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  242. do {
  243. return try JSONDecoder().decode(type, from: data)
  244. } catch {
  245. if let json = String(data: data, encoding: .utf8) {
  246. FirebaseLogger.vertexAI.error(
  247. messageID: LogMessageID.jsonResponseError.rawValue,
  248. "JSON response: \(json)"
  249. )
  250. }
  251. FirebaseLogger.vertexAI.error(
  252. messageID: LogMessageID.jsonDecodingError.rawValue,
  253. "Error decoding server JSON: \(error)"
  254. )
  255. throw error
  256. }
  257. }
  258. #if DEBUG
  259. private func cURLCommand(from request: URLRequest) -> String {
  260. var returnValue = "curl "
  261. if let allHeaders = request.allHTTPHeaderFields {
  262. for (key, value) in allHeaders {
  263. returnValue += "-H '\(key): \(value)' "
  264. }
  265. }
  266. guard let url = request.url else { return "" }
  267. returnValue += "'\(url.absoluteString)' "
  268. guard let body = request.httpBody,
  269. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  270. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  271. returnValue += "-d '\(escapedJSON)'"
  272. return returnValue
  273. }
  274. private func printCURLCommand(from request: URLRequest) {
  275. let command = cURLCommand(from: request)
  276. let logLevel = FirebaseLoggerLevel.debug
  277. let messagePrefix = FirebaseLogger.vertexAI
  278. .messagePrefix(messageID: LogMessageID.curlCommandForRequest.rawValue)
  279. if FirebaseLogger.isLoggableLevel(logLevel) {
  280. FirebaseLogger.vertexAI.osLogger().log(level: FirebaseLogger.osLogType(logLevel), """
  281. \(messagePrefix, privacy: .public) Creating request with the equivalent cURL command:
  282. ----- cURL command -----
  283. \(command, privacy: .private)
  284. ------------------------
  285. """)
  286. }
  287. }
  288. #endif // DEBUG
  289. }