ModelInfoRetriever.swift 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// URL Session to use while retrieving model info.
  18. protocol ModelInfoRetrieverSession {
  19. func getModelInfo(with request: URLRequest,
  20. completion: @escaping (Data?, URLResponse?, Error?) -> Void)
  21. }
  22. /// Extension to customize data task requests.
  23. extension URLSession: ModelInfoRetrieverSession {
  24. func getModelInfo(with request: URLRequest,
  25. completion: @escaping (Data?, URLResponse?, Error?) -> Void) {
  26. let task = dataTask(with: request) { data, response, error in
  27. completion(data, response, error)
  28. }
  29. task.resume()
  30. }
  31. }
  32. /// Model info response object.
  33. private struct ModelInfoResponse: Codable {
  34. var downloadURL: String
  35. var urlExpiryTime: String
  36. var size: String
  37. }
  38. /// Properties for server response keys.
  39. private extension ModelInfoResponse {
  40. enum CodingKeys: String, CodingKey {
  41. case downloadURL = "downloadUri"
  42. case urlExpiryTime = "expireTime"
  43. case size = "sizeBytes"
  44. }
  45. }
  46. /// Downloading model info will return new model info only if it different from local model info.
  47. enum DownloadModelInfoResult {
  48. case notModified
  49. case modelInfo(RemoteModelInfo)
  50. }
  51. /// Model info retrieval error codes.
  52. enum ModelInfoErrorCode {
  53. case noError
  54. case noHash
  55. case httpError(code: Int)
  56. case connectionFailed
  57. case hashMismatch
  58. }
  59. /// Model info retriever for a model from local user defaults or server.
  60. class ModelInfoRetriever {
  61. /// Model name.
  62. private let modelName: String
  63. /// URL session for model info request.
  64. private let session: ModelInfoRetrieverSession
  65. /// Current Firebase app project ID.
  66. private let projectID: String
  67. /// Current Firebase app API key.
  68. private let apiKey: String
  69. /// Current Firebase app name.
  70. private let appName: String
  71. /// Local model info to validate model freshness.
  72. private let localModelInfo: LocalModelInfo?
  73. /// Telemetry logger.
  74. private let telemetryLogger: TelemetryLogger?
  75. /// Auth token provider.
  76. typealias AuthTokenProvider = (_ completion: @escaping (Result<String, DownloadError>) -> Void)
  77. -> Void
  78. private let authTokenProvider: AuthTokenProvider
  79. /// Associate model info retriever with current Firebase app, and model name.
  80. init(modelName: String,
  81. projectID: String,
  82. apiKey: String,
  83. authTokenProvider: @escaping AuthTokenProvider,
  84. appName: String,
  85. localModelInfo: LocalModelInfo? = nil,
  86. session: ModelInfoRetrieverSession? = nil,
  87. telemetryLogger: TelemetryLogger? = nil) {
  88. self.modelName = modelName
  89. self.projectID = projectID
  90. self.apiKey = apiKey
  91. self.appName = appName
  92. self.localModelInfo = localModelInfo
  93. self.authTokenProvider = authTokenProvider
  94. self.telemetryLogger = telemetryLogger
  95. if let urlSession = session {
  96. self.session = urlSession
  97. } else {
  98. self.session = URLSession(configuration: .ephemeral)
  99. }
  100. }
  101. convenience init(modelName: String,
  102. projectID: String,
  103. apiKey: String,
  104. installations: Installations,
  105. appName: String,
  106. localModelInfo: LocalModelInfo? = nil,
  107. session: ModelInfoRetrieverSession? = nil,
  108. telemetryLogger: TelemetryLogger? = nil) {
  109. self.init(modelName: modelName,
  110. projectID: projectID,
  111. apiKey: apiKey,
  112. authTokenProvider: ModelInfoRetriever.authTokenProvider(installation: installations),
  113. appName: appName,
  114. localModelInfo: localModelInfo,
  115. session: session,
  116. telemetryLogger: telemetryLogger)
  117. }
  118. private static func authTokenProvider(installation: Installations) -> AuthTokenProvider {
  119. return { completion in
  120. installation.authToken { tokenResult, error in
  121. guard let result = tokenResult
  122. else {
  123. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  124. .authToken)))
  125. return
  126. }
  127. completion(.success(result.authToken))
  128. }
  129. }
  130. }
  131. /// Get model info from server.
  132. func downloadModelInfo(completion: @escaping (Result<DownloadModelInfoResult, DownloadError>)
  133. -> Void) {
  134. authTokenProvider { result in
  135. switch result {
  136. /// Successfully received FIS token.
  137. case let .success(authToken):
  138. DeviceLogger.logEvent(level: .debug,
  139. message: ModelInfoRetriever.DebugDescription
  140. .receivedAuthToken,
  141. messageCode: .validAuthToken)
  142. /// Get model info fetch URL with appropriate HTTP headers.
  143. guard let request = self.getModelInfoFetchURLRequest(token: authToken) else {
  144. DeviceLogger.logEvent(level: .debug,
  145. message: ModelInfoRetriever.ErrorDescription
  146. .invalidModelInfoFetchURL,
  147. messageCode: .invalidModelInfoFetchURL)
  148. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  149. status: .failed,
  150. errorCode: .connectionFailed)
  151. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  152. .invalidModelInfoFetchURL)))
  153. return
  154. }
  155. /// Download model info.
  156. self.session.getModelInfo(with: request) {
  157. data, response, error in
  158. if let downloadError = error {
  159. let description = ModelInfoRetriever.ErrorDescription
  160. .failedModelInfoRetrieval(downloadError.localizedDescription)
  161. DeviceLogger.logEvent(level: .debug,
  162. message: description,
  163. messageCode: .modelInfoRetrievalError)
  164. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  165. status: .failed,
  166. errorCode: .connectionFailed)
  167. completion(.failure(.internalError(description: description)))
  168. } else {
  169. guard let httpResponse = response as? HTTPURLResponse else {
  170. DeviceLogger.logEvent(level: .debug,
  171. message: ModelInfoRetriever.ErrorDescription
  172. .invalidHTTPResponse,
  173. messageCode: .invalidHTTPResponse)
  174. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  175. status: .failed,
  176. errorCode: .connectionFailed)
  177. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  178. .invalidHTTPResponse)))
  179. return
  180. }
  181. DeviceLogger.logEvent(level: .debug,
  182. message: ModelInfoRetriever.DebugDescription
  183. .receivedServerResponse,
  184. messageCode: .validHTTPResponse)
  185. switch httpResponse.statusCode {
  186. case 200:
  187. guard let modelHash = httpResponse
  188. .allHeaderFields[ModelInfoRetriever.etagHTTPHeader] as? String else {
  189. DeviceLogger.logEvent(level: .debug,
  190. message: ModelInfoRetriever.ErrorDescription.missingModelHash,
  191. messageCode: .missingModelHash)
  192. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  193. status: .failed,
  194. errorCode: .noHash)
  195. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  196. .missingModelHash)))
  197. return
  198. }
  199. guard let data = data else {
  200. DeviceLogger.logEvent(level: .debug,
  201. message: ModelInfoRetriever.ErrorDescription
  202. .invalidHTTPResponse,
  203. messageCode: .invalidHTTPResponse)
  204. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  205. .invalidHTTPResponse)))
  206. return
  207. }
  208. do {
  209. let modelInfo = try self.getRemoteModelInfoFromResponse(data, modelHash: modelHash)
  210. DeviceLogger.logEvent(level: .debug,
  211. message: ModelInfoRetriever.DebugDescription
  212. .modelInfoDownloaded,
  213. messageCode: .modelInfoDownloaded)
  214. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  215. status: .updateAvailable,
  216. errorCode: .noError)
  217. completion(.success(.modelInfo(modelInfo)))
  218. } catch {
  219. let description = ModelInfoRetriever.ErrorDescription
  220. .invalidmodelInfoJSON(error.localizedDescription)
  221. DeviceLogger.logEvent(level: .debug,
  222. message: description,
  223. messageCode: .invalidModelInfoJSON)
  224. completion(
  225. .failure(.internalError(description: description))
  226. )
  227. }
  228. case 304:
  229. /// For this case to occur, local model info has to already be available on device.
  230. // TODO: Is this needed? Currently handles the case if model info disappears between request and response
  231. guard let localInfo = self.localModelInfo else {
  232. DeviceLogger.logEvent(level: .debug,
  233. message: ModelInfoRetriever.ErrorDescription
  234. .unexpectedModelInfoDeletion,
  235. messageCode: .modelInfoDeleted)
  236. completion(
  237. .failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  238. .unexpectedModelInfoDeletion))
  239. )
  240. return
  241. }
  242. guard let modelHash = httpResponse
  243. .allHeaderFields[ModelInfoRetriever.etagHTTPHeader] as? String else {
  244. DeviceLogger.logEvent(level: .debug,
  245. message: ModelInfoRetriever.ErrorDescription
  246. .missingModelHash,
  247. messageCode: .noModelHash)
  248. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  249. status: .failed,
  250. errorCode: .noHash)
  251. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  252. .missingModelHash)))
  253. return
  254. }
  255. guard modelHash == localInfo.modelHash else {
  256. DeviceLogger.logEvent(level: .debug,
  257. message: ModelInfoRetriever.ErrorDescription
  258. .modelHashMismatch,
  259. messageCode: .modelHashMismatchError)
  260. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  261. status: .failed,
  262. errorCode: .hashMismatch)
  263. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  264. .modelHashMismatch)))
  265. return
  266. }
  267. DeviceLogger.logEvent(level: .debug,
  268. message: ModelInfoRetriever.DebugDescription
  269. .modelInfoUnmodified,
  270. messageCode: .modelInfoUnmodified)
  271. completion(.success(.notModified))
  272. case 400:
  273. let description = ModelInfoRetriever.ErrorDescription.invalidModelName(self.modelName)
  274. DeviceLogger.logEvent(level: .debug,
  275. message: description,
  276. messageCode: .invalidModelName)
  277. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  278. status: .failed,
  279. errorCode: .httpError(code: httpResponse
  280. .statusCode))
  281. completion(.failure(.invalidArgument))
  282. case 401, 403:
  283. DeviceLogger.logEvent(level: .debug,
  284. message: ModelInfoRetriever.ErrorDescription.permissionDenied,
  285. messageCode: .permissionDenied)
  286. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  287. status: .failed,
  288. errorCode: .httpError(code: httpResponse
  289. .statusCode))
  290. completion(.failure(.permissionDenied))
  291. case 404:
  292. let description = ModelInfoRetriever.ErrorDescription.modelNotFound(self.modelName)
  293. DeviceLogger.logEvent(level: .debug,
  294. message: description,
  295. messageCode: .modelNotFound)
  296. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  297. status: .failed,
  298. errorCode: .httpError(code: httpResponse
  299. .statusCode))
  300. completion(.failure(.notFound))
  301. // TODO: Handle more http status codes
  302. default:
  303. let description = ModelInfoRetriever.ErrorDescription
  304. .modelInfoRetrievalFailed(httpResponse.statusCode)
  305. DeviceLogger.logEvent(level: .debug,
  306. message: description,
  307. messageCode: .modelInfoRetrievalError)
  308. self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
  309. status: .failed,
  310. errorCode: .httpError(code: httpResponse
  311. .statusCode))
  312. completion(.failure(.internalError(description: description)))
  313. }
  314. }
  315. }
  316. /// FIS token error.
  317. case .failure:
  318. DeviceLogger.logEvent(level: .debug,
  319. message: ModelInfoRetriever.ErrorDescription
  320. .authToken,
  321. messageCode: .authTokenError)
  322. completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
  323. .authToken)))
  324. return
  325. }
  326. }
  327. }
  328. }
  329. /// Extension with helper methods to handle fetching model info from server.
  330. extension ModelInfoRetriever {
  331. /// HTTP request headers.
  332. private static let fisTokenHTTPHeader = "x-goog-firebase-installations-auth"
  333. private static let hashMatchHTTPHeader = "if-none-match"
  334. private static let bundleIDHTTPHeader = "x-ios-bundle-identifier"
  335. /// HTTP response headers.
  336. private static let etagHTTPHeader = "Etag"
  337. /// Construct model fetch base URL.
  338. var modelInfoFetchURL: URL? {
  339. var components = URLComponents()
  340. components.scheme = "https"
  341. components.host = "firebaseml.googleapis.com"
  342. components.path = "/v1beta2/projects/\(projectID)/models/\(modelName):download"
  343. components.queryItems = [URLQueryItem(name: "key", value: apiKey)]
  344. return components.url
  345. }
  346. /// Construct model fetch URL request.
  347. func getModelInfoFetchURLRequest(token: String) -> URLRequest? {
  348. guard let fetchURL = modelInfoFetchURL else { return nil }
  349. var request = URLRequest(url: fetchURL)
  350. request.httpMethod = "GET"
  351. // TODO: Check if bundle ID needs to be part of the request header.
  352. let bundleID = Bundle.main.bundleIdentifier ?? ""
  353. request.setValue(bundleID, forHTTPHeaderField: ModelInfoRetriever.bundleIDHTTPHeader)
  354. request.setValue(token, forHTTPHeaderField: ModelInfoRetriever.fisTokenHTTPHeader)
  355. /// Get model hash if local model info is available on device.
  356. if let modelInfo = localModelInfo {
  357. request.setValue(
  358. modelInfo.modelHash,
  359. forHTTPHeaderField: ModelInfoRetriever.hashMatchHTTPHeader
  360. )
  361. }
  362. return request
  363. }
  364. /// Parse date from string - used to get download URL expiry time.
  365. private static func getDateFromString(_ strDate: String) -> Date? {
  366. if #available(iOS 11, macOS 10.13, macCatalyst 13.0, tvOS 11.0, watchOS 4.0, *) {
  367. let dateFormatter = ISO8601DateFormatter()
  368. dateFormatter.timeZone = TimeZone(secondsFromGMT: 0)
  369. dateFormatter.formatOptions = [.withFractionalSeconds]
  370. return dateFormatter.date(from: strDate)
  371. } else {
  372. let dateFormatter = DateFormatter()
  373. dateFormatter.locale = Locale(identifier: "en-US_POSIX")
  374. dateFormatter.dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"
  375. dateFormatter.timeZone = TimeZone(secondsFromGMT: 0)
  376. return dateFormatter.date(from: strDate)
  377. }
  378. }
  379. /// Return model info created from server response.
  380. private func getRemoteModelInfoFromResponse(_ data: Data,
  381. modelHash: String) throws -> RemoteModelInfo {
  382. let decoder = JSONDecoder()
  383. guard let modelInfoJSON = try? decoder.decode(ModelInfoResponse.self, from: data) else {
  384. throw DownloadError
  385. .internalError(description: ModelInfoRetriever.ErrorDescription.decodeModelInfoResponse)
  386. }
  387. // TODO: Possibly improve handling invalid server responses.
  388. guard let downloadURL = URL(string: modelInfoJSON.downloadURL) else {
  389. throw DownloadError
  390. .internalError(description: ModelInfoRetriever.ErrorDescription
  391. .invalidModelDownloadURL)
  392. }
  393. let modelHash = modelHash
  394. let size = Int(modelInfoJSON.size) ?? 0
  395. guard let expiryTime = ModelInfoRetriever.getDateFromString(modelInfoJSON.urlExpiryTime) else {
  396. throw DownloadError
  397. .internalError(description: ModelInfoRetriever.ErrorDescription
  398. .invalidModelDownloadURLExpiryTime)
  399. }
  400. return RemoteModelInfo(
  401. name: modelName,
  402. downloadURL: downloadURL,
  403. modelHash: modelHash,
  404. size: size,
  405. urlExpiryTime: expiryTime
  406. )
  407. }
  408. }
  409. /// Possible error messages for model info retrieval.
  410. extension ModelInfoRetriever {
  411. /// Debug descriptions.
  412. private enum DebugDescription {
  413. static let receivedServerResponse = "Received a valid response from model info server."
  414. static let receivedAuthToken = "Generated valid auth token."
  415. static let modelInfoDownloaded = "Successfully downloaded model info."
  416. static let modelInfoUnmodified = "Local model info matches the latest on server."
  417. }
  418. /// Error descriptions.
  419. private enum ErrorDescription {
  420. static let authToken = "Error retrieving auth token."
  421. static let selfDeallocated = "Self deallocated."
  422. static let missingModelHash = "Model hash missing in model info server response."
  423. static let invalidModelInfoFetchURL = "Unable to create URL to fetch model info."
  424. static let invalidHTTPResponse =
  425. "Could not get a valid HTTP response for model info retrieval."
  426. static let invalidmodelInfoJSON = { (error: String) in
  427. "Failed to parse model info: \(error)"
  428. }
  429. static let failedModelInfoRetrieval = { (error: String) in
  430. "Failed to retrieve model info: \(error)"
  431. }
  432. static let unexpectedModelInfoDeletion =
  433. "Model info was deleted unexpectedly."
  434. static let serverResponse = { (errorCode: Int) in
  435. "Server returned with HTTP error code: \(errorCode)."
  436. }
  437. static let invalidModelName = { (name: String) in
  438. "Invalid model name: \(name)"
  439. }
  440. static let modelNotFound = { (name: String) in
  441. "No model found with name: \(name)"
  442. }
  443. static let modelInfoRetrievalFailed = { (code: Int) in
  444. "Model info retrieval failed with HTTP error code: \(code)"
  445. }
  446. static let decodeModelInfoResponse =
  447. "Unable to decode model info response from server."
  448. static let invalidModelDownloadURL =
  449. "Invalid model download URL from server."
  450. static let invalidModelDownloadURLExpiryTime =
  451. "Invalid download URL expiry time from server."
  452. static let modelHashMismatch = "Unexpected model hash value."
  453. static let permissionDenied = "Invalid or missing permissions to retrieve model info."
  454. static let anotherDownloadInProgress = "Model info download already in progress."
  455. }
  456. }