ModelInfoRetriever.swift 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// Model info object with details about pending or downloaded model.
  18. class ModelInfo: NSObject {
  19. /// Model name.
  20. var name: String
  21. /// User defaults associated with model.
  22. var defaults: UserDefaults
  23. // TODO: revisit UserDefaultsBacked
  24. /// Download URL for the model file, as returned by server.
  25. @UserDefaultsBacked var downloadURL: String
  26. /// Hash of the model, as returned by server.
  27. @UserDefaultsBacked var modelHash: String
  28. /// Size of the model, as returned by server.
  29. @UserDefaultsBacked var size: Int
  30. /// Local path of the model.
  31. @UserDefaultsBacked var path: String?
  32. /// Initialize model info and create user default keys.
  33. init(app: FirebaseApp, name: String, defaults: UserDefaults = .firebaseMLDefaults) {
  34. self.name = name
  35. self.defaults = defaults
  36. let bundleID = Bundle.main.bundleIdentifier ?? ""
  37. let defaultsPrefix = "\(bundleID).\(app.name).\(name)"
  38. _downloadURL = UserDefaultsBacked(
  39. key: "\(defaultsPrefix).model-download-url",
  40. storage: defaults
  41. )
  42. _modelHash = UserDefaultsBacked(key: "\(defaultsPrefix).model-hash", storage: defaults)
  43. _size = UserDefaultsBacked(key: "\(defaultsPrefix).model-size", storage: defaults)
  44. _path = UserDefaultsBacked(key: "\(defaultsPrefix).model-path", storage: defaults)
  45. }
  46. }
  47. /// Model info retriever for a model from local user defaults or server.
  48. class ModelInfoRetriever: NSObject {
  49. /// Current Firebase app.
  50. var app: FirebaseApp
  51. /// Model info associated with model.
  52. var modelInfo: ModelInfo?
  53. /// Project id.
  54. var projectID: String
  55. /// Model name.
  56. var modelName: String
  57. /// Firebase installations.
  58. var installations: Installations
  59. /// Associate model info retriever with current Firebase app, project ID, and model name.
  60. init(app: FirebaseApp, projectID: String, modelName: String) {
  61. self.app = app
  62. self.projectID = projectID
  63. self.modelName = modelName
  64. installations = Installations.installations(app: app)
  65. }
  66. /// Build custom model object from model info.
  67. func buildModel() -> CustomModel? {
  68. /// Build custom model only if model info is filled out, and model file is already on device.
  69. guard let info = modelInfo, let path = info.path else { return nil }
  70. let model = CustomModel(
  71. name: info.name,
  72. size: info.size,
  73. path: path,
  74. hash: info.modelHash
  75. )
  76. return model
  77. }
  78. }
  79. /// Extension to handle fetching model info from server.
  80. extension ModelInfoRetriever {
  81. /// HTTP request headers.
  82. static let fisTokenHTTPHeader = "x-goog-firebase-installations-auth"
  83. static let hashMatchHTTPHeader = "if-none-match"
  84. static let bundleIDHTTPHeader = "x-ios-bundle-identifier"
  85. /// HTTP response headers.
  86. static let etagHTTPHeader = "ETag"
  87. /// Error descriptions.
  88. static let tokenErrorDescription = "Error retrieving FIS token."
  89. static let selfDeallocatedErrorDescription = "Self deallocated."
  90. static let missingModelHashErrorDescription = "Model hash missing in server response."
  91. static let invalidHTTPResponseErrorDescription =
  92. "Could not get a valid HTTP response from server."
  93. /// Construct model fetch base URL.
  94. var modelInfoFetchURL: URL {
  95. var components = URLComponents()
  96. components.scheme = "https"
  97. components.host = "firebaseml.googleapis.com"
  98. components.path = "/v1beta2/projects/\(projectID)/models/\(modelName):download"
  99. // TODO: handle nil
  100. return components.url!
  101. }
  102. /// Construct model fetch URL request.
  103. func getModelInfoFetchURLRequest(token: String) -> URLRequest {
  104. var request = URLRequest(url: modelInfoFetchURL)
  105. request.httpMethod = "GET"
  106. // TODO: Check if bundle ID needs to be part of the request header.
  107. let bundleID = Bundle.main.bundleIdentifier ?? ""
  108. request.setValue(bundleID, forHTTPHeaderField: ModelInfoRetriever.bundleIDHTTPHeader)
  109. request.setValue(token, forHTTPHeaderField: ModelInfoRetriever.fisTokenHTTPHeader)
  110. if let info = modelInfo, info.modelHash.count > 0 {
  111. request.setValue(info.modelHash, forHTTPHeaderField: ModelInfoRetriever.hashMatchHTTPHeader)
  112. }
  113. return request
  114. }
  115. /// Get model info from server.
  116. func downloadModelInfo(completion: @escaping (DownloadError?) -> Void) {
  117. /// Get FIS token.
  118. installations.authToken { [weak self] tokenResult, error in
  119. guard let self = self else {
  120. completion(.internalError(description: ModelInfoRetriever.selfDeallocatedErrorDescription))
  121. return
  122. }
  123. guard let result = tokenResult
  124. else {
  125. completion(.internalError(description: ModelInfoRetriever.tokenErrorDescription))
  126. return
  127. }
  128. /// Get model info fetch URL with appropriate HTTP headers.
  129. let request = self.getModelInfoFetchURLRequest(token: result.authToken)
  130. /// Download model info.
  131. // TODO: Consider moving request to a separate method
  132. let session = URLSession(configuration: .ephemeral)
  133. let dataTask = session.dataTask(with: request) { [weak self]
  134. data, response, error in
  135. guard let self = self else {
  136. completion(.internalError(description: ModelInfoRetriever
  137. .selfDeallocatedErrorDescription))
  138. return
  139. }
  140. if let downloadError = error {
  141. completion(.internalError(description: downloadError.localizedDescription))
  142. } else {
  143. guard let httpResponse = response as? HTTPURLResponse else {
  144. completion(.internalError(description: ModelInfoRetriever
  145. .invalidHTTPResponseErrorDescription))
  146. return
  147. }
  148. // TODO: Handle more http status codes
  149. switch httpResponse.statusCode {
  150. case 200:
  151. guard let modelHash = httpResponse
  152. .allHeaderFields[ModelInfoRetriever.etagHTTPHeader] as? String else {
  153. completion(.internalError(description: ModelInfoRetriever
  154. .missingModelHashErrorDescription))
  155. return
  156. }
  157. guard let data = data else {
  158. completion(.internalError(description: ModelInfoRetriever
  159. .invalidHTTPResponseErrorDescription))
  160. return
  161. }
  162. self.saveModelInfo(data: data, modelHash: modelHash)
  163. completion(nil)
  164. case 304:
  165. completion(nil)
  166. default:
  167. completion(.notFound)
  168. }
  169. }
  170. }
  171. dataTask.resume()
  172. }
  173. }
  174. /// Save model info to user defaults.
  175. func saveModelInfo(data: Data, modelHash: String) {
  176. // TODO: Save model info to user defaults
  177. modelInfo?.modelHash = modelHash
  178. }
  179. }
  180. /// Named user defaults for FirebaseML.
  181. extension UserDefaults {
  182. static var firebaseMLDefaults: UserDefaults {
  183. let suiteName = "com.google.firebase.ml"
  184. // TODO: handle nil gracefully
  185. let defaults = UserDefaults(suiteName: suiteName)!
  186. return defaults
  187. }
  188. }