// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import FirebaseAppCheckInterop import FirebaseAuthInterop import FirebaseCore import Foundation import os.log @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) struct GenerativeAIService { /// The language of the SDK in the format `gl-/`. static let languageTag = "gl-swift/5" /// The Firebase SDK version in the format `fire/`. static let firebaseVersionTag = "fire/\(FirebaseVersion())" private let firebaseInfo: FirebaseInfo private let urlSession: URLSession private let useLimitedUseAppCheckTokens: Bool init(firebaseInfo: FirebaseInfo, urlSession: URLSession, useLimitedUseAppCheckTokens: Bool) { self.firebaseInfo = firebaseInfo self.urlSession = urlSession self.useLimitedUseAppCheckTokens = useLimitedUseAppCheckTokens } func loadRequest(request: T) async throws -> T.Response { let urlRequest = try await urlRequest(request: request) #if DEBUG printCURLCommand(from: urlRequest) #endif let data: Data let rawResponse: URLResponse (data, rawResponse) = try await urlSession.data(for: urlRequest) let response = try httpResponse(urlResponse: rawResponse) // Verify the status code is 200 guard response.statusCode == 200 else { AILog.error( code: .loadRequestResponseError, "The server responded with an error: \(response)" ) if let responseString = String(data: data, encoding: .utf8) { AILog.error( code: .loadRequestResponseErrorPayload, "Response payload: \(responseString)" ) } throw parseError(responseData: data) } return try parseResponse(T.Response.self, from: data) } @available(macOS 12.0, *) func loadRequestStream(request: T) -> AsyncThrowingStream where T: Sendable { return AsyncThrowingStream { continuation in Task { let urlRequest: URLRequest do { urlRequest = try await self.urlRequest(request: request) } catch { continuation.finish(throwing: error) return } #if DEBUG printCURLCommand(from: urlRequest) #endif let stream: URLSession.AsyncBytes let rawResponse: URLResponse do { (stream, rawResponse) = try await urlSession.bytes(for: urlRequest) } catch { continuation.finish(throwing: error) return } // Verify the status code is 200 let response: HTTPURLResponse do { response = try httpResponse(urlResponse: rawResponse) } catch { continuation.finish(throwing: error) return } // Verify the status code is 200 guard response.statusCode == 200 else { AILog.error( code: .loadRequestStreamResponseError, "The server responded with an error: \(response)" ) var responseBody = "" for try await line in stream.lines { responseBody += line + "\n" } AILog.error( code: .loadRequestStreamResponseErrorPayload, "Response payload: \(responseBody)" ) continuation.finish(throwing: parseError(responseBody: responseBody)) return } // Received lines that are not server-sent events (SSE); these are not prefixed with "data:" var extraLines = "" let decoder = JSONDecoder() decoder.keyDecodingStrategy = .convertFromSnakeCase for try await line in stream.lines { AILog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)") if line.hasPrefix("data:") { // We can assume 5 characters since it's utf-8 encoded, removing `data:`. let jsonText = String(line.dropFirst(5)) let data: Data do { data = try jsonData(jsonText: jsonText) } catch { continuation.finish(throwing: error) return } // Handle the content. do { let content = try parseResponse(T.Response.self, from: data) continuation.yield(content) } catch { continuation.finish(throwing: error) return } } else { extraLines += line } } if extraLines.count > 0 { continuation.finish(throwing: parseError(responseBody: extraLines)) return } continuation.finish(throwing: nil) } } } // MARK: - Private Helpers private func urlRequest(request: T) async throws -> URLRequest { var urlRequest = URLRequest(url: request.url) urlRequest.httpMethod = "POST" urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key") urlRequest.setValue( "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)", forHTTPHeaderField: "x-goog-api-client" ) urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") if let appCheck = firebaseInfo.appCheck { let tokenResult = try await fetchAppCheckToken(appCheck: appCheck) urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck") if let error = tokenResult.error { AILog.error( code: .appCheckTokenFetchFailed, "Failed to fetch AppCheck token. Error: \(error)" ) } } if let auth = firebaseInfo.auth, let authToken = try await auth.getToken( forcingRefresh: false ) { urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization") } if firebaseInfo.app.isDataCollectionDefaultEnabled { urlRequest.setValue(firebaseInfo.firebaseAppID, forHTTPHeaderField: "X-Firebase-AppId") if let appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String { urlRequest.setValue(appVersion, forHTTPHeaderField: "X-Firebase-AppVersion") } } let encoder = JSONEncoder() urlRequest.httpBody = try encoder.encode(request) urlRequest.timeoutInterval = request.options.timeout return urlRequest } private func fetchAppCheckToken(appCheck: AppCheckInterop) async throws -> FIRAppCheckTokenResultInterop { if useLimitedUseAppCheckTokens { if let token = await getLimitedUseAppCheckToken(appCheck: appCheck) { return token } let errorMessage = "The provided App Check token provider doesn't implement getLimitedUseToken(), but requireLimitedUseTokens was enabled." #if Debug fatalError(errorMessage) #else throw NSError( domain: "\(Constants.baseErrorDomain).\(Self.self)", code: AILog.MessageCode.appCheckTokenFetchFailed.rawValue, userInfo: [NSLocalizedDescriptionKey: errorMessage] ) #endif } return await appCheck.getToken(forcingRefresh: false) } private func getLimitedUseAppCheckToken(appCheck: AppCheckInterop) async -> FIRAppCheckTokenResultInterop? { // At the moment, `await` doesn’t get along with Objective-C’s optional protocol methods. await withCheckedContinuation { (continuation: CheckedContinuation< FIRAppCheckTokenResultInterop?, Never >) in guard useLimitedUseAppCheckTokens, // `getLimitedUseToken(completion:)` is an optional protocol method. Optional binding // is performed to make sure `continuation` is called even if the method’s not implemented. let limitedUseTokenClosure = appCheck.getLimitedUseToken else { return continuation.resume(returning: nil) } limitedUseTokenClosure { tokenResult in // The placeholder token should be used in the case of App Check error. continuation.resume(returning: tokenResult) } } } private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse { // The following condition should always be true: "Whenever you make HTTP URL load requests, any // response objects you get back from the URLSession, NSURLConnection, or NSURLDownload class // are instances of the HTTPURLResponse class." guard let response = urlResponse as? HTTPURLResponse else { AILog.error( code: .generativeAIServiceNonHTTPResponse, "Response wasn't an HTTP response, internal error \(urlResponse)" ) throw URLError( .badServerResponse, userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."] ) } return response } private func jsonData(jsonText: String) throws -> Data { guard let data = jsonText.data(using: .utf8) else { throw DecodingError.dataCorrupted(DecodingError.Context( codingPath: [], debugDescription: "Could not parse response as UTF8." )) } return data } private func parseError(responseBody: String) -> Error { do { let data = try jsonData(jsonText: responseBody) return parseError(responseData: data) } catch { return error } } private func parseError(responseData: Data) -> Error { do { let rpcError = try JSONDecoder().decode(BackendError.self, from: responseData) logRPCError(rpcError) return rpcError } catch { // TODO: Return an error about an unrecognized error payload with the response body return error } } // Log specific RPC errors that cannot be mitigated or handled by user code. // These errors do not produce specific GenerateContentError or CountTokensError cases. private func logRPCError(_ error: BackendError) { let projectID = firebaseInfo.projectID if error.isVertexAIInFirebaseServiceDisabledError() { AILog.error(code: .vertexAIInFirebaseAPIDisabled, """ The Firebase AI SDK requires the Firebase AI API \ (`firebasevertexai.googleapis.com`) to be enabled in your Firebase project. Enable this API \ by visiting the Firebase Console at https://console.firebase.google.com/project/\(projectID)/genai/ and clicking "Get started". \ If you enabled this API recently, wait a few minutes for the action to propagate to our \ systems and then retry. """) } } private func parseResponse(_ type: T.Type, from data: Data) throws -> T { do { return try JSONDecoder().decode(type, from: data) } catch { if let json = String(data: data, encoding: .utf8) { AILog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)") } AILog.error( code: .loadRequestParseResponseFailedJSONError, "Error decoding server JSON: \(error)" ) throw error } } #if DEBUG private func cURLCommand(from request: URLRequest) -> String { var returnValue = "curl " if let allHeaders = request.allHTTPHeaderFields { for (key, value) in allHeaders { returnValue += "-H '\(key): \(value)' " } } guard let url = request.url else { return "" } returnValue += "'\(url.absoluteString)' " guard let body = request.httpBody, let jsonStr = String(bytes: body, encoding: .utf8) else { return "" } let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''") returnValue += "-d '\(escapedJSON)'" return returnValue } private func printCURLCommand(from request: URLRequest) { guard AILog.additionalLoggingEnabled() else { return } let command = cURLCommand(from: request) os_log(.debug, log: AILog.logObject, """ \(AILog.service) Creating request with the equivalent cURL command: ----- cURL command ----- \(command, privacy: .private) ------------------------ """) } #endif // DEBUG }