GenerativeAIService.swift 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
  19. struct GenerativeAIService {
  20. /// The language of the SDK in the format `gl-<language>/<version>`.
  21. static let languageTag = "gl-swift/5"
  22. /// The Firebase SDK version in the format `fire/<version>`.
  23. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  24. /// Gives permission to talk to the backend.
  25. private let apiKey: String
  26. private let appCheck: AppCheckInterop?
  27. private let auth: AuthInterop?
  28. private let urlSession: URLSession
  29. init(apiKey: String, appCheck: AppCheckInterop?, auth: AuthInterop?, urlSession: URLSession) {
  30. self.apiKey = apiKey
  31. self.appCheck = appCheck
  32. self.auth = auth
  33. self.urlSession = urlSession
  34. }
  35. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  36. let urlRequest = try await urlRequest(request: request)
  37. #if DEBUG
  38. printCURLCommand(from: urlRequest)
  39. #endif
  40. let data: Data
  41. let rawResponse: URLResponse
  42. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  43. let response = try httpResponse(urlResponse: rawResponse)
  44. // Verify the status code is 200
  45. guard response.statusCode == 200 else {
  46. Logging.default.error("[FirebaseVertexAI] The server responded with an error: \(response)")
  47. if let responseString = String(data: data, encoding: .utf8) {
  48. Logging.network.error("[FirebaseVertexAI] Response payload: \(responseString)")
  49. }
  50. throw parseError(responseData: data)
  51. }
  52. return try parseResponse(T.Response.self, from: data)
  53. }
  54. @available(macOS 12.0, *)
  55. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  56. -> AsyncThrowingStream<T.Response, Error> {
  57. return AsyncThrowingStream { continuation in
  58. Task {
  59. let urlRequest: URLRequest
  60. do {
  61. urlRequest = try await self.urlRequest(request: request)
  62. } catch {
  63. continuation.finish(throwing: error)
  64. return
  65. }
  66. #if DEBUG
  67. printCURLCommand(from: urlRequest)
  68. #endif
  69. let stream: URLSession.AsyncBytes
  70. let rawResponse: URLResponse
  71. do {
  72. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  73. } catch {
  74. continuation.finish(throwing: error)
  75. return
  76. }
  77. // Verify the status code is 200
  78. let response: HTTPURLResponse
  79. do {
  80. response = try httpResponse(urlResponse: rawResponse)
  81. } catch {
  82. continuation.finish(throwing: error)
  83. return
  84. }
  85. // Verify the status code is 200
  86. guard response.statusCode == 200 else {
  87. Logging.default
  88. .error("[FirebaseVertexAI] The server responded with an error: \(response)")
  89. var responseBody = ""
  90. for try await line in stream.lines {
  91. responseBody += line + "\n"
  92. }
  93. Logging.network.error("[FirebaseVertexAI] Response payload: \(responseBody)")
  94. continuation.finish(throwing: parseError(responseBody: responseBody))
  95. return
  96. }
  97. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  98. var extraLines: String = ""
  99. let decoder = JSONDecoder()
  100. decoder.keyDecodingStrategy = .convertFromSnakeCase
  101. for try await line in stream.lines {
  102. Logging.network.debug("[FirebaseVertexAI] Stream response: \(line)")
  103. if line.hasPrefix("data:") {
  104. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  105. let jsonText = String(line.dropFirst(5))
  106. let data: Data
  107. do {
  108. data = try jsonData(jsonText: jsonText)
  109. } catch {
  110. continuation.finish(throwing: error)
  111. return
  112. }
  113. // Handle the content.
  114. do {
  115. let content = try parseResponse(T.Response.self, from: data)
  116. continuation.yield(content)
  117. } catch {
  118. continuation.finish(throwing: error)
  119. return
  120. }
  121. } else {
  122. extraLines += line
  123. }
  124. }
  125. if extraLines.count > 0 {
  126. continuation.finish(throwing: parseError(responseBody: extraLines))
  127. return
  128. }
  129. continuation.finish(throwing: nil)
  130. }
  131. }
  132. }
  133. // MARK: - Private Helpers
  134. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  135. var urlRequest = URLRequest(url: request.url)
  136. urlRequest.httpMethod = "POST"
  137. urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
  138. urlRequest.setValue(
  139. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  140. forHTTPHeaderField: "x-goog-api-client"
  141. )
  142. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  143. if let appCheck {
  144. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  145. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  146. if let error = tokenResult.error {
  147. Logging.default
  148. .debug("[FirebaseVertexAI] Failed to fetch AppCheck token. Error: \(error)")
  149. }
  150. }
  151. if let auth, let authToken = try await auth.getToken(forcingRefresh: false) {
  152. urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
  153. }
  154. let encoder = JSONEncoder()
  155. encoder.keyEncodingStrategy = .convertToSnakeCase
  156. urlRequest.httpBody = try encoder.encode(request)
  157. if let timeoutInterval = request.options.timeout {
  158. urlRequest.timeoutInterval = timeoutInterval
  159. }
  160. return urlRequest
  161. }
  162. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  163. // Verify the status code is 200
  164. guard let response = urlResponse as? HTTPURLResponse else {
  165. Logging.default
  166. .error(
  167. "[FirebaseVertexAI] Response wasn't an HTTP response, internal error \(urlResponse)"
  168. )
  169. throw NSError(
  170. domain: "com.google.generative-ai",
  171. code: -1,
  172. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  173. )
  174. }
  175. return response
  176. }
  177. private func jsonData(jsonText: String) throws -> Data {
  178. guard let data = jsonText.data(using: .utf8) else {
  179. let error = NSError(
  180. domain: "com.google.generative-ai",
  181. code: -1,
  182. userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."]
  183. )
  184. throw error
  185. }
  186. return data
  187. }
  188. private func parseError(responseBody: String) -> Error {
  189. do {
  190. let data = try jsonData(jsonText: responseBody)
  191. return parseError(responseData: data)
  192. } catch {
  193. return error
  194. }
  195. }
  196. private func parseError(responseData: Data) -> Error {
  197. do {
  198. return try JSONDecoder().decode(RPCError.self, from: responseData)
  199. } catch {
  200. // TODO: Return an error about an unrecognized error payload with the response body
  201. return error
  202. }
  203. }
  204. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  205. do {
  206. return try JSONDecoder().decode(type, from: data)
  207. } catch {
  208. if let json = String(data: data, encoding: .utf8) {
  209. Logging.network.error("[FirebaseVertexAI] JSON response: \(json)")
  210. }
  211. Logging.default.error("[FirebaseVertexAI] Error decoding server JSON: \(error)")
  212. throw error
  213. }
  214. }
  215. #if DEBUG
  216. private func cURLCommand(from request: URLRequest) -> String {
  217. var returnValue = "curl "
  218. if let allHeaders = request.allHTTPHeaderFields {
  219. for (key, value) in allHeaders {
  220. returnValue += "-H '\(key): \(value)' "
  221. }
  222. }
  223. guard let url = request.url else { return "" }
  224. returnValue += "'\(url.absoluteString)' "
  225. guard let body = request.httpBody,
  226. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  227. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  228. returnValue += "-d '\(escapedJSON)'"
  229. return returnValue
  230. }
  231. private func printCURLCommand(from request: URLRequest) {
  232. let command = cURLCommand(from: request)
  233. Logging.verbose.debug("""
  234. [FirebaseVertexAI] Creating request with the equivalent cURL command:
  235. ----- cURL command -----
  236. \(command, privacy: .private)
  237. ------------------------
  238. """)
  239. }
  240. #endif // DEBUG
  241. }