ModelDownloader.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. import FirebaseInstallations
  17. /// Possible errors with model downloading.
  18. public enum DownloadError: Error, Equatable {
  19. /// No model with this name found on server.
  20. case notFound
  21. /// Caller does not have necessary permissions for this operation.
  22. case permissionDenied
  23. /// Conditions not met to perform download.
  24. case failedPrecondition
  25. /// Not enough space for model on device.
  26. case notEnoughSpace
  27. /// Malformed model name.
  28. case invalidArgument
  29. /// Other errors with description.
  30. case internalError(description: String)
  31. }
  32. /// Possible errors with locating model on device.
  33. public enum DownloadedModelError: Error, Equatable {
  34. /// File system error.
  35. case fileIOError(description: String)
  36. /// Model not found on device.
  37. case notFound
  38. /// Other errors with description.
  39. case internalError(description: String)
  40. }
  41. /// Possible ways to get a custom model.
  42. public enum ModelDownloadType {
  43. /// Get local model stored on device.
  44. case localModel
  45. /// Get local model on device and update to latest model from server in the background.
  46. case localModelUpdateInBackground
  47. /// Get latest model from server.
  48. case latestModel
  49. }
  50. /// Downloader to manage custom model downloads.
  51. public class ModelDownloader {
  52. /// Name of the app associated with this instance of ModelDownloader.
  53. private let appName: String
  54. /// Current Firebase app options.
  55. private let options: FirebaseOptions
  56. /// Installations instance for current Firebase app.
  57. private let installations: Installations
  58. /// User defaults for model info.
  59. private let userDefaults: UserDefaults
  60. /// Telemetry logger tied to this instance of model downloader.
  61. let telemetryLogger: TelemetryLogger?
  62. /// Shared dictionary mapping app name to a specific instance of model downloader.
  63. // TODO: Switch to using Firebase components.
  64. private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
  65. /// Private init for downloader.
  66. private init(app: FirebaseApp, defaults: UserDefaults = .firebaseMLDefaults) {
  67. appName = app.name
  68. options = app.options
  69. installations = Installations.installations(app: app)
  70. /// Respect Firebase-wide data collection setting.
  71. telemetryLogger = TelemetryLogger(app: app)
  72. userDefaults = defaults
  73. NotificationCenter.default.addObserver(
  74. self,
  75. selector: #selector(deleteModelDownloader),
  76. name: Notification.Name("FIRAppDeleteNotification"),
  77. object: nil
  78. )
  79. }
  80. /// Handles app deletion notification.
  81. @objc private func deleteModelDownloader(notification: Notification) {
  82. if let userInfo = notification.userInfo,
  83. let appName = userInfo["FIRAppNameKey"] as? String {
  84. ModelDownloader.modelDownloaderDictionary.removeValue(forKey: appName)
  85. // TODO: Clean up user defaults
  86. // TODO: Clean up local instances of app
  87. }
  88. }
  89. /// Model downloader with default app.
  90. public static func modelDownloader() -> ModelDownloader {
  91. guard let defaultApp = FirebaseApp.app() else {
  92. fatalError("Default Firebase app not configured.")
  93. }
  94. return modelDownloader(app: defaultApp)
  95. }
  96. /// Model Downloader with custom app.
  97. public static func modelDownloader(app: FirebaseApp) -> ModelDownloader {
  98. if let downloader = modelDownloaderDictionary[app.name] {
  99. return downloader
  100. } else {
  101. let downloader = ModelDownloader(app: app)
  102. modelDownloaderDictionary[app.name] = downloader
  103. return downloader
  104. }
  105. }
  106. /// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
  107. public func getModel(name modelName: String, downloadType: ModelDownloadType,
  108. conditions: ModelDownloadConditions,
  109. progressHandler: ((Float) -> Void)? = nil,
  110. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  111. switch downloadType {
  112. case .localModel:
  113. if let localModel = getLocalModel(modelName: modelName) {
  114. DispatchQueue.main.async {
  115. completion(.success(localModel))
  116. }
  117. } else {
  118. getRemoteModel(
  119. modelName: modelName,
  120. progressHandler: progressHandler,
  121. completion: completion
  122. )
  123. }
  124. case .localModelUpdateInBackground:
  125. if let localModel = getLocalModel(modelName: modelName) {
  126. DispatchQueue.main.async {
  127. completion(.success(localModel))
  128. }
  129. DispatchQueue.global(qos: .utility).async { [weak self] in
  130. self?.getRemoteModel(
  131. modelName: modelName,
  132. progressHandler: nil,
  133. completion: { result in
  134. switch result {
  135. // TODO: Log outcome of background download
  136. case .success: break
  137. case .failure: break
  138. }
  139. }
  140. )
  141. }
  142. } else {
  143. getRemoteModel(
  144. modelName: modelName,
  145. progressHandler: progressHandler,
  146. completion: completion
  147. )
  148. }
  149. case .latestModel:
  150. getRemoteModel(
  151. modelName: modelName,
  152. progressHandler: progressHandler,
  153. completion: completion
  154. )
  155. }
  156. }
  157. /// Gets all downloaded models.
  158. public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
  159. DownloadedModelError>) -> Void) {
  160. do {
  161. let modelPaths = try ModelFileManager.contentsOfModelsDirectory()
  162. var customModels = Set<CustomModel>()
  163. for path in modelPaths {
  164. guard let modelName = ModelFileManager.getModelNameFromFilePath(path) else {
  165. completion(.failure(.internalError(description: "Invalid model file name.")))
  166. return
  167. }
  168. guard let modelInfo = getLocalModelInfo(modelName: modelName) else {
  169. completion(
  170. .failure(.internalError(description: "Failed to get model info for model file."))
  171. )
  172. return
  173. }
  174. guard modelInfo.path == path.absoluteString else {
  175. completion(
  176. .failure(.internalError(description: "Outdated model paths in local storage."))
  177. )
  178. return
  179. }
  180. let model = CustomModel(localModelInfo: modelInfo)
  181. customModels.insert(model)
  182. }
  183. completion(.success(customModels))
  184. } catch let error as DownloadedModelError {
  185. completion(.failure(error))
  186. } catch {
  187. completion(.failure(.internalError(description: error.localizedDescription)))
  188. }
  189. }
  190. /// Deletes a custom model from device.
  191. public func deleteDownloadedModel(name modelName: String,
  192. completion: @escaping (Result<Void, DownloadedModelError>)
  193. -> Void) {
  194. // TODO: Delete previously downloaded model
  195. guard let localModelInfo = getLocalModelInfo(modelName: modelName),
  196. let localPath = URL(string: localModelInfo.path)
  197. else {
  198. completion(.failure(.notFound))
  199. return
  200. }
  201. do {
  202. try ModelFileManager.removeFile(at: localPath)
  203. completion(.success(()))
  204. } catch let error as DownloadedModelError {
  205. completion(.failure(error))
  206. } catch {
  207. completion(.failure(.internalError(description: error.localizedDescription)))
  208. }
  209. }
  210. }
  211. extension ModelDownloader {
  212. /// Return local model info only if the model info is available and the corresponding model file is already on device.
  213. private func getLocalModelInfo(modelName: String) -> LocalModelInfo? {
  214. guard let localModelInfo = LocalModelInfo(
  215. fromDefaults: userDefaults,
  216. name: modelName,
  217. appName: appName
  218. ) else {
  219. return nil
  220. }
  221. /// There is local model info on device, but no model file at the expected path.
  222. guard let localPath = URL(string: localModelInfo.path),
  223. ModelFileManager.isFileReachable(at: localPath) else {
  224. // TODO: Consider deleting local model info in user defaults.
  225. return nil
  226. }
  227. return localModelInfo
  228. }
  229. /// Get model saved on device if available.
  230. private func getLocalModel(modelName: String) -> CustomModel? {
  231. guard let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
  232. let model = CustomModel(localModelInfo: localModelInfo)
  233. return model
  234. }
  235. /// Download and get model from server, unless the latest model is already available on device.
  236. private func getRemoteModel(modelName: String,
  237. progressHandler: ((Float) -> Void)? = nil,
  238. completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
  239. let localModelInfo = getLocalModelInfo(modelName: modelName)
  240. let modelInfoRetriever = ModelInfoRetriever(
  241. modelName: modelName,
  242. options: options,
  243. installations: installations,
  244. appName: appName,
  245. localModelInfo: localModelInfo
  246. )
  247. modelInfoRetriever.downloadModelInfo { result in
  248. switch result {
  249. case let .success(downloadModelInfoResult):
  250. switch downloadModelInfoResult {
  251. /// New model info was downloaded from server.
  252. case let .modelInfo(remoteModelInfo):
  253. let downloadTask = ModelDownloadTask(
  254. remoteModelInfo: remoteModelInfo,
  255. appName: self.appName,
  256. defaults: self.userDefaults,
  257. telemetryLogger: self.telemetryLogger,
  258. progressHandler: progressHandler,
  259. completion: completion
  260. )
  261. downloadTask.resumeModelDownload()
  262. /// Local model info is the latest model info.
  263. case .notModified:
  264. guard let localModel = self.getLocalModel(modelName: modelName) else {
  265. DispatchQueue.main.async {
  266. /// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
  267. // TODO: Consider handling: if model file is deleted after local model info is retrieved but before model info network call
  268. completion(
  269. .failure(
  270. .internalError(description: "Model unavailable due to deleted local model info.")
  271. )
  272. )
  273. }
  274. return
  275. }
  276. DispatchQueue.main.async {
  277. completion(.success(localModel))
  278. }
  279. }
  280. case let .failure(downloadError):
  281. DispatchQueue.main.async {
  282. completion(.failure(downloadError))
  283. }
  284. }
  285. }
  286. }
  287. }
  288. /// Model downloader extension for testing.
  289. extension ModelDownloader {
  290. /// Model downloader instance for testing.
  291. // TODO: Consider using protocols
  292. static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
  293. app: FirebaseApp) -> ModelDownloader {
  294. if let downloader = modelDownloaderDictionary[app.name] {
  295. return downloader
  296. } else {
  297. let downloader = ModelDownloader(app: app, defaults: defaults)
  298. modelDownloaderDictionary[app.name] = downloader
  299. return downloader
  300. }
  301. }
  302. }