GenerativeAIService.swift 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseCore
  16. import Foundation
  17. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
  18. struct GenerativeAIService {
  19. /// The language of the SDK in the format `gl-<language>/<version>`.
  20. static let languageTag = "gl-swift/5"
  21. /// The Firebase SDK version in the format `fire/<version>`.
  22. static let firebaseVersionTag = "fire/\(FirebaseVersion())"
  23. /// Gives permission to talk to the backend.
  24. private let apiKey: String
  25. private let appCheck: AppCheckInterop?
  26. private let urlSession: URLSession
  27. init(apiKey: String, appCheck: AppCheckInterop?, urlSession: URLSession) {
  28. self.apiKey = apiKey
  29. self.appCheck = appCheck
  30. self.urlSession = urlSession
  31. }
  32. func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
  33. let urlRequest = try await urlRequest(request: request)
  34. #if DEBUG
  35. printCURLCommand(from: urlRequest)
  36. #endif
  37. let data: Data
  38. let rawResponse: URLResponse
  39. (data, rawResponse) = try await urlSession.data(for: urlRequest)
  40. let response = try httpResponse(urlResponse: rawResponse)
  41. // Verify the status code is 200
  42. guard response.statusCode == 200 else {
  43. Logging.default.error("[GoogleGenerativeAI] The server responded with an error: \(response)")
  44. if let responseString = String(data: data, encoding: .utf8) {
  45. Logging.network.error("[GoogleGenerativeAI] Response payload: \(responseString)")
  46. }
  47. throw parseError(responseData: data)
  48. }
  49. return try parseResponse(T.Response.self, from: data)
  50. }
  51. @available(macOS 12.0, *)
  52. func loadRequestStream<T: GenerativeAIRequest>(request: T)
  53. -> AsyncThrowingStream<T.Response, Error> {
  54. return AsyncThrowingStream { continuation in
  55. Task {
  56. let urlRequest: URLRequest
  57. do {
  58. urlRequest = try await self.urlRequest(request: request)
  59. } catch {
  60. continuation.finish(throwing: error)
  61. return
  62. }
  63. #if DEBUG
  64. printCURLCommand(from: urlRequest)
  65. #endif
  66. let stream: URLSession.AsyncBytes
  67. let rawResponse: URLResponse
  68. do {
  69. (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
  70. } catch {
  71. continuation.finish(throwing: error)
  72. return
  73. }
  74. // Verify the status code is 200
  75. let response: HTTPURLResponse
  76. do {
  77. response = try httpResponse(urlResponse: rawResponse)
  78. } catch {
  79. continuation.finish(throwing: error)
  80. return
  81. }
  82. // Verify the status code is 200
  83. guard response.statusCode == 200 else {
  84. Logging.default
  85. .error("[GoogleGenerativeAI] The server responded with an error: \(response)")
  86. var responseBody = ""
  87. for try await line in stream.lines {
  88. responseBody += line + "\n"
  89. }
  90. Logging.network.error("[GoogleGenerativeAI] Response payload: \(responseBody)")
  91. continuation.finish(throwing: parseError(responseBody: responseBody))
  92. return
  93. }
  94. // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
  95. var extraLines: String = ""
  96. let decoder = JSONDecoder()
  97. decoder.keyDecodingStrategy = .convertFromSnakeCase
  98. for try await line in stream.lines {
  99. Logging.network.debug("[GoogleGenerativeAI] Stream response: \(line)")
  100. if line.hasPrefix("data:") {
  101. // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
  102. let jsonText = String(line.dropFirst(5))
  103. let data: Data
  104. do {
  105. data = try jsonData(jsonText: jsonText)
  106. } catch {
  107. continuation.finish(throwing: error)
  108. return
  109. }
  110. // Handle the content.
  111. do {
  112. let content = try parseResponse(T.Response.self, from: data)
  113. continuation.yield(content)
  114. } catch {
  115. continuation.finish(throwing: error)
  116. return
  117. }
  118. } else {
  119. extraLines += line
  120. }
  121. }
  122. if extraLines.count > 0 {
  123. continuation.finish(throwing: parseError(responseBody: extraLines))
  124. return
  125. }
  126. continuation.finish(throwing: nil)
  127. }
  128. }
  129. }
  130. // MARK: - Private Helpers
  131. private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
  132. var urlRequest = URLRequest(url: request.url)
  133. urlRequest.httpMethod = "POST"
  134. urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
  135. urlRequest.setValue(
  136. "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
  137. forHTTPHeaderField: "x-goog-api-client"
  138. )
  139. urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
  140. if let appCheck {
  141. let tokenResult = await appCheck.getToken(forcingRefresh: false)
  142. urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
  143. if let error = tokenResult.error {
  144. Logging.default
  145. .debug("[GoogleGenerativeAI] Failed to fetch AppCheck token. Error: \(error)")
  146. }
  147. }
  148. let encoder = JSONEncoder()
  149. encoder.keyEncodingStrategy = .convertToSnakeCase
  150. urlRequest.httpBody = try encoder.encode(request)
  151. if let timeoutInterval = request.options.timeout {
  152. urlRequest.timeoutInterval = timeoutInterval
  153. }
  154. return urlRequest
  155. }
  156. private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
  157. // Verify the status code is 200
  158. guard let response = urlResponse as? HTTPURLResponse else {
  159. Logging.default
  160. .error(
  161. "[GoogleGenerativeAI] Response wasn't an HTTP response, internal error \(urlResponse)"
  162. )
  163. throw NSError(
  164. domain: "com.google.generative-ai",
  165. code: -1,
  166. userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
  167. )
  168. }
  169. return response
  170. }
  171. private func jsonData(jsonText: String) throws -> Data {
  172. guard let data = jsonText.data(using: .utf8) else {
  173. let error = NSError(
  174. domain: "com.google.generative-ai",
  175. code: -1,
  176. userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."]
  177. )
  178. throw error
  179. }
  180. return data
  181. }
  182. private func parseError(responseBody: String) -> Error {
  183. do {
  184. let data = try jsonData(jsonText: responseBody)
  185. return parseError(responseData: data)
  186. } catch {
  187. return error
  188. }
  189. }
  190. private func parseError(responseData: Data) -> Error {
  191. do {
  192. return try JSONDecoder().decode(RPCError.self, from: responseData)
  193. } catch {
  194. // TODO: Return an error about an unrecognized error payload with the response body
  195. return error
  196. }
  197. }
  198. private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
  199. do {
  200. return try JSONDecoder().decode(type, from: data)
  201. } catch {
  202. if let json = String(data: data, encoding: .utf8) {
  203. Logging.network.error("[GoogleGenerativeAI] JSON response: \(json)")
  204. }
  205. Logging.default.error("[GoogleGenerativeAI] Error decoding server JSON: \(error)")
  206. throw error
  207. }
  208. }
  209. #if DEBUG
  210. private func cURLCommand(from request: URLRequest) -> String {
  211. var returnValue = "curl "
  212. if let allHeaders = request.allHTTPHeaderFields {
  213. for (key, value) in allHeaders {
  214. returnValue += "-H '\(key): \(value)' "
  215. }
  216. }
  217. guard let url = request.url else { return "" }
  218. returnValue += "'\(url.absoluteString)' "
  219. guard let body = request.httpBody,
  220. let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
  221. let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
  222. returnValue += "-d '\(escapedJSON)'"
  223. return returnValue
  224. }
  225. private func printCURLCommand(from request: URLRequest) {
  226. let command = cURLCommand(from: request)
  227. Logging.verbose.debug("""
  228. [GoogleGenerativeAI] Creating request with the equivalent cURL command:
  229. ----- cURL command -----
  230. \(command, privacy: .private)
  231. ------------------------
  232. """)
  233. }
  234. #endif // DEBUG
  235. }