ModelDownloadTask.swift 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. #if SWIFT_PACKAGE
  16. internal import GoogleUtilities_UserDefaults
  17. #else
  18. internal import GoogleUtilities
  19. #endif // SWIFT_PACKAGE
  20. /// Task to download model file to device.
  21. class ModelDownloadTask {
  22. /// Name of the app associated with this instance of ModelDownloadTask.
  23. private let appName: String
  24. /// Model info downloaded from server.
  25. private(set) var remoteModelInfo: RemoteModelInfo
  26. /// User defaults to which local model info should ultimately be written.
  27. private let defaults: GULUserDefaults
  28. /// Keeps track of download associated with this model download task.
  29. private(set) var downloadStatus: ModelDownloadStatus = .ready
  30. /// Downloader instance.
  31. private let downloader: FileDownloader
  32. /// Telemetry logger.
  33. private let telemetryLogger: TelemetryLogger?
  34. /// Download progress handler.
  35. typealias ProgressHandler = (Float) -> Void
  36. private var progressHandler: ProgressHandler?
  37. /// Download completion handler.
  38. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  39. private var completion: Completion
  40. init(remoteModelInfo: RemoteModelInfo,
  41. appName: String,
  42. defaults: GULUserDefaults,
  43. downloader: FileDownloader,
  44. progressHandler: ProgressHandler? = nil,
  45. completion: @escaping Completion,
  46. telemetryLogger: TelemetryLogger? = nil) {
  47. self.remoteModelInfo = remoteModelInfo
  48. self.appName = appName
  49. self.defaults = defaults
  50. self.downloader = downloader
  51. self.progressHandler = progressHandler
  52. self.completion = completion
  53. self.telemetryLogger = telemetryLogger
  54. }
  55. }
  56. extension ModelDownloadTask {
  57. /// Check if requests can be merged.
  58. func canMergeRequests() -> Bool {
  59. return downloadStatus != .complete
  60. }
  61. /// Merge duplicate requests. This method is not thread-safe.
  62. func merge(newProgressHandler: ProgressHandler? = nil, newCompletion: @escaping Completion) {
  63. let originalProgressHandler = progressHandler
  64. progressHandler = { progress in
  65. originalProgressHandler?(progress)
  66. newProgressHandler?(progress)
  67. }
  68. let originalCompletion = completion
  69. completion = { result in
  70. originalCompletion(result)
  71. newCompletion(result)
  72. }
  73. }
  74. /// Check if download task can be resumed.
  75. func canResume() -> Bool {
  76. return downloadStatus == .ready
  77. }
  78. /// Download model file.
  79. func resume() {
  80. // Prevent multiple concurrent downloads.
  81. guard downloadStatus != .downloading else {
  82. DeviceLogger.logEvent(level: .debug,
  83. message: ModelDownloadTask.ErrorDescription.anotherDownloadInProgress,
  84. messageCode: .anotherDownloadInProgressError)
  85. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  86. status: .failed,
  87. model: CustomModel(name: remoteModelInfo.name,
  88. size: remoteModelInfo.size,
  89. path: "",
  90. hash: remoteModelInfo.modelHash),
  91. downloadErrorCode: .downloadFailed)
  92. return
  93. }
  94. downloadStatus = .downloading
  95. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  96. status: .downloading,
  97. model: CustomModel(name: remoteModelInfo.name,
  98. size: remoteModelInfo.size,
  99. path: "",
  100. hash: remoteModelInfo.modelHash),
  101. downloadErrorCode: .noError)
  102. downloader.downloadFile(with: remoteModelInfo.downloadURL,
  103. progressHandler: { downloadedBytes, totalBytes in
  104. /// Fraction of model file downloaded.
  105. let calculatedProgress = Float(downloadedBytes) / Float(totalBytes)
  106. self.progressHandler?(calculatedProgress)
  107. }) { result in
  108. self.downloadStatus = .complete
  109. switch result {
  110. case let .success(response):
  111. DeviceLogger.logEvent(level: .debug,
  112. message: ModelDownloadTask.DebugDescription
  113. .receivedServerResponse,
  114. messageCode: .validHTTPResponse)
  115. self.handleResponse(
  116. response: response.urlResponse,
  117. tempURL: response.fileURL,
  118. completion: self.completion
  119. )
  120. case let .failure(error):
  121. var downloadError: DownloadError
  122. switch error {
  123. case let FileDownloaderError.networkError(error):
  124. let description = ModelDownloadTask.ErrorDescription
  125. .invalidHostName(error.localizedDescription)
  126. downloadError = .failedPrecondition
  127. DeviceLogger.logEvent(level: .debug,
  128. message: description,
  129. messageCode: .hostnameError)
  130. self.telemetryLogger?.logModelDownloadEvent(
  131. eventName: .modelDownload,
  132. status: .failed,
  133. model: CustomModel(name: self.remoteModelInfo.name,
  134. size: self.remoteModelInfo.size,
  135. path: "",
  136. hash: self.remoteModelInfo.modelHash),
  137. downloadErrorCode: .noConnection
  138. )
  139. case FileDownloaderError.unexpectedResponseType:
  140. let description = ModelDownloadTask.ErrorDescription.invalidHTTPResponse
  141. downloadError = .internalError(description: description)
  142. DeviceLogger.logEvent(level: .debug,
  143. message: description,
  144. messageCode: .invalidHTTPResponse)
  145. self.telemetryLogger?.logModelDownloadEvent(
  146. eventName: .modelDownload,
  147. status: .failed,
  148. model: CustomModel(name: self.remoteModelInfo.name,
  149. size: self.remoteModelInfo.size,
  150. path: "",
  151. hash: self.remoteModelInfo.modelHash),
  152. downloadErrorCode: .downloadFailed
  153. )
  154. default:
  155. let description = ModelDownloadTask.ErrorDescription.unknownDownloadError
  156. downloadError = .internalError(description: description)
  157. DeviceLogger.logEvent(level: .debug,
  158. message: description,
  159. messageCode: .modelDownloadError)
  160. self.telemetryLogger?.logModelDownloadEvent(
  161. eventName: .modelDownload,
  162. status: .failed,
  163. model: CustomModel(name: self.remoteModelInfo.name,
  164. size: self.remoteModelInfo.size,
  165. path: "",
  166. hash: self.remoteModelInfo.modelHash),
  167. downloadErrorCode: .downloadFailed
  168. )
  169. }
  170. self.completion(.failure(downloadError))
  171. }
  172. }
  173. }
  174. /// Handle model download response.
  175. func handleResponse(response: HTTPURLResponse, tempURL: URL, completion: @escaping Completion) {
  176. guard (200 ..< 299).contains(response.statusCode) else {
  177. switch response.statusCode {
  178. case 400:
  179. // Possible failure due to download URL expiry. Check if download URL has expired.
  180. guard remoteModelInfo.urlExpiryTime < Date() else {
  181. DeviceLogger.logEvent(level: .debug,
  182. message: ModelDownloadTask.ErrorDescription
  183. .invalidArgument(remoteModelInfo.name),
  184. messageCode: .invalidArgument)
  185. telemetryLogger?.logModelDownloadEvent(
  186. eventName: .modelDownload,
  187. status: .failed,
  188. model: CustomModel(name: remoteModelInfo.name,
  189. size: remoteModelInfo.size,
  190. path: "",
  191. hash: remoteModelInfo.modelHash),
  192. downloadErrorCode: .httpError(code: response.statusCode)
  193. )
  194. completion(.failure(.invalidArgument))
  195. return
  196. }
  197. DeviceLogger.logEvent(level: .debug,
  198. message: ModelDownloadTask.ErrorDescription.expiredModelInfo,
  199. messageCode: .expiredModelInfo)
  200. telemetryLogger?.logModelDownloadEvent(
  201. eventName: .modelDownload,
  202. status: .failed,
  203. model: CustomModel(name: remoteModelInfo.name,
  204. size: remoteModelInfo.size,
  205. path: "",
  206. hash: remoteModelInfo.modelHash),
  207. downloadErrorCode: .urlExpired
  208. )
  209. completion(.failure(.expiredDownloadURL))
  210. case 401, 403:
  211. DeviceLogger.logEvent(level: .debug,
  212. message: ModelDownloadTask.ErrorDescription.permissionDenied,
  213. messageCode: .permissionDenied)
  214. telemetryLogger?.logModelDownloadEvent(
  215. eventName: .modelDownload,
  216. status: .failed,
  217. model: CustomModel(name: remoteModelInfo.name,
  218. size: remoteModelInfo.size,
  219. path: "",
  220. hash: remoteModelInfo.modelHash),
  221. downloadErrorCode: .httpError(code: response.statusCode)
  222. )
  223. completion(.failure(.permissionDenied))
  224. case 404:
  225. DeviceLogger.logEvent(level: .debug,
  226. message: ModelDownloadTask.ErrorDescription
  227. .modelNotFound(remoteModelInfo.name),
  228. messageCode: .modelNotFound)
  229. telemetryLogger?.logModelDownloadEvent(
  230. eventName: .modelDownload,
  231. status: .failed,
  232. model: CustomModel(name: remoteModelInfo.name,
  233. size: remoteModelInfo.size,
  234. path: "",
  235. hash: remoteModelInfo.modelHash),
  236. downloadErrorCode: .httpError(code: response.statusCode)
  237. )
  238. completion(.failure(.notFound))
  239. default:
  240. let description = ModelDownloadTask.ErrorDescription
  241. .modelDownloadFailed(response.statusCode)
  242. DeviceLogger.logEvent(level: .debug,
  243. message: description,
  244. messageCode: .modelDownloadError)
  245. telemetryLogger?.logModelDownloadEvent(
  246. eventName: .modelDownload,
  247. status: .failed,
  248. model: CustomModel(name: remoteModelInfo.name,
  249. size: remoteModelInfo.size,
  250. path: "",
  251. hash: remoteModelInfo.modelHash),
  252. downloadErrorCode: .httpError(code: response.statusCode)
  253. )
  254. completion(.failure(.internalError(description: description)))
  255. }
  256. return
  257. }
  258. /// Construct local model file URL.
  259. guard var modelFileURL = ModelFileManager.getDownloadedModelFileURL(
  260. appName: appName,
  261. modelName: remoteModelInfo.name
  262. ) else {
  263. // Could not create Application Support directory to store model files.
  264. let description = ModelDownloadTask.ErrorDescription.noModelsDirectory
  265. DeviceLogger.logEvent(level: .debug,
  266. message: description,
  267. messageCode: .downloadedModelSaveError)
  268. // Downloading the file succeeding but saving failed.
  269. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  270. status: .succeeded,
  271. model: CustomModel(name: remoteModelInfo.name,
  272. size: remoteModelInfo.size,
  273. path: "",
  274. hash: remoteModelInfo.modelHash),
  275. downloadErrorCode: .downloadFailed)
  276. completion(.failure(.internalError(description: description)))
  277. return
  278. }
  279. do {
  280. // Try disabling iCloud backup for model files, because UserDefaults is not backed up.
  281. var resourceValue = URLResourceValues()
  282. resourceValue.isExcludedFromBackup = true
  283. do {
  284. try modelFileURL.setResourceValues(resourceValue)
  285. } catch {
  286. DeviceLogger.logEvent(level: .debug,
  287. message: ModelDownloadTask.ErrorDescription.disableBackupError,
  288. messageCode: .disableBackupError)
  289. }
  290. // Save model file to device.
  291. try ModelFileManager.moveFile(
  292. at: tempURL,
  293. to: modelFileURL,
  294. size: Int64(remoteModelInfo.size)
  295. )
  296. DeviceLogger.logEvent(level: .debug,
  297. message: ModelDownloadTask.DebugDescription.savedModelFile,
  298. messageCode: .downloadedModelFileSaved)
  299. // Generate local model info.
  300. let localModelInfo = LocalModelInfo(from: remoteModelInfo)
  301. // Write model info to user defaults.
  302. localModelInfo.writeToDefaults(defaults, appName: appName)
  303. DeviceLogger.logEvent(level: .debug,
  304. message: ModelDownloadTask.DebugDescription.savedLocalModelInfo,
  305. messageCode: .downloadedModelInfoSaved)
  306. // Build model from model info and local path.
  307. let model = CustomModel(localModelInfo: localModelInfo, path: modelFileURL.path)
  308. DeviceLogger.logEvent(level: .debug,
  309. message: ModelDownloadTask.DebugDescription.modelDownloaded,
  310. messageCode: .modelDownloaded)
  311. telemetryLogger?.logModelDownloadEvent(
  312. eventName: .modelDownload,
  313. status: .succeeded,
  314. model: model,
  315. downloadErrorCode: .noError
  316. )
  317. completion(.success(model))
  318. } catch let error as DownloadError {
  319. if error == .notEnoughSpace {
  320. DeviceLogger.logEvent(level: .debug,
  321. message: ModelDownloadTask.ErrorDescription.notEnoughSpace,
  322. messageCode: .notEnoughSpace)
  323. } else {
  324. DeviceLogger.logEvent(level: .debug,
  325. message: ModelDownloadTask.ErrorDescription.modelSaveError,
  326. messageCode: .downloadedModelSaveError)
  327. }
  328. // Downloading the file succeeding but saving failed.
  329. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  330. status: .succeeded,
  331. model: CustomModel(name: remoteModelInfo.name,
  332. size: remoteModelInfo.size,
  333. path: "",
  334. hash: remoteModelInfo.modelHash),
  335. downloadErrorCode: .downloadFailed)
  336. completion(.failure(error))
  337. return
  338. } catch {
  339. DeviceLogger.logEvent(level: .debug,
  340. message: ModelDownloadTask.ErrorDescription.modelSaveError,
  341. messageCode: .downloadedModelSaveError)
  342. // Downloading the file succeeding but saving failed.
  343. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  344. status: .succeeded,
  345. model: CustomModel(name: remoteModelInfo.name,
  346. size: remoteModelInfo.size,
  347. path: "",
  348. hash: remoteModelInfo.modelHash),
  349. downloadErrorCode: .downloadFailed)
  350. completion(.failure(.internalError(description: error.localizedDescription)))
  351. return
  352. }
  353. }
  354. }
  355. /// Possible states of model downloading.
  356. enum ModelDownloadStatus {
  357. case ready
  358. case downloading
  359. case complete
  360. }
  361. /// Download error codes.
  362. enum ModelDownloadErrorCode {
  363. case noError
  364. case urlExpired
  365. case noConnection
  366. case downloadFailed
  367. case httpError(code: Int)
  368. }
  369. /// Possible debug and error messages for model downloading.
  370. extension ModelDownloadTask {
  371. /// Debug descriptions.
  372. private enum DebugDescription {
  373. static let receivedServerResponse = "Received a valid response from download server."
  374. static let savedModelFile = "Model file saved successfully to device."
  375. static let savedLocalModelInfo = "Downloaded model info saved successfully to user defaults."
  376. static let modelDownloaded = "Model download completed successfully."
  377. }
  378. /// Error descriptions.
  379. private enum ErrorDescription {
  380. static let invalidHostName = { (error: String) in
  381. "Unable to resolve hostname or connect to host: \(error)"
  382. }
  383. static let sessionInvalidated = "Session invalidated due to failed pre-conditions."
  384. static let invalidHTTPResponse = "Could not get valid HTTP response for model downloading."
  385. static let anotherDownloadInProgress = "Download already in progress."
  386. static let modelNotFound = { (name: String) in
  387. "No model found with name: \(name)"
  388. }
  389. static let invalidArgument = { (name: String) in
  390. "Invalid argument for model name: \(name)"
  391. }
  392. static let expiredModelInfo = "Unable to update expired model info."
  393. static let permissionDenied = "Invalid or missing permissions to download model."
  394. static let notEnoughSpace = "Not enough space on device."
  395. static let disableBackupError = "Unable to disable model file backup."
  396. static let noModelsDirectory = "Could not create directory for model storage."
  397. static let modelSaveError = "Unable to save downloaded remote model file."
  398. static let unknownDownloadError = "Unable to download model due to unknown error."
  399. static let modelDownloadFailed = { (code: Int) in
  400. "Model download failed with HTTP error code: \(code)"
  401. }
  402. }
  403. }