ModelDownloadTask.swift 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. /// Task to download model file to device.
  16. class ModelDownloadTask {
  17. /// Name of the app associated with this instance of ModelDownloadTask.
  18. private let appName: String
  19. /// Model info downloaded from server.
  20. private(set) var remoteModelInfo: RemoteModelInfo
  21. /// User defaults to which local model info should ultimately be written.
  22. private let defaults: UserDefaults
  23. /// Keeps track of download associated with this model download task.
  24. private(set) var downloadStatus: ModelDownloadStatus = .ready
  25. /// Downloader instance.
  26. private let downloader: FileDownloader
  27. /// Telemetry logger.
  28. private let telemetryLogger: TelemetryLogger?
  29. /// Download progress handler.
  30. typealias ProgressHandler = (Float) -> Void
  31. private var progressHandler: ProgressHandler?
  32. /// Download completion handler.
  33. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  34. private var completion: Completion
  35. init(remoteModelInfo: RemoteModelInfo,
  36. appName: String,
  37. defaults: UserDefaults,
  38. downloader: FileDownloader,
  39. progressHandler: ProgressHandler? = nil,
  40. completion: @escaping Completion,
  41. telemetryLogger: TelemetryLogger? = nil) {
  42. self.remoteModelInfo = remoteModelInfo
  43. self.appName = appName
  44. self.defaults = defaults
  45. self.downloader = downloader
  46. self.progressHandler = progressHandler
  47. self.completion = completion
  48. self.telemetryLogger = telemetryLogger
  49. }
  50. }
  51. extension ModelDownloadTask {
  52. /// Check if requests can be merged.
  53. func canMergeRequests() -> Bool {
  54. return downloadStatus != .complete
  55. }
  56. /// Merge duplicate requests. This method is not thread-safe.
  57. func merge(newProgressHandler: ProgressHandler? = nil, newCompletion: @escaping Completion) {
  58. let originalProgressHandler = progressHandler
  59. progressHandler = { progress in
  60. originalProgressHandler?(progress)
  61. newProgressHandler?(progress)
  62. }
  63. let originalCompletion = completion
  64. completion = { result in
  65. originalCompletion(result)
  66. newCompletion(result)
  67. }
  68. }
  69. /// Check if download task can be resumed.
  70. func canResume() -> Bool {
  71. return downloadStatus == .ready
  72. }
  73. /// Download model file.
  74. func resume() {
  75. // Prevent multiple concurrent downloads.
  76. guard downloadStatus != .downloading else {
  77. DeviceLogger.logEvent(level: .debug,
  78. message: ModelDownloadTask.ErrorDescription.anotherDownloadInProgress,
  79. messageCode: .anotherDownloadInProgressError)
  80. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  81. status: .failed,
  82. model: CustomModel(name: remoteModelInfo.name,
  83. size: remoteModelInfo.size,
  84. path: "",
  85. hash: remoteModelInfo.modelHash),
  86. downloadErrorCode: .downloadFailed)
  87. return
  88. }
  89. downloadStatus = .downloading
  90. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  91. status: .downloading,
  92. model: CustomModel(name: remoteModelInfo.name,
  93. size: remoteModelInfo.size,
  94. path: "",
  95. hash: remoteModelInfo.modelHash),
  96. downloadErrorCode: .noError)
  97. downloader.downloadFile(with: remoteModelInfo.downloadURL,
  98. progressHandler: { downloadedBytes, totalBytes in
  99. /// Fraction of model file downloaded.
  100. let calculatedProgress = Float(downloadedBytes) / Float(totalBytes)
  101. self.progressHandler?(calculatedProgress)
  102. }) { result in
  103. self.downloadStatus = .complete
  104. switch result {
  105. case let .success(response):
  106. DeviceLogger.logEvent(level: .debug,
  107. message: ModelDownloadTask.DebugDescription
  108. .receivedServerResponse,
  109. messageCode: .validHTTPResponse)
  110. self.handleResponse(
  111. response: response.urlResponse,
  112. tempURL: response.fileURL,
  113. completion: self.completion
  114. )
  115. case let .failure(error):
  116. var downloadError: DownloadError
  117. switch error {
  118. case let FileDownloaderError.networkError(error):
  119. let description = ModelDownloadTask.ErrorDescription
  120. .invalidHostName(error.localizedDescription)
  121. downloadError = .failedPrecondition
  122. DeviceLogger.logEvent(level: .debug,
  123. message: description,
  124. messageCode: .hostnameError)
  125. self.telemetryLogger?.logModelDownloadEvent(
  126. eventName: .modelDownload,
  127. status: .failed,
  128. model: CustomModel(name: self.remoteModelInfo.name,
  129. size: self.remoteModelInfo.size,
  130. path: "",
  131. hash: self.remoteModelInfo.modelHash),
  132. downloadErrorCode: .noConnection
  133. )
  134. case FileDownloaderError.unexpectedResponseType:
  135. let description = ModelDownloadTask.ErrorDescription.invalidHTTPResponse
  136. downloadError = .internalError(description: description)
  137. DeviceLogger.logEvent(level: .debug,
  138. message: description,
  139. messageCode: .invalidHTTPResponse)
  140. self.telemetryLogger?.logModelDownloadEvent(
  141. eventName: .modelDownload,
  142. status: .failed,
  143. model: CustomModel(name: self.remoteModelInfo.name,
  144. size: self.remoteModelInfo.size,
  145. path: "",
  146. hash: self.remoteModelInfo.modelHash),
  147. downloadErrorCode: .downloadFailed
  148. )
  149. default:
  150. let description = ModelDownloadTask.ErrorDescription.unknownDownloadError
  151. downloadError = .internalError(description: description)
  152. DeviceLogger.logEvent(level: .debug,
  153. message: description,
  154. messageCode: .modelDownloadError)
  155. self.telemetryLogger?.logModelDownloadEvent(
  156. eventName: .modelDownload,
  157. status: .failed,
  158. model: CustomModel(name: self.remoteModelInfo.name,
  159. size: self.remoteModelInfo.size,
  160. path: "",
  161. hash: self.remoteModelInfo.modelHash),
  162. downloadErrorCode: .downloadFailed
  163. )
  164. }
  165. self.completion(.failure(downloadError))
  166. }
  167. }
  168. }
  169. /// Handle model download response.
  170. func handleResponse(response: HTTPURLResponse, tempURL: URL, completion: @escaping Completion) {
  171. guard (200 ..< 299).contains(response.statusCode) else {
  172. switch response.statusCode {
  173. case 400:
  174. // Possible failure due to download URL expiry. Check if download URL has expired.
  175. guard remoteModelInfo.urlExpiryTime < Date() else {
  176. DeviceLogger.logEvent(level: .debug,
  177. message: ModelDownloadTask.ErrorDescription
  178. .invalidArgument(remoteModelInfo.name),
  179. messageCode: .invalidArgument)
  180. telemetryLogger?.logModelDownloadEvent(
  181. eventName: .modelDownload,
  182. status: .failed,
  183. model: CustomModel(name: remoteModelInfo.name,
  184. size: remoteModelInfo.size,
  185. path: "",
  186. hash: remoteModelInfo.modelHash),
  187. downloadErrorCode: .httpError(code: response.statusCode)
  188. )
  189. completion(.failure(.invalidArgument))
  190. return
  191. }
  192. DeviceLogger.logEvent(level: .debug,
  193. message: ModelDownloadTask.ErrorDescription.expiredModelInfo,
  194. messageCode: .expiredModelInfo)
  195. telemetryLogger?.logModelDownloadEvent(
  196. eventName: .modelDownload,
  197. status: .failed,
  198. model: CustomModel(name: remoteModelInfo.name,
  199. size: remoteModelInfo.size,
  200. path: "",
  201. hash: remoteModelInfo.modelHash),
  202. downloadErrorCode: .urlExpired
  203. )
  204. completion(.failure(.expiredDownloadURL))
  205. case 401, 403:
  206. DeviceLogger.logEvent(level: .debug,
  207. message: ModelDownloadTask.ErrorDescription.permissionDenied,
  208. messageCode: .permissionDenied)
  209. telemetryLogger?.logModelDownloadEvent(
  210. eventName: .modelDownload,
  211. status: .failed,
  212. model: CustomModel(name: remoteModelInfo.name,
  213. size: remoteModelInfo.size,
  214. path: "",
  215. hash: remoteModelInfo.modelHash),
  216. downloadErrorCode: .httpError(code: response.statusCode)
  217. )
  218. completion(.failure(.permissionDenied))
  219. case 404:
  220. DeviceLogger.logEvent(level: .debug,
  221. message: ModelDownloadTask.ErrorDescription
  222. .modelNotFound(remoteModelInfo.name),
  223. messageCode: .modelNotFound)
  224. telemetryLogger?.logModelDownloadEvent(
  225. eventName: .modelDownload,
  226. status: .failed,
  227. model: CustomModel(name: remoteModelInfo.name,
  228. size: remoteModelInfo.size,
  229. path: "",
  230. hash: remoteModelInfo.modelHash),
  231. downloadErrorCode: .httpError(code: response.statusCode)
  232. )
  233. completion(.failure(.notFound))
  234. default:
  235. let description = ModelDownloadTask.ErrorDescription
  236. .modelDownloadFailed(response.statusCode)
  237. DeviceLogger.logEvent(level: .debug,
  238. message: description,
  239. messageCode: .modelDownloadError)
  240. telemetryLogger?.logModelDownloadEvent(
  241. eventName: .modelDownload,
  242. status: .failed,
  243. model: CustomModel(name: remoteModelInfo.name,
  244. size: remoteModelInfo.size,
  245. path: "",
  246. hash: remoteModelInfo.modelHash),
  247. downloadErrorCode: .httpError(code: response.statusCode)
  248. )
  249. completion(.failure(.internalError(description: description)))
  250. }
  251. return
  252. }
  253. /// Construct local model file URL.
  254. guard var modelFileURL = ModelFileManager.getDownloadedModelFileURL(
  255. appName: appName,
  256. modelName: remoteModelInfo.name
  257. ) else {
  258. // Could not create Application Support directory to store model files.
  259. let description = ModelDownloadTask.ErrorDescription.noModelsDirectory
  260. DeviceLogger.logEvent(level: .debug,
  261. message: description,
  262. messageCode: .downloadedModelSaveError)
  263. // Downloading the file succeeding but saving failed.
  264. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  265. status: .succeeded,
  266. model: CustomModel(name: remoteModelInfo.name,
  267. size: remoteModelInfo.size,
  268. path: "",
  269. hash: remoteModelInfo.modelHash),
  270. downloadErrorCode: .downloadFailed)
  271. completion(.failure(.internalError(description: description)))
  272. return
  273. }
  274. do {
  275. // Try disabling iCloud backup for model files, because UserDefaults is not backed up.
  276. var resourceValue = URLResourceValues()
  277. resourceValue.isExcludedFromBackup = true
  278. do {
  279. try modelFileURL.setResourceValues(resourceValue)
  280. } catch {
  281. DeviceLogger.logEvent(level: .debug,
  282. message: ModelDownloadTask.ErrorDescription.disableBackupError,
  283. messageCode: .disableBackupError)
  284. }
  285. // Save model file to device.
  286. try ModelFileManager.moveFile(
  287. at: tempURL,
  288. to: modelFileURL,
  289. size: Int64(remoteModelInfo.size)
  290. )
  291. DeviceLogger.logEvent(level: .debug,
  292. message: ModelDownloadTask.DebugDescription.savedModelFile,
  293. messageCode: .downloadedModelFileSaved)
  294. // Generate local model info.
  295. let localModelInfo = LocalModelInfo(from: remoteModelInfo)
  296. // Write model info to user defaults.
  297. localModelInfo.writeToDefaults(defaults, appName: appName)
  298. DeviceLogger.logEvent(level: .debug,
  299. message: ModelDownloadTask.DebugDescription.savedLocalModelInfo,
  300. messageCode: .downloadedModelInfoSaved)
  301. // Build model from model info and local path.
  302. let model = CustomModel(localModelInfo: localModelInfo, path: modelFileURL.path)
  303. DeviceLogger.logEvent(level: .debug,
  304. message: ModelDownloadTask.DebugDescription.modelDownloaded,
  305. messageCode: .modelDownloaded)
  306. telemetryLogger?.logModelDownloadEvent(
  307. eventName: .modelDownload,
  308. status: .succeeded,
  309. model: model,
  310. downloadErrorCode: .noError
  311. )
  312. completion(.success(model))
  313. } catch let error as DownloadError {
  314. if error == .notEnoughSpace {
  315. DeviceLogger.logEvent(level: .debug,
  316. message: ModelDownloadTask.ErrorDescription.notEnoughSpace,
  317. messageCode: .notEnoughSpace)
  318. } else {
  319. DeviceLogger.logEvent(level: .debug,
  320. message: ModelDownloadTask.ErrorDescription.modelSaveError,
  321. messageCode: .downloadedModelSaveError)
  322. }
  323. // Downloading the file succeeding but saving failed.
  324. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  325. status: .succeeded,
  326. model: CustomModel(name: remoteModelInfo.name,
  327. size: remoteModelInfo.size,
  328. path: "",
  329. hash: remoteModelInfo.modelHash),
  330. downloadErrorCode: .downloadFailed)
  331. completion(.failure(error))
  332. return
  333. } catch {
  334. DeviceLogger.logEvent(level: .debug,
  335. message: ModelDownloadTask.ErrorDescription.modelSaveError,
  336. messageCode: .downloadedModelSaveError)
  337. // Downloading the file succeeding but saving failed.
  338. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  339. status: .succeeded,
  340. model: CustomModel(name: remoteModelInfo.name,
  341. size: remoteModelInfo.size,
  342. path: "",
  343. hash: remoteModelInfo.modelHash),
  344. downloadErrorCode: .downloadFailed)
  345. completion(.failure(.internalError(description: error.localizedDescription)))
  346. return
  347. }
  348. }
  349. }
  350. /// Possible states of model downloading.
  351. enum ModelDownloadStatus {
  352. case ready
  353. case downloading
  354. case complete
  355. }
  356. /// Download error codes.
  357. enum ModelDownloadErrorCode {
  358. case noError
  359. case urlExpired
  360. case noConnection
  361. case downloadFailed
  362. case httpError(code: Int)
  363. }
  364. /// Possible debug and error messages for model downloading.
  365. extension ModelDownloadTask {
  366. /// Debug descriptions.
  367. private enum DebugDescription {
  368. static let receivedServerResponse = "Received a valid response from download server."
  369. static let savedModelFile = "Model file saved successfully to device."
  370. static let savedLocalModelInfo = "Downloaded model info saved successfully to user defaults."
  371. static let modelDownloaded = "Model download completed successfully."
  372. }
  373. /// Error descriptions.
  374. private enum ErrorDescription {
  375. static let invalidHostName = { (error: String) in
  376. "Unable to resolve hostname or connect to host: \(error)"
  377. }
  378. static let sessionInvalidated = "Session invalidated due to failed pre-conditions."
  379. static let invalidHTTPResponse = "Could not get valid HTTP response for model downloading."
  380. static let anotherDownloadInProgress = "Download already in progress."
  381. static let modelNotFound = { (name: String) in
  382. "No model found with name: \(name)"
  383. }
  384. static let invalidArgument = { (name: String) in
  385. "Invalid argument for model name: \(name)"
  386. }
  387. static let expiredModelInfo = "Unable to update expired model info."
  388. static let permissionDenied = "Invalid or missing permissions to download model."
  389. static let notEnoughSpace = "Not enough space on device."
  390. static let disableBackupError = "Unable to disable model file backup."
  391. static let noModelsDirectory = "Could not create directory for model storage."
  392. static let modelSaveError = "Unable to save downloaded remote model file."
  393. static let unknownDownloadError = "Unable to download model due to unknown error."
  394. static let modelDownloadFailed = { (code: Int) in
  395. "Model download failed with HTTP error code: \(code)"
  396. }
  397. }
  398. }