| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- // Copyright 2023 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- import FirebaseAppCheckInterop
- import FirebaseAuthInterop
- import FirebaseCore
- import Foundation
- import os.log
- @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
- struct GenerativeAIService {
- /// The language of the SDK in the format `gl-<language>/<version>`.
- static let languageTag = "gl-swift/5"
- /// The Firebase SDK version in the format `fire/<version>`.
- static let firebaseVersionTag = "fire/\(FirebaseVersion())"
- private let firebaseInfo: FirebaseInfo
- private let urlSession: URLSession
- init(firebaseInfo: FirebaseInfo, urlSession: URLSession) {
- self.firebaseInfo = firebaseInfo
- self.urlSession = urlSession
- }
- func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
- let urlRequest = try await urlRequest(request: request)
- #if DEBUG
- printCURLCommand(from: urlRequest)
- #endif
- let data: Data
- let rawResponse: URLResponse
- (data, rawResponse) = try await urlSession.data(for: urlRequest)
- let response = try httpResponse(urlResponse: rawResponse)
- // Verify the status code is 200
- guard response.statusCode == 200 else {
- AILog.error(
- code: .loadRequestResponseError,
- "The server responded with an error: \(response)"
- )
- if let responseString = String(data: data, encoding: .utf8) {
- AILog.error(
- code: .loadRequestResponseErrorPayload,
- "Response payload: \(responseString)"
- )
- }
- throw parseError(responseData: data)
- }
- return try parseResponse(T.Response.self, from: data)
- }
- @available(macOS 12.0, *)
- func loadRequestStream<T: GenerativeAIRequest>(request: T)
- -> AsyncThrowingStream<T.Response, Error> where T: Sendable {
- return AsyncThrowingStream { continuation in
- Task {
- let urlRequest: URLRequest
- do {
- urlRequest = try await self.urlRequest(request: request)
- } catch {
- continuation.finish(throwing: error)
- return
- }
- #if DEBUG
- printCURLCommand(from: urlRequest)
- #endif
- let stream: URLSession.AsyncBytes
- let rawResponse: URLResponse
- do {
- (stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
- } catch {
- continuation.finish(throwing: error)
- return
- }
- // Verify the status code is 200
- let response: HTTPURLResponse
- do {
- response = try httpResponse(urlResponse: rawResponse)
- } catch {
- continuation.finish(throwing: error)
- return
- }
- // Verify the status code is 200
- guard response.statusCode == 200 else {
- AILog.error(
- code: .loadRequestStreamResponseError,
- "The server responded with an error: \(response)"
- )
- var responseBody = ""
- for try await line in stream.lines {
- responseBody += line + "\n"
- }
- AILog.error(
- code: .loadRequestStreamResponseErrorPayload,
- "Response payload: \(responseBody)"
- )
- continuation.finish(throwing: parseError(responseBody: responseBody))
- return
- }
- // Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
- var extraLines = ""
- let decoder = JSONDecoder()
- decoder.keyDecodingStrategy = .convertFromSnakeCase
- for try await line in stream.lines {
- AILog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")
- if line.hasPrefix("data:") {
- // We can assume 5 characters since it's utf-8 encoded, removing `data:`.
- let jsonText = String(line.dropFirst(5))
- let data: Data
- do {
- data = try jsonData(jsonText: jsonText)
- } catch {
- continuation.finish(throwing: error)
- return
- }
- // Handle the content.
- do {
- let content = try parseResponse(T.Response.self, from: data)
- continuation.yield(content)
- } catch {
- continuation.finish(throwing: error)
- return
- }
- } else {
- extraLines += line
- }
- }
- if extraLines.count > 0 {
- continuation.finish(throwing: parseError(responseBody: extraLines))
- return
- }
- continuation.finish(throwing: nil)
- }
- }
- }
- // MARK: - Private Helpers
- private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
- var urlRequest = URLRequest(url: request.url)
- urlRequest.httpMethod = "POST"
- urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key")
- urlRequest.setValue(
- "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)",
- forHTTPHeaderField: "x-goog-api-client"
- )
- urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
- if let appCheck = firebaseInfo.appCheck {
- let tokenResult = await appCheck.getToken(forcingRefresh: false)
- urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
- if let error = tokenResult.error {
- AILog.error(
- code: .appCheckTokenFetchFailed,
- "Failed to fetch AppCheck token. Error: \(error)"
- )
- }
- }
- if let auth = firebaseInfo.auth, let authToken = try await auth.getToken(
- forcingRefresh: false
- ) {
- urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization")
- }
- if firebaseInfo.app.isDataCollectionDefaultEnabled {
- urlRequest.setValue(firebaseInfo.firebaseAppID, forHTTPHeaderField: "X-Firebase-AppId")
- if let appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String {
- urlRequest.setValue(appVersion, forHTTPHeaderField: "X-Firebase-AppVersion")
- }
- }
- let encoder = JSONEncoder()
- urlRequest.httpBody = try encoder.encode(request)
- urlRequest.timeoutInterval = request.options.timeout
- return urlRequest
- }
- private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
- // The following condition should always be true: "Whenever you make HTTP URL load requests, any
- // response objects you get back from the URLSession, NSURLConnection, or NSURLDownload class
- // are instances of the HTTPURLResponse class."
- guard let response = urlResponse as? HTTPURLResponse else {
- AILog.error(
- code: .generativeAIServiceNonHTTPResponse,
- "Response wasn't an HTTP response, internal error \(urlResponse)"
- )
- throw URLError(
- .badServerResponse,
- userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
- )
- }
- return response
- }
- private func jsonData(jsonText: String) throws -> Data {
- guard let data = jsonText.data(using: .utf8) else {
- throw DecodingError.dataCorrupted(DecodingError.Context(
- codingPath: [],
- debugDescription: "Could not parse response as UTF8."
- ))
- }
- return data
- }
- private func parseError(responseBody: String) -> Error {
- do {
- let data = try jsonData(jsonText: responseBody)
- return parseError(responseData: data)
- } catch {
- return error
- }
- }
- private func parseError(responseData: Data) -> Error {
- do {
- let rpcError = try JSONDecoder().decode(BackendError.self, from: responseData)
- logRPCError(rpcError)
- return rpcError
- } catch {
- // TODO: Return an error about an unrecognized error payload with the response body
- return error
- }
- }
- // Log specific RPC errors that cannot be mitigated or handled by user code.
- // These errors do not produce specific GenerateContentError or CountTokensError cases.
- private func logRPCError(_ error: BackendError) {
- let projectID = firebaseInfo.projectID
- if error.isVertexAIInFirebaseServiceDisabledError() {
- AILog.error(code: .vertexAIInFirebaseAPIDisabled, """
- The Firebase AI SDK requires the Firebase AI API \
- (`firebasevertexai.googleapis.com`) to be enabled in your Firebase project. Enable this API \
- by visiting the Firebase Console at
- https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \
- If you enabled this API recently, wait a few minutes for the action to propagate to our \
- systems and then retry.
- """)
- }
- }
- private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
- do {
- return try JSONDecoder().decode(type, from: data)
- } catch {
- if let json = String(data: data, encoding: .utf8) {
- AILog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
- }
- AILog.error(
- code: .loadRequestParseResponseFailedJSONError,
- "Error decoding server JSON: \(error)"
- )
- throw error
- }
- }
- #if DEBUG
- private func cURLCommand(from request: URLRequest) -> String {
- var returnValue = "curl "
- if let allHeaders = request.allHTTPHeaderFields {
- for (key, value) in allHeaders {
- returnValue += "-H '\(key): \(value)' "
- }
- }
- guard let url = request.url else { return "" }
- returnValue += "'\(url.absoluteString)' "
- guard let body = request.httpBody,
- let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
- let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
- returnValue += "-d '\(escapedJSON)'"
- return returnValue
- }
- private func printCURLCommand(from request: URLRequest) {
- guard AILog.additionalLoggingEnabled() else {
- return
- }
- let command = cURLCommand(from: request)
- os_log(.debug, log: AILog.logObject, """
- \(AILog.service) Creating request with the equivalent cURL command:
- ----- cURL command -----
- \(command, privacy: .private)
- ------------------------
- """)
- }
- #endif // DEBUG
- }
|