|
|
@@ -0,0 +1,210 @@
|
|
|
+// Copyright 2020 Google LLC
|
|
|
+//
|
|
|
+// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+// you may not use this file except in compliance with the License.
|
|
|
+// You may obtain a copy of the License at
|
|
|
+//
|
|
|
+// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+//
|
|
|
+// Unless required by applicable law or agreed to in writing, software
|
|
|
+// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+// See the License for the specific language governing permissions and
|
|
|
+// limitations under the License.
|
|
|
+
|
|
|
+import Foundation
|
|
|
+import FirebaseCore
|
|
|
+import FirebaseInstallations
|
|
|
+
|
|
|
+/// Model info object with details about pending or downloaded model.
|
|
|
+class ModelInfo: NSObject {
|
|
|
+ /// Model name.
|
|
|
+ var name: String
|
|
|
+
|
|
|
+ /// User defaults associated with model.
|
|
|
+ var defaults: UserDefaults
|
|
|
+
|
|
|
+ // TODO: revisit UserDefaultsBacked
|
|
|
+ /// Download URL for the model file, as returned by server.
|
|
|
+ @UserDefaultsBacked var downloadURL: String
|
|
|
+
|
|
|
+ /// Hash of the model, as returned by server.
|
|
|
+ @UserDefaultsBacked var modelHash: String
|
|
|
+
|
|
|
+ /// Size of the model, as returned by server.
|
|
|
+ @UserDefaultsBacked var size: Int
|
|
|
+
|
|
|
+ /// Local path of the model.
|
|
|
+ @UserDefaultsBacked var path: String?
|
|
|
+
|
|
|
+ /// Initialize model info and create user default keys.
|
|
|
+ init(app: FirebaseApp, name: String, defaults: UserDefaults = .firebaseMLDefaults) {
|
|
|
+ self.name = name
|
|
|
+ self.defaults = defaults
|
|
|
+ let bundleID = Bundle.main.bundleIdentifier ?? ""
|
|
|
+ let defaultsPrefix = "\(bundleID).\(app.name).\(name)"
|
|
|
+ _downloadURL = UserDefaultsBacked(
|
|
|
+ key: "\(defaultsPrefix).model-download-url",
|
|
|
+ storage: defaults
|
|
|
+ )
|
|
|
+ _modelHash = UserDefaultsBacked(key: "\(defaultsPrefix).model-hash", storage: defaults)
|
|
|
+ _size = UserDefaultsBacked(key: "\(defaultsPrefix).model-size", storage: defaults)
|
|
|
+ _path = UserDefaultsBacked(key: "\(defaultsPrefix).model-path", storage: defaults)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Model info retriever for a model from local user defaults or server.
|
|
|
+class ModelInfoRetriever: NSObject {
|
|
|
+ /// Current Firebase app.
|
|
|
+ var app: FirebaseApp
|
|
|
+ /// Model info associated with model.
|
|
|
+ var modelInfo: ModelInfo?
|
|
|
+ /// Project id.
|
|
|
+ var projectID: String
|
|
|
+ /// Model name.
|
|
|
+ var modelName: String
|
|
|
+ /// Firebase installations.
|
|
|
+ var installations: Installations
|
|
|
+
|
|
|
+ /// Associate model info retriever with current Firebase app, project ID, and model name.
|
|
|
+ init(app: FirebaseApp, projectID: String, modelName: String) {
|
|
|
+ self.app = app
|
|
|
+ self.projectID = projectID
|
|
|
+ self.modelName = modelName
|
|
|
+ installations = Installations.installations(app: app)
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Build custom model object from model info.
|
|
|
+ func buildModel() -> CustomModel? {
|
|
|
+ /// Build custom model only if model info is filled out, and model file is already on device.
|
|
|
+ guard let info = modelInfo, let path = info.path else { return nil }
|
|
|
+ let model = CustomModel(
|
|
|
+ name: info.name,
|
|
|
+ size: info.size,
|
|
|
+ path: path,
|
|
|
+ hash: info.modelHash
|
|
|
+ )
|
|
|
+ return model
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Extension to handle fetching model info from server.
|
|
|
+extension ModelInfoRetriever {
|
|
|
+ /// HTTP request headers.
|
|
|
+ static let fisTokenHTTPHeader = "x-goog-firebase-installations-auth"
|
|
|
+ static let hashMatchHTTPHeader = "if-none-match"
|
|
|
+ static let bundleIDHTTPHeader = "x-ios-bundle-identifier"
|
|
|
+
|
|
|
+ /// HTTP response headers.
|
|
|
+ static let etagHTTPHeader = "ETag"
|
|
|
+
|
|
|
+ /// Error descriptions.
|
|
|
+ static let tokenErrorDescription = "Error retrieving FIS token."
|
|
|
+ static let selfDeallocatedErrorDescription = "Self deallocated."
|
|
|
+ static let missingModelHashErrorDescription = "Model hash missing in server response."
|
|
|
+ static let invalidHTTPResponseErrorDescription =
|
|
|
+ "Could not get a valid HTTP response from server."
|
|
|
+
|
|
|
+ /// Construct model fetch base URL.
|
|
|
+ var modelInfoFetchURL: URL {
|
|
|
+ var components = URLComponents()
|
|
|
+ components.scheme = "https"
|
|
|
+ components.host = "firebaseml.googleapis.com"
|
|
|
+ components.path = "/v1beta2/projects/\(projectID)/models/\(modelName):download"
|
|
|
+ // TODO: handle nil
|
|
|
+ return components.url!
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Construct model fetch URL request.
|
|
|
+ func getModelInfoFetchURLRequest(token: String) -> URLRequest {
|
|
|
+ var request = URLRequest(url: modelInfoFetchURL)
|
|
|
+ request.httpMethod = "GET"
|
|
|
+ // TODO: Check if bundle ID needs to be part of the request header.
|
|
|
+ let bundleID = Bundle.main.bundleIdentifier ?? ""
|
|
|
+ request.setValue(bundleID, forHTTPHeaderField: ModelInfoRetriever.bundleIDHTTPHeader)
|
|
|
+ request.setValue(token, forHTTPHeaderField: ModelInfoRetriever.fisTokenHTTPHeader)
|
|
|
+ if let info = modelInfo, info.modelHash.count > 0 {
|
|
|
+ request.setValue(info.modelHash, forHTTPHeaderField: ModelInfoRetriever.hashMatchHTTPHeader)
|
|
|
+ }
|
|
|
+ return request
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Get model info from server.
|
|
|
+ func downloadModelInfo(completion: @escaping (DownloadError?) -> Void) {
|
|
|
+ /// Get FIS token.
|
|
|
+ installations.authToken { [weak self] tokenResult, error in
|
|
|
+ guard let self = self else {
|
|
|
+ completion(.internalError(description: ModelInfoRetriever.selfDeallocatedErrorDescription))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ guard let result = tokenResult
|
|
|
+ else {
|
|
|
+ completion(.internalError(description: ModelInfoRetriever.tokenErrorDescription))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ /// Get model info fetch URL with appropriate HTTP headers.
|
|
|
+ let request = self.getModelInfoFetchURLRequest(token: result.authToken)
|
|
|
+
|
|
|
+ /// Download model info.
|
|
|
+ // TODO: Consider moving request to a separate method
|
|
|
+ let session = URLSession(configuration: .ephemeral)
|
|
|
+ let dataTask = session.dataTask(with: request) { [weak self]
|
|
|
+ data, response, error in
|
|
|
+ guard let self = self else {
|
|
|
+ completion(.internalError(description: ModelInfoRetriever
|
|
|
+ .selfDeallocatedErrorDescription))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if let downloadError = error {
|
|
|
+ completion(.internalError(description: downloadError.localizedDescription))
|
|
|
+ } else {
|
|
|
+ guard let httpResponse = response as? HTTPURLResponse else {
|
|
|
+ completion(.internalError(description: ModelInfoRetriever
|
|
|
+ .invalidHTTPResponseErrorDescription))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO: Handle more http status codes
|
|
|
+ switch httpResponse.statusCode {
|
|
|
+ case 200:
|
|
|
+ guard let modelHash = httpResponse
|
|
|
+ .allHeaderFields[ModelInfoRetriever.etagHTTPHeader] as? String else {
|
|
|
+ completion(.internalError(description: ModelInfoRetriever
|
|
|
+ .missingModelHashErrorDescription))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ guard let data = data else {
|
|
|
+ completion(.internalError(description: ModelInfoRetriever
|
|
|
+ .invalidHTTPResponseErrorDescription))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ self.saveModelInfo(data: data, modelHash: modelHash)
|
|
|
+ completion(nil)
|
|
|
+ case 304:
|
|
|
+ completion(nil)
|
|
|
+ default:
|
|
|
+ completion(.notFound)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ dataTask.resume()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Save model info to user defaults.
|
|
|
+ func saveModelInfo(data: Data, modelHash: String) {
|
|
|
+ // TODO: Save model info to user defaults
|
|
|
+ modelInfo?.modelHash = modelHash
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Named user defaults for FirebaseML.
|
|
|
+extension UserDefaults {
|
|
|
+ static var firebaseMLDefaults: UserDefaults {
|
|
|
+ let suiteName = "com.google.firebase.ml"
|
|
|
+ // TODO: handle nil gracefully
|
|
|
+ let defaults = UserDefaults(suiteName: suiteName)!
|
|
|
+ return defaults
|
|
|
+ }
|
|
|
+}
|