ModelDownloader.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// Possible errors with model downloading.
  18. public enum DownloadError: Error, Equatable {
  19. /// No model with this name found on server.
  20. case notFound
  21. /// Caller does not have necessary permissions for this operation.
  22. case permissionDenied
  23. /// Conditions not met to perform download.
  24. case failedPrecondition
  25. /// Not enough space for model on device.
  26. case notEnoughSpace
  27. /// Malformed model name.
  28. case invalidArgument
  29. /// Other errors with description.
  30. case internalError(description: String)
  31. }
  32. /// Possible errors with locating model on device.
  33. public enum DownloadedModelError: Error, Equatable {
  34. /// File system error.
  35. case fileIOError(description: String)
  36. /// Model not found on device.
  37. case notFound
  38. /// Other errors with description.
  39. case internalError(description: String)
  40. }
  41. /// Possible ways to get a custom model.
  42. public enum ModelDownloadType {
  43. /// Get local model stored on device.
  44. case localModel
  45. /// Get local model on device and update to latest model from server in the background.
  46. case localModelUpdateInBackground
  47. /// Get latest model from server.
  48. case latestModel
  49. }
  50. /// Downloader to manage custom model downloads.
  51. public class ModelDownloader {
  52. /// Name of the app associated with this instance of ModelDownloader.
  53. private let appName: String
  54. /// Current Firebase app options.
  55. private let options: FirebaseOptions
  56. /// Installations instance for current Firebase app.
  57. private let installations: Installations
  58. /// User defaults for model info.
  59. private let userDefaults: UserDefaults
  60. /// Shared dictionary mapping app name to a specific instance of model downloader.
  61. // TODO: Switch to using Firebase components.
  62. private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
  63. /// Private init for downloader.
  64. private init(app: FirebaseApp, defaults: UserDefaults = .firebaseMLDefaults) {
  65. appName = app.name
  66. options = app.options
  67. installations = Installations.installations(app: app)
  68. userDefaults = defaults
  69. NotificationCenter.default.addObserver(
  70. self,
  71. selector: #selector(deleteModelDownloader),
  72. name: Notification.Name("FIRAppDeleteNotification"),
  73. object: nil
  74. )
  75. }
  76. /// Handles app deletion notification.
  77. @objc private func deleteModelDownloader(notification: Notification) {
  78. if let userInfo = notification.userInfo,
  79. let appName = userInfo["FIRAppNameKey"] as? String {
  80. ModelDownloader.modelDownloaderDictionary.removeValue(forKey: appName)
  81. // TODO: Clean up user defaults
  82. // TODO: Clean up local instances of app
  83. }
  84. }
  85. /// Model downloader with default app.
  86. public static func modelDownloader() -> ModelDownloader {
  87. guard let defaultApp = FirebaseApp.app() else {
  88. fatalError("Default Firebase app not configured.")
  89. }
  90. return modelDownloader(app: defaultApp)
  91. }
  92. /// Model Downloader with custom app.
  93. public static func modelDownloader(app: FirebaseApp) -> ModelDownloader {
  94. if let downloader = modelDownloaderDictionary[app.name] {
  95. return downloader
  96. } else {
  97. let downloader = ModelDownloader(app: app)
  98. modelDownloaderDictionary[app.name] = downloader
  99. return downloader
  100. }
  101. }
  102. /// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
  103. public func getModel(name modelName: String, downloadType: ModelDownloadType,
  104. conditions: ModelDownloadConditions,
  105. progressHandler: ((Float) -> Void)? = nil,
  106. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  107. switch downloadType {
  108. case .localModel:
  109. if let localModel = getLocalModel(modelName: modelName) {
  110. DispatchQueue.main.async {
  111. completion(.success(localModel))
  112. }
  113. } else {
  114. getRemoteModel(
  115. modelName: modelName,
  116. progressHandler: progressHandler,
  117. completion: completion
  118. )
  119. }
  120. case .localModelUpdateInBackground:
  121. if let localModel = getLocalModel(modelName: modelName) {
  122. DispatchQueue.main.async {
  123. completion(.success(localModel))
  124. }
  125. DispatchQueue.global(qos: .utility).async { [weak self] in
  126. self?.getRemoteModel(
  127. modelName: modelName,
  128. progressHandler: nil,
  129. completion: { result in
  130. switch result {
  131. // TODO: Log outcome of background download
  132. case .success: break
  133. case .failure: break
  134. }
  135. }
  136. )
  137. }
  138. } else {
  139. getRemoteModel(
  140. modelName: modelName,
  141. progressHandler: progressHandler,
  142. completion: completion
  143. )
  144. }
  145. case .latestModel:
  146. getRemoteModel(
  147. modelName: modelName,
  148. progressHandler: progressHandler,
  149. completion: completion
  150. )
  151. }
  152. }
  153. /// Gets all downloaded models.
  154. public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
  155. DownloadedModelError>) -> Void) {
  156. do {
  157. let modelPaths = try ModelFileManager.contentsOfModelsDirectory()
  158. var customModels = Set<CustomModel>()
  159. for path in modelPaths {
  160. guard let modelName = ModelFileManager.getModelNameFromFilePath(path) else {
  161. completion(.failure(.internalError(description: "Invalid model file name.")))
  162. return
  163. }
  164. guard let modelInfo = getLocalModelInfo(modelName: modelName) else {
  165. completion(
  166. .failure(.internalError(description: "Failed to get model info for model file."))
  167. )
  168. return
  169. }
  170. guard modelInfo.path == path.absoluteString else {
  171. completion(
  172. .failure(.internalError(description: "Outdated model paths in local storage."))
  173. )
  174. return
  175. }
  176. let model = CustomModel(localModelInfo: modelInfo)
  177. customModels.insert(model)
  178. }
  179. completion(.success(customModels))
  180. } catch let error as DownloadedModelError {
  181. completion(.failure(error))
  182. } catch {
  183. completion(.failure(.internalError(description: error.localizedDescription)))
  184. }
  185. }
  186. /// Deletes a custom model from device.
  187. public func deleteDownloadedModel(name modelName: String,
  188. completion: @escaping (Result<Void, DownloadedModelError>)
  189. -> Void) {
  190. // TODO: Delete previously downloaded model
  191. guard let localModelInfo = getLocalModelInfo(modelName: modelName),
  192. let localPath = URL(string: localModelInfo.path)
  193. else {
  194. completion(.failure(.notFound))
  195. return
  196. }
  197. do {
  198. try ModelFileManager.removeFile(at: localPath)
  199. completion(.success(()))
  200. } catch let error as DownloadedModelError {
  201. completion(.failure(error))
  202. } catch {
  203. completion(.failure(.internalError(description: error.localizedDescription)))
  204. }
  205. }
  206. }
  207. extension ModelDownloader {
  208. /// Return local model info only if the model info is available and the corresponding model file is already on device.
  209. private func getLocalModelInfo(modelName: String) -> LocalModelInfo? {
  210. guard let localModelInfo = LocalModelInfo(
  211. fromDefaults: userDefaults,
  212. name: modelName,
  213. appName: appName
  214. ) else {
  215. return nil
  216. }
  217. /// There is local model info on device, but no model file at the expected path.
  218. guard let localPath = URL(string: localModelInfo.path),
  219. ModelFileManager.isFileReachable(at: localPath) else {
  220. // TODO: Consider deleting local model info in user defaults.
  221. return nil
  222. }
  223. return localModelInfo
  224. }
  225. /// Get model saved on device if available.
  226. private func getLocalModel(modelName: String) -> CustomModel? {
  227. guard let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
  228. let model = CustomModel(localModelInfo: localModelInfo)
  229. return model
  230. }
  231. /// Download and get model from server, unless the latest model is already available on device.
  232. private func getRemoteModel(modelName: String,
  233. progressHandler: ((Float) -> Void)? = nil,
  234. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  235. let localModelInfo = getLocalModelInfo(modelName: modelName)
  236. let modelInfoRetriever = ModelInfoRetriever(
  237. modelName: modelName,
  238. options: options,
  239. installations: installations,
  240. appName: appName,
  241. localModelInfo: localModelInfo
  242. )
  243. modelInfoRetriever.downloadModelInfo { result in
  244. switch result {
  245. case let .success(downloadModelInfoResult):
  246. switch downloadModelInfoResult {
  247. /// New model info was downloaded from server.
  248. case let .modelInfo(remoteModelInfo):
  249. let downloadTask = ModelDownloadTask(
  250. remoteModelInfo: remoteModelInfo,
  251. appName: self.appName,
  252. defaults: self.userDefaults,
  253. progressHandler: progressHandler,
  254. completion: completion
  255. )
  256. downloadTask.resumeModelDownload()
  257. /// Local model info is the latest model info.
  258. case .notModified:
  259. guard let localModel = self.getLocalModel(modelName: modelName) else {
  260. DispatchQueue.main.async {
  261. /// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
  262. // TODO: Consider handling: if model file is deleted after local model info is retrieved but before model info network call
  263. completion(
  264. .failure(
  265. .internalError(description: "Model unavailable due to deleted local model info.")
  266. )
  267. )
  268. }
  269. return
  270. }
  271. DispatchQueue.main.async {
  272. completion(.success(localModel))
  273. }
  274. }
  275. case let .failure(downloadError):
  276. DispatchQueue.main.async {
  277. completion(.failure(downloadError))
  278. }
  279. }
  280. }
  281. }
  282. }
  283. /// Model downloader extension for testing.
  284. extension ModelDownloader {
  285. /// Model downloader instance for testing.
  286. // TODO: Consider using protocols
  287. static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
  288. app: FirebaseApp) -> ModelDownloader {
  289. if let downloader = modelDownloaderDictionary[app.name] {
  290. return downloader
  291. } else {
  292. let downloader = ModelDownloader(app: app, defaults: defaults)
  293. modelDownloaderDictionary[app.name] = downloader
  294. return downloader
  295. }
  296. }
  297. }