ModelDownloader.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// Possible errors with model downloading.
  18. public enum DownloadError: Error {
  19. /// No model with this name found on server.
  20. case notFound
  21. /// Caller does not have necessary permissions for this operation.
  22. case permissionDenied
  23. /// Conditions not met to perform download.
  24. case failedPrecondition
  25. /// Not enough space for model on device.
  26. case notEnoughSpace
  27. /// Malformed model name.
  28. case invalidArgument
  29. /// Other errors with description.
  30. case internalError(description: String)
  31. }
  32. /// Possible errors with locating model on device.
  33. public enum DownloadedModelError: Error {
  34. /// File system error.
  35. case fileIOError(description: String)
  36. /// Model not found on device.
  37. case notFound
  38. /// Other errors with description.
  39. case internalError(description: String)
  40. }
  41. /// Possible ways to get a custom model.
  42. public enum ModelDownloadType {
  43. /// Get local model stored on device.
  44. case localModel
  45. /// Get local model on device and update to latest model from server in the background.
  46. case localModelUpdateInBackground
  47. /// Get latest model from server.
  48. case latestModel
  49. }
  50. /// Downloader to manage custom model downloads.
  51. public class ModelDownloader {
  52. /// Name of the app associated with this instance of ModelDownloader.
  53. private let appName: String
  54. /// Current Firebase app options.
  55. private let options: FirebaseOptions
  56. /// Installations instance for current Firebase app.
  57. private let installations: Installations
  58. /// User defaults for model info.
  59. private let userDefaults: UserDefaults
  60. /// Telemetry logger tied to this instance of model downloader.
  61. let telemetryLogger: TelemetryLogger?
  62. /// Shared dictionary mapping app name to a specific instance of model downloader.
  63. // TODO: Switch to using Firebase components.
  64. private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
  65. /// Private init for downloader.
  66. private init(app: FirebaseApp, defaults: UserDefaults = .firebaseMLDefaults) {
  67. appName = app.name
  68. options = app.options
  69. installations = Installations.installations(app: app)
  70. /// Respect Firebase-wide data collection setting.
  71. telemetryLogger = TelemetryLogger(app: app)
  72. userDefaults = defaults
  73. let notificationName = "FIRAppDeleteNotification"
  74. NotificationCenter.default.addObserver(
  75. self,
  76. selector: #selector(deleteModelDownloader),
  77. name: Notification.Name(notificationName),
  78. object: nil
  79. )
  80. }
  81. /// Handles app deletion notification.
  82. @objc private func deleteModelDownloader(notification: Notification) {
  83. let userInfoKey = "FIRAppNameKey"
  84. if let userInfo = notification.userInfo,
  85. let appName = userInfo[userInfoKey] as? String {
  86. ModelDownloader.modelDownloaderDictionary.removeValue(forKey: appName)
  87. // TODO: Clean up user defaults
  88. // TODO: Clean up local instances of app
  89. }
  90. }
  91. /// Model downloader with default app.
  92. public static func modelDownloader() -> ModelDownloader {
  93. guard let defaultApp = FirebaseApp.app() else {
  94. fatalError(ModelDownloader.ErrorDescription.defaultAppNotConfigured)
  95. }
  96. return modelDownloader(app: defaultApp)
  97. }
  98. /// Model Downloader with custom app.
  99. public static func modelDownloader(app: FirebaseApp) -> ModelDownloader {
  100. if let downloader = modelDownloaderDictionary[app.name] {
  101. return downloader
  102. } else {
  103. let downloader = ModelDownloader(app: app)
  104. modelDownloaderDictionary[app.name] = downloader
  105. return downloader
  106. }
  107. }
  108. /// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
  109. public func getModel(name modelName: String, downloadType: ModelDownloadType,
  110. conditions: ModelDownloadConditions,
  111. progressHandler: ((Float) -> Void)? = nil,
  112. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  113. switch downloadType {
  114. case .localModel:
  115. if let localModel = getLocalModel(modelName: modelName) {
  116. DispatchQueue.main.async {
  117. completion(.success(localModel))
  118. }
  119. } else {
  120. getRemoteModel(
  121. modelName: modelName,
  122. progressHandler: progressHandler,
  123. completion: completion
  124. )
  125. }
  126. case .localModelUpdateInBackground:
  127. if let localModel = getLocalModel(modelName: modelName) {
  128. DispatchQueue.main.async {
  129. completion(.success(localModel))
  130. }
  131. DispatchQueue.global(qos: .utility).async { [weak self] in
  132. self?.getRemoteModel(
  133. modelName: modelName,
  134. progressHandler: nil,
  135. completion: { result in
  136. switch result {
  137. case .success: break
  138. case .failure:
  139. DeviceLogger.logEvent(
  140. level: .info,
  141. category: .modelDownload,
  142. message: ModelDownloader.ErrorDescription.backgroundModelDownload,
  143. messageCode: .backgroundDownloadError
  144. )
  145. }
  146. }
  147. )
  148. }
  149. } else {
  150. getRemoteModel(
  151. modelName: modelName,
  152. progressHandler: progressHandler,
  153. completion: completion
  154. )
  155. }
  156. case .latestModel:
  157. getRemoteModel(
  158. modelName: modelName,
  159. progressHandler: progressHandler,
  160. completion: completion
  161. )
  162. }
  163. }
  164. /// Gets all downloaded models.
  165. public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
  166. DownloadedModelError>) -> Void) {
  167. do {
  168. let modelPaths = try ModelFileManager.contentsOfModelsDirectory()
  169. var customModels = Set<CustomModel>()
  170. for path in modelPaths {
  171. guard let modelName = ModelFileManager.getModelNameFromFilePath(path) else {
  172. completion(.failure(.internalError(description: ModelDownloader
  173. .ErrorDescription
  174. .parseModelName(path.absoluteString))))
  175. return
  176. }
  177. guard let modelInfo = getLocalModelInfo(modelName: modelName) else {
  178. completion(
  179. .failure(.internalError(description: ModelDownloader.ErrorDescription
  180. .retrieveLocalModelInfo))
  181. )
  182. return
  183. }
  184. guard modelInfo.path == path.absoluteString else {
  185. completion(
  186. .failure(.internalError(description: ModelDownloader.ErrorDescription
  187. .outdatedModelPath))
  188. )
  189. return
  190. }
  191. let model = CustomModel(localModelInfo: modelInfo)
  192. customModels.insert(model)
  193. }
  194. completion(.success(customModels))
  195. } catch let error as DownloadedModelError {
  196. completion(.failure(error))
  197. } catch {
  198. completion(.failure(.internalError(description: error.localizedDescription)))
  199. }
  200. }
  201. /// Deletes a custom model from device.
  202. public func deleteDownloadedModel(name modelName: String,
  203. completion: @escaping (Result<Void, DownloadedModelError>)
  204. -> Void) {
  205. guard let localModelInfo = getLocalModelInfo(modelName: modelName),
  206. let localPath = URL(string: localModelInfo.path)
  207. else {
  208. completion(.failure(.notFound))
  209. return
  210. }
  211. do {
  212. try ModelFileManager.removeFile(at: localPath)
  213. completion(.success(()))
  214. } catch let error as DownloadedModelError {
  215. completion(.failure(error))
  216. } catch {
  217. completion(.failure(.internalError(description: error.localizedDescription)))
  218. }
  219. }
  220. }
  221. extension ModelDownloader {
  222. /// Return local model info only if the model info is available and the corresponding model file is already on device.
  223. private func getLocalModelInfo(modelName: String) -> LocalModelInfo? {
  224. guard let localModelInfo = LocalModelInfo(
  225. fromDefaults: userDefaults,
  226. name: modelName,
  227. appName: appName
  228. ) else {
  229. return nil
  230. }
  231. /// There is local model info on device, but no model file at the expected path.
  232. guard let localPath = URL(string: localModelInfo.path),
  233. ModelFileManager.isFileReachable(at: localPath) else {
  234. // TODO: Consider deleting local model info in user defaults.
  235. return nil
  236. }
  237. return localModelInfo
  238. }
  239. /// Get model saved on device if available.
  240. private func getLocalModel(modelName: String) -> CustomModel? {
  241. guard let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
  242. let model = CustomModel(localModelInfo: localModelInfo)
  243. return model
  244. }
  245. /// Download and get model from server, unless the latest model is already available on device.
  246. private func getRemoteModel(modelName: String,
  247. progressHandler: ((Float) -> Void)? = nil,
  248. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  249. let localModelInfo = getLocalModelInfo(modelName: modelName)
  250. let modelInfoRetriever = ModelInfoRetriever(
  251. modelName: modelName,
  252. options: options,
  253. installations: installations,
  254. appName: appName,
  255. localModelInfo: localModelInfo
  256. )
  257. modelInfoRetriever.downloadModelInfo { result in
  258. switch result {
  259. case let .success(downloadModelInfoResult):
  260. switch downloadModelInfoResult {
  261. /// New model info was downloaded from server.
  262. case let .modelInfo(remoteModelInfo):
  263. let downloadTask = ModelDownloadTask(
  264. remoteModelInfo: remoteModelInfo,
  265. appName: self.appName,
  266. defaults: self.userDefaults,
  267. modelInfoRetriever: modelInfoRetriever,
  268. telemetryLogger: self.telemetryLogger,
  269. progressHandler: progressHandler,
  270. completion: completion
  271. )
  272. // TODO: Is it possible for download URL to expire here?
  273. downloadTask.resumeModelDownload()
  274. /// Local model info is the latest model info.
  275. case .notModified:
  276. guard let localModel = self.getLocalModel(modelName: modelName) else {
  277. DispatchQueue.main.async {
  278. /// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
  279. // TODO: Consider handling: If model file is deleted after local model info is retrieved but before model info network call.
  280. completion(
  281. .failure(
  282. .internalError(description: ModelDownloader.ErrorDescription
  283. .deletedLocalModelInfo)
  284. )
  285. )
  286. }
  287. return
  288. }
  289. DispatchQueue.main.async {
  290. completion(.success(localModel))
  291. }
  292. }
  293. case let .failure(downloadError):
  294. DispatchQueue.main.async {
  295. completion(.failure(downloadError))
  296. }
  297. }
  298. }
  299. }
  300. }
  301. /// Model downloader extension for testing.
  302. extension ModelDownloader {
  303. /// Model downloader instance for testing.
  304. // TODO: Consider using protocols
  305. static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
  306. app: FirebaseApp) -> ModelDownloader {
  307. if let downloader = modelDownloaderDictionary[app.name] {
  308. return downloader
  309. } else {
  310. let downloader = ModelDownloader(app: app, defaults: defaults)
  311. modelDownloaderDictionary[app.name] = downloader
  312. return downloader
  313. }
  314. }
  315. }
  316. /// Possible error messages while using model downloader.
  317. extension ModelDownloader {
  318. /// Error descriptions.
  319. private enum ErrorDescription {
  320. static let defaultAppNotConfigured =
  321. "Default Firebase app not configured."
  322. static let parseModelName = { (path: String) in
  323. "Unable to parse model file name at \(path)."
  324. }
  325. static let retrieveLocalModelInfo =
  326. "Failed to get stored model info for model file."
  327. static let outdatedModelPath = "Outdated model paths in local storage."
  328. static let deletedLocalModelInfo =
  329. "Model unavailable due to deleted local model info."
  330. static let backgroundModelDownload: StaticString =
  331. "Failed to update model in background."
  332. }
  333. }