ModelInfoRetriever.swift 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// Model info response object.
  18. struct ModelInfoResponse: Codable {
  19. var downloadURL: String
  20. var expireTime: String
  21. var size: String
  22. }
  23. /// Properties for server response keys.
  24. extension ModelInfoResponse {
  25. enum CodingKeys: String, CodingKey {
  26. case downloadURL = "downloadUri"
  27. case expireTime
  28. case size = "sizeBytes"
  29. }
  30. }
  31. /// Model info retriever for a model from local user defaults or server.
  32. class ModelInfoRetriever: NSObject {
  33. /// Current Firebase app.
  34. var app: FirebaseApp
  35. /// Model info associated with model.
  36. var modelInfo: ModelInfo?
  37. /// Model name.
  38. var modelName: String
  39. /// Firebase installations.
  40. var installations: Installations
  41. /// User defaults associated with model.
  42. var defaults: UserDefaults
  43. /// Associate model info retriever with current Firebase app, and model name.
  44. init(app: FirebaseApp, modelName: String, defaults: UserDefaults = .firebaseMLDefaults) {
  45. self.app = app
  46. self.modelName = modelName
  47. self.defaults = defaults
  48. installations = Installations.installations(app: app)
  49. }
  50. /// Build custom model object from model info.
  51. func buildModel() -> CustomModel? {
  52. /// Build custom model only if model info is filled out, and model file is already on device.
  53. guard let info = modelInfo, let path = info.path else { return nil }
  54. let model = CustomModel(
  55. name: info.name,
  56. size: info.size,
  57. path: path,
  58. hash: info.modelHash
  59. )
  60. return model
  61. }
  62. }
  63. /// Extension to handle fetching model info from server.
  64. extension ModelInfoRetriever {
  65. /// HTTP request headers.
  66. static let fisTokenHTTPHeader = "x-goog-firebase-installations-auth"
  67. static let hashMatchHTTPHeader = "if-none-match"
  68. static let bundleIDHTTPHeader = "x-ios-bundle-identifier"
  69. /// HTTP response headers.
  70. static let etagHTTPHeader = "Etag"
  71. /// Error descriptions.
  72. static let tokenErrorDescription = "Error retrieving FIS token."
  73. static let selfDeallocatedErrorDescription = "Self deallocated."
  74. static let missingModelHashErrorDescription = "Model hash missing in server response."
  75. static let invalidHTTPResponseErrorDescription =
  76. "Could not get a valid HTTP response from server."
  77. /// Construct model fetch base URL.
  78. var modelInfoFetchURL: URL {
  79. let projectID = app.options.projectID ?? ""
  80. let apiKey = app.options.apiKey
  81. var components = URLComponents()
  82. components.scheme = "https"
  83. components.host = "firebaseml.googleapis.com"
  84. components.path = "/v1beta2/projects/\(projectID)/models/\(modelName):download"
  85. components.queryItems = [URLQueryItem(name: "key", value: apiKey)]
  86. // TODO: handle nil
  87. return components.url!
  88. }
  89. /// Construct model fetch URL request.
  90. func getModelInfoFetchURLRequest(token: String) -> URLRequest {
  91. var request = URLRequest(url: modelInfoFetchURL)
  92. request.httpMethod = "GET"
  93. // TODO: Check if bundle ID needs to be part of the request header.
  94. let bundleID = Bundle.main.bundleIdentifier ?? ""
  95. request.setValue(bundleID, forHTTPHeaderField: ModelInfoRetriever.bundleIDHTTPHeader)
  96. request.setValue(token, forHTTPHeaderField: ModelInfoRetriever.fisTokenHTTPHeader)
  97. if let info = modelInfo, info.modelHash.count > 0 {
  98. request.setValue(info.modelHash, forHTTPHeaderField: ModelInfoRetriever.hashMatchHTTPHeader)
  99. }
  100. return request
  101. }
  102. /// Get installations auth token.
  103. func getAuthToken(completion: @escaping (Result<String, DownloadError>) -> Void) {
  104. /// Get FIS token.
  105. installations.authToken { tokenResult, error in
  106. guard let result = tokenResult
  107. else {
  108. completion(.failure(.internalError(description: ModelInfoRetriever.tokenErrorDescription)))
  109. return
  110. }
  111. completion(.success(result.authToken))
  112. }
  113. }
  114. /// Get model info from server.
  115. func downloadModelInfo(completion: @escaping (DownloadError?) -> Void) {
  116. getAuthToken { result in
  117. switch result {
  118. case let .success(authToken):
  119. /// Get model info fetch URL with appropriate HTTP headers.
  120. let request = self.getModelInfoFetchURLRequest(token: authToken)
  121. // TODO: revisit using ephemeral session with Etag
  122. let session = URLSession(configuration: .ephemeral)
  123. /// Download model info.
  124. let dataTask = session.dataTask(with: request) { [weak self]
  125. data, response, error in
  126. guard let self = self else {
  127. completion(.internalError(description: ModelInfoRetriever
  128. .selfDeallocatedErrorDescription))
  129. return
  130. }
  131. if let downloadError = error {
  132. completion(.internalError(description: downloadError.localizedDescription))
  133. } else {
  134. guard let httpResponse = response as? HTTPURLResponse else {
  135. completion(.internalError(description: ModelInfoRetriever
  136. .invalidHTTPResponseErrorDescription))
  137. return
  138. }
  139. switch httpResponse.statusCode {
  140. case 200:
  141. guard let modelHash = httpResponse
  142. .allHeaderFields[ModelInfoRetriever.etagHTTPHeader] as? String else {
  143. completion(.internalError(description: ModelInfoRetriever
  144. .missingModelHashErrorDescription))
  145. return
  146. }
  147. guard let data = data else {
  148. completion(.internalError(description: ModelInfoRetriever
  149. .invalidHTTPResponseErrorDescription))
  150. return
  151. }
  152. self.saveModelInfo(data: data, modelHash: modelHash)
  153. completion(nil)
  154. case 304:
  155. completion(nil)
  156. case 404:
  157. completion(.notFound)
  158. // TODO: Handle more http status codes
  159. default:
  160. completion(
  161. .internalError(
  162. description: "Server returned with error - \(httpResponse.statusCode)."
  163. )
  164. )
  165. }
  166. }
  167. }
  168. dataTask.resume()
  169. case .failure:
  170. completion(.internalError(description: ModelInfoRetriever.tokenErrorDescription))
  171. return
  172. }
  173. }
  174. }
  175. /// Save model info to user defaults.
  176. func saveModelInfo(data: Data, modelHash: String) {
  177. let decoder = JSONDecoder()
  178. guard let modelInfoJSON = try? decoder.decode(ModelInfoResponse.self, from: data)
  179. else { return }
  180. let modelInfo = ModelInfo(app: app, name: modelName, defaults: defaults)
  181. modelInfo.downloadURL = modelInfoJSON.downloadURL
  182. // TODO: Possibly improve handling invalid server responses.
  183. modelInfo.size = Int(modelInfoJSON.size) ?? 0
  184. modelInfo.modelHash = modelHash
  185. self.modelInfo = modelInfo
  186. }
  187. }