ModelDownloader.swift 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// Possible errors with model downloading.
  18. public enum DownloadError: Error, Equatable {
  19. /// No model with this name found on server.
  20. case notFound
  21. /// Caller does not have necessary permissions for this operation.
  22. case permissionDenied
  23. /// Conditions not met to perform download.
  24. case failedPrecondition
  25. /// Not enough space for model on device.
  26. case notEnoughSpace
  27. /// Malformed model name.
  28. case invalidArgument
  29. /// Other errors with description.
  30. case internalError(description: String)
  31. }
  32. /// Possible errors with locating model on device.
  33. public enum DownloadedModelError: Error {
  34. /// File system error.
  35. case fileIOError(description: String)
  36. /// Model not found on device.
  37. case notFound
  38. /// Other errors with description.
  39. case internalError(description: String)
  40. }
  41. /// Extension to handle internally meaningful errors.
  42. extension DownloadError {
  43. static let expiredDownloadURL: DownloadError = {
  44. DownloadError.internalError(description: "Expired model download URL.")
  45. }()
  46. }
  47. /// Possible ways to get a custom model.
  48. public enum ModelDownloadType {
  49. /// Get local model stored on device.
  50. case localModel
  51. /// Get local model on device and update to latest model from server in the background.
  52. case localModelUpdateInBackground
  53. /// Get latest model from server.
  54. case latestModel
  55. }
  56. /// Downloader to manage custom model downloads.
  57. public class ModelDownloader {
  58. /// Name of the app associated with this instance of ModelDownloader.
  59. private let appName: String
  60. /// Current Firebase app options.
  61. private let options: FirebaseOptions
  62. /// Installations instance for current Firebase app.
  63. private let installations: Installations
  64. /// User defaults for model info.
  65. private let userDefaults: UserDefaults
  66. /// Telemetry logger tied to this instance of model downloader.
  67. let telemetryLogger: TelemetryLogger?
  68. /// Number of retries in case of model download URL expiry.
  69. var numberOfRetries: Int = 1
  70. /// Handler that always runs on the main thread
  71. let mainQueueHandler = { handler in
  72. DispatchQueue.main.async {
  73. handler
  74. }
  75. }
  76. /// Download task associated with the model currently being downloaded.
  77. private var currentDownloadTask: [String: ModelDownloadTask] = [:]
  78. /// DispatchQueue to manage download task dictionary.
  79. let taskSerialQueue = DispatchQueue(label: "downloadtask.serial.queue")
  80. /// Shared dictionary mapping app name to a specific instance of model downloader.
  81. // TODO: Switch to using Firebase components.
  82. private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
  83. /// Private init for downloader.
  84. private init(app: FirebaseApp, defaults: UserDefaults = .firebaseMLDefaults) {
  85. appName = app.name
  86. options = app.options
  87. installations = Installations.installations(app: app)
  88. /// Respect Firebase-wide data collection setting.
  89. telemetryLogger = TelemetryLogger(app: app)
  90. userDefaults = defaults
  91. let notificationName = "FIRAppDeleteNotification"
  92. NotificationCenter.default.addObserver(
  93. self,
  94. selector: #selector(deleteModelDownloader),
  95. name: Notification.Name(notificationName),
  96. object: nil
  97. )
  98. }
  99. /// Handles app deletion notification.
  100. @objc private func deleteModelDownloader(notification: Notification) {
  101. let userInfoKey = "FIRAppNameKey"
  102. if let userInfo = notification.userInfo,
  103. let appName = userInfo[userInfoKey] as? String {
  104. ModelDownloader.modelDownloaderDictionary.removeValue(forKey: appName)
  105. // TODO: Do we need to force deinit downloader instance?
  106. // TODO: Clean up user defaults
  107. // TODO: Clean up local instances of app
  108. DeviceLogger.logEvent(level: .debug,
  109. message: ModelDownloader.DebugDescription.deleteModelDownloader,
  110. messageCode: .downloaderInstanceDeleted)
  111. }
  112. }
  113. /// Model downloader with default app.
  114. public static func modelDownloader() -> ModelDownloader {
  115. guard let defaultApp = FirebaseApp.app() else {
  116. fatalError(ModelDownloader.ErrorDescription.defaultAppNotConfigured)
  117. }
  118. return modelDownloader(app: defaultApp)
  119. }
  120. /// Model Downloader with custom app.
  121. public static func modelDownloader(app: FirebaseApp) -> ModelDownloader {
  122. if let downloader = modelDownloaderDictionary[app.name] {
  123. DeviceLogger.logEvent(level: .debug,
  124. message: ModelDownloader.DebugDescription.retrieveModelDownloader,
  125. messageCode: .downloaderInstanceRetrieved)
  126. return downloader
  127. } else {
  128. let downloader = ModelDownloader(app: app)
  129. modelDownloaderDictionary[app.name] = downloader
  130. DeviceLogger.logEvent(level: .debug,
  131. message: ModelDownloader.DebugDescription.createModelDownloader,
  132. messageCode: .downloaderInstanceCreated)
  133. return downloader
  134. }
  135. }
  136. /// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
  137. public func getModel(name modelName: String,
  138. downloadType: ModelDownloadType,
  139. conditions: ModelDownloadConditions,
  140. progressHandler: ((Float) -> Void)? = nil,
  141. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  142. switch downloadType {
  143. case .localModel:
  144. if let localModel = getLocalModel(modelName: modelName) {
  145. DeviceLogger.logEvent(level: .debug,
  146. message: ModelDownloader.DebugDescription.localModelFound,
  147. messageCode: .localModelFound)
  148. mainQueueHandler(completion(.success(localModel)))
  149. } else {
  150. getRemoteModel(
  151. modelName: modelName,
  152. conditions: conditions,
  153. progressHandler: progressHandler,
  154. completion: completion
  155. )
  156. }
  157. case .localModelUpdateInBackground:
  158. if let localModel = getLocalModel(modelName: modelName) {
  159. DeviceLogger.logEvent(level: .debug,
  160. message: ModelDownloader.DebugDescription.localModelFound,
  161. messageCode: .localModelFound)
  162. mainQueueHandler(completion(.success(localModel)))
  163. telemetryLogger?.logModelDownloadEvent(
  164. eventName: .modelDownload,
  165. status: .scheduled,
  166. downloadErrorCode: .noError
  167. )
  168. DispatchQueue.global(qos: .utility).async { [weak self] in
  169. self?.getRemoteModel(
  170. modelName: modelName,
  171. conditions: conditions,
  172. progressHandler: nil,
  173. completion: { result in
  174. switch result {
  175. case let .success(model):
  176. DeviceLogger.logEvent(level: .debug,
  177. message: ModelDownloader.DebugDescription
  178. .backgroundModelDownloaded,
  179. messageCode: .backgroundModelDownloaded)
  180. self?.telemetryLogger?.logModelDownloadEvent(
  181. eventName: .modelDownload,
  182. status: .succeeded,
  183. model: model,
  184. downloadErrorCode: .noError
  185. )
  186. case .failure:
  187. DeviceLogger.logEvent(level: .debug,
  188. message: ModelDownloader.ErrorDescription
  189. .backgroundModelDownload,
  190. messageCode: .backgroundDownloadError)
  191. self?.telemetryLogger?.logModelDownloadEvent(
  192. eventName: .modelDownload,
  193. status: .failed,
  194. downloadErrorCode: .downloadFailed
  195. )
  196. }
  197. }
  198. )
  199. }
  200. } else {
  201. getRemoteModel(
  202. modelName: modelName,
  203. conditions: conditions,
  204. progressHandler: progressHandler,
  205. completion: completion
  206. )
  207. }
  208. case .latestModel:
  209. getRemoteModel(
  210. modelName: modelName,
  211. conditions: conditions,
  212. progressHandler: progressHandler,
  213. completion: completion
  214. )
  215. }
  216. }
  217. /// Gets all downloaded models.
  218. public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
  219. DownloadedModelError>) -> Void) {
  220. do {
  221. let modelPaths = try ModelFileManager.contentsOfModelsDirectory()
  222. var customModels = Set<CustomModel>()
  223. for path in modelPaths {
  224. guard let modelName = ModelFileManager.getModelNameFromFilePath(path) else {
  225. let description = ModelDownloader.ErrorDescription.parseModelName(path.absoluteString)
  226. DeviceLogger.logEvent(level: .debug,
  227. message: description,
  228. messageCode: .modelNameParseError)
  229. mainQueueHandler(completion(.failure(.internalError(description: description))))
  230. return
  231. }
  232. guard let modelInfo = getLocalModelInfo(modelName: modelName) else {
  233. let description = ModelDownloader.ErrorDescription.noLocalModelInfo(modelName)
  234. DeviceLogger.logEvent(level: .debug,
  235. message: description,
  236. messageCode: .noLocalModelInfo)
  237. mainQueueHandler(completion(.failure(.internalError(description: description))))
  238. return
  239. }
  240. guard let pathURL = URL(string: modelInfo.path)?.standardizedFileURL,
  241. pathURL == path.standardizedFileURL else {
  242. DeviceLogger.logEvent(level: .debug,
  243. message: ModelDownloader.ErrorDescription.outdatedModelPath,
  244. messageCode: .outdatedModelPathError)
  245. mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
  246. .ErrorDescription.outdatedModelPath))))
  247. return
  248. }
  249. let model = CustomModel(localModelInfo: modelInfo)
  250. customModels.insert(model)
  251. }
  252. DeviceLogger.logEvent(level: .debug,
  253. message: ModelDownloader.DebugDescription.allLocalModelsFound,
  254. messageCode: .allLocalModelsFound)
  255. completion(.success(customModels))
  256. } catch let error as DownloadedModelError {
  257. DeviceLogger.logEvent(level: .debug,
  258. message: ModelDownloader.ErrorDescription.listModelsFailed(error),
  259. messageCode: .listModelsError)
  260. mainQueueHandler(completion(.failure(error)))
  261. } catch {
  262. DeviceLogger.logEvent(level: .debug,
  263. message: ModelDownloader.ErrorDescription.listModelsFailed(error),
  264. messageCode: .listModelsError)
  265. mainQueueHandler(completion(.failure(.internalError(description: error
  266. .localizedDescription))))
  267. }
  268. }
  269. /// Deletes a custom model from device.
  270. public func deleteDownloadedModel(name modelName: String,
  271. completion: @escaping (Result<Void, DownloadedModelError>)
  272. -> Void) {
  273. guard let localModelInfo = getLocalModelInfo(modelName: modelName),
  274. let localPath = URL(string: localModelInfo.path)
  275. else {
  276. DeviceLogger.logEvent(level: .debug,
  277. message: ModelDownloader.ErrorDescription.modelNotFound(modelName),
  278. messageCode: .modelNotFound)
  279. mainQueueHandler(completion(.failure(.notFound)))
  280. return
  281. }
  282. do {
  283. try ModelFileManager.removeFile(at: localPath)
  284. DeviceLogger.logEvent(level: .debug,
  285. message: ModelDownloader.DebugDescription.modelDeleted,
  286. messageCode: .modelDeleted)
  287. mainQueueHandler(completion(.success(())))
  288. } catch let error as DownloadedModelError {
  289. mainQueueHandler(completion(.failure(error)))
  290. } catch {
  291. mainQueueHandler(completion(.failure(.internalError(description: error
  292. .localizedDescription))))
  293. }
  294. }
  295. }
  296. extension ModelDownloader {
  297. /// Return local model info only if the model info is available and the corresponding model file is already on device.
  298. private func getLocalModelInfo(modelName: String) -> LocalModelInfo? {
  299. guard let localModelInfo = LocalModelInfo(
  300. fromDefaults: userDefaults,
  301. name: modelName,
  302. appName: appName
  303. ) else {
  304. let description = ModelDownloader.DebugDescription.noLocalModelInfo(modelName)
  305. DeviceLogger.logEvent(level: .debug,
  306. message: description,
  307. messageCode: .noLocalModelInfo)
  308. return nil
  309. }
  310. /// There is local model info on device, but no model file at the expected path.
  311. guard let localPath = URL(string: localModelInfo.path),
  312. ModelFileManager.isFileReachable(at: localPath) else {
  313. // TODO: Consider deleting local model info in user defaults.
  314. return nil
  315. }
  316. return localModelInfo
  317. }
  318. /// Get model saved on device if available.
  319. private func getLocalModel(modelName: String) -> CustomModel? {
  320. guard let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
  321. let model = CustomModel(localModelInfo: localModelInfo)
  322. return model
  323. }
  324. /// Download and get model from server, unless the latest model is already available on device.
  325. private func getRemoteModel(modelName: String,
  326. conditions: ModelDownloadConditions,
  327. progressHandler: ((Float) -> Void)? = nil,
  328. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  329. let localModelInfo = getLocalModelInfo(modelName: modelName)
  330. guard let projectID = options.projectID, let apiKey = options.apiKey else {
  331. DeviceLogger.logEvent(level: .debug,
  332. message: ModelDownloader.ErrorDescription.invalidOptions,
  333. messageCode: .invalidOptions)
  334. completion(.failure(.internalError(description: ModelDownloader.ErrorDescription
  335. .invalidOptions)))
  336. return
  337. }
  338. let modelInfoRetriever = ModelInfoRetriever(
  339. modelName: modelName,
  340. projectID: projectID,
  341. apiKey: apiKey,
  342. installations: installations,
  343. appName: appName,
  344. localModelInfo: localModelInfo,
  345. telemetryLogger: telemetryLogger
  346. )
  347. let downloader = ModelFileDownloader(conditions: conditions)
  348. downloadInfoAndModel(
  349. modelName: modelName,
  350. modelInfoRetriever: modelInfoRetriever,
  351. downloader: downloader,
  352. conditions: conditions,
  353. progressHandler: progressHandler,
  354. completion: completion
  355. )
  356. }
  357. func downloadInfoAndModel(modelName: String,
  358. modelInfoRetriever: ModelInfoRetriever,
  359. downloader: FileDownloader,
  360. conditions: ModelDownloadConditions,
  361. progressHandler: ((Float) -> Void)? = nil,
  362. completion: @escaping (Result<CustomModel, DownloadError>)
  363. -> Void) {
  364. modelInfoRetriever.downloadModelInfo { result in
  365. switch result {
  366. case let .success(downloadModelInfoResult):
  367. switch downloadModelInfoResult {
  368. /// New model info was downloaded from server.
  369. case let .modelInfo(remoteModelInfo):
  370. let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
  371. if let progressHandler = progressHandler {
  372. self.mainQueueHandler(progressHandler(progress))
  373. }
  374. }
  375. let taskCompletion: ModelDownloadTask.Completion = { result in
  376. switch result {
  377. case let .success(model):
  378. self.mainQueueHandler(completion(.success(model)))
  379. case let .failure(error):
  380. switch error {
  381. case .notFound:
  382. self.mainQueueHandler(completion(.failure(.notFound)))
  383. case .invalidArgument:
  384. self.mainQueueHandler(completion(.failure(.invalidArgument)))
  385. case .permissionDenied:
  386. self.mainQueueHandler(completion(.failure(.permissionDenied)))
  387. /// This is the error returned when URL expired.
  388. case .expiredDownloadURL:
  389. /// Check if retries are allowed.
  390. guard self.numberOfRetries > 0 else {
  391. self
  392. .mainQueueHandler(
  393. completion(.failure(.internalError(description: ModelDownloader
  394. .ErrorDescription
  395. .expiredModelInfo)))
  396. )
  397. return
  398. }
  399. self.numberOfRetries -= 1
  400. DeviceLogger.logEvent(level: .debug,
  401. message: ModelDownloader.DebugDescription.retryDownload,
  402. messageCode: .retryDownload)
  403. self.downloadInfoAndModel(
  404. modelName: modelName,
  405. modelInfoRetriever: modelInfoRetriever,
  406. downloader: downloader,
  407. conditions: conditions,
  408. progressHandler: progressHandler,
  409. completion: completion
  410. )
  411. default:
  412. self.mainQueueHandler(completion(.failure(error)))
  413. }
  414. }
  415. self.taskSerialQueue.async {
  416. /// Stop keeping track of current download task.
  417. self.currentDownloadTask.removeValue(forKey: modelName)
  418. }
  419. }
  420. self.taskSerialQueue.sync {
  421. /// Merge duplicate requests if there is already a download in progress for the same model.
  422. if let downloadTask = self.currentDownloadTask[modelName],
  423. downloadTask.canMergeRequests() {
  424. downloadTask.merge(
  425. newProgressHandler: taskProgressHandler,
  426. newCompletion: taskCompletion
  427. )
  428. DeviceLogger.logEvent(level: .debug,
  429. message: ModelDownloader.DebugDescription.mergingRequests,
  430. messageCode: .mergeRequests)
  431. if downloadTask.canResume() {
  432. downloadTask.resume()
  433. }
  434. } else {
  435. let downloadTask = ModelDownloadTask(
  436. remoteModelInfo: remoteModelInfo,
  437. appName: self.appName,
  438. defaults: self.userDefaults,
  439. downloader: downloader,
  440. progressHandler: taskProgressHandler,
  441. completion: taskCompletion,
  442. telemetryLogger: self.telemetryLogger
  443. )
  444. /// Keep track of current download task to allow for merging duplicate requests.
  445. self.currentDownloadTask[modelName] = downloadTask
  446. downloadTask.resume()
  447. }
  448. }
  449. /// Local model info is the latest model info.
  450. case .notModified:
  451. guard let localModel = self.getLocalModel(modelName: modelName) else {
  452. /// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
  453. // TODO: Consider handling: If model file is deleted after local model info is retrieved but before model info network call.
  454. self
  455. .mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
  456. .ErrorDescription.deletedLocalModelInfo))))
  457. return
  458. }
  459. self.mainQueueHandler(completion(.success(localModel)))
  460. }
  461. case let .failure(error):
  462. self.mainQueueHandler(completion(.failure(error)))
  463. }
  464. }
  465. }
  466. }
  467. /// Model downloader extension for testing.
  468. extension ModelDownloader {
  469. /// Model downloader instance for testing.
  470. // TODO: Consider using protocols
  471. static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
  472. app: FirebaseApp) -> ModelDownloader {
  473. if let downloader = modelDownloaderDictionary[app.name] {
  474. return downloader
  475. } else {
  476. let downloader = ModelDownloader(app: app, defaults: defaults)
  477. modelDownloaderDictionary[app.name] = downloader
  478. return downloader
  479. }
  480. }
  481. }
  482. /// Possible error messages while using model downloader.
  483. extension ModelDownloader {
  484. /// Debug descriptions.
  485. private enum DebugDescription {
  486. static let deleteModelDownloader = "Model downloader instance deleted due to app deletion."
  487. static let retrieveModelDownloader =
  488. "Initialized with existing downloader instance associated with this app."
  489. static let createModelDownloader =
  490. "Initialized with new downloader instance associated with this app."
  491. static let localModelFound = "Found local model on device."
  492. static let backgroundModelDownloaded = "Downloaded latest model in the background."
  493. static let allLocalModelsFound = "Found and listed all local models."
  494. static let modelDeleted = "Model deleted successfully."
  495. static let retryDownload = "Retrying download."
  496. static let noLocalModelInfo = { (name: String) in
  497. "No local model info for model file named: \(name)."
  498. }
  499. static let mergingRequests = "Merging duplicate download requests."
  500. }
  501. /// Error descriptions.
  502. private enum ErrorDescription {
  503. static let defaultAppNotConfigured =
  504. "Default Firebase app not configured."
  505. static let parseModelName = { (path: String) in
  506. "List models failed due to unexpected model file name at \(path)."
  507. }
  508. static let invalidOptions = "Unable to retrieve project ID and/or API key for Firebase app."
  509. static let listModelsFailed = { (error: Error) in
  510. "Unable to list models, failed with error: \(error)"
  511. }
  512. static let modelNotFound = { (name: String) in
  513. "Model deletion failed due to no model found with name: \(name)"
  514. }
  515. static let noLocalModelInfo = { (name: String) in
  516. "List models failed due to no local model info for model file named: \(name)."
  517. }
  518. static let modelDownloadFailed = { (error: Error) in
  519. "Model download failed with error: \(error)"
  520. }
  521. static let modelInfoRetrievalFailed = { (error: Error) in
  522. "Model info retrieval failed with error: \(error)"
  523. }
  524. static let outdatedModelPath =
  525. "List models failed due to outdated model paths in local storage."
  526. static let deletedLocalModelInfo =
  527. "Model unavailable due to deleted local model info."
  528. static let backgroundModelDownload =
  529. "Failed to update model in background."
  530. static let expiredModelInfo = "Unable to update expired model info."
  531. }
  532. }