ModelDownloadTask.swift 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. /// Possible states of model downloading.
  17. enum ModelDownloadStatus {
  18. case ready
  19. case downloading
  20. case complete
  21. }
  22. /// Download error codes.
  23. enum ModelDownloadErrorCode {
  24. case noError
  25. case urlExpired
  26. case noConnection
  27. case downloadFailed
  28. case httpError(code: Int)
  29. }
  30. /// Manager to handle model downloading device and storing downloaded model info to persistent storage.
  31. class ModelDownloadTask: NSObject {
  32. typealias ProgressHandler = (Float) -> Void
  33. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  34. /// Name of the app associated with this instance of ModelDownloadTask.
  35. private let appName: String
  36. /// Model info downloaded from server.
  37. private(set) var remoteModelInfo: RemoteModelInfo
  38. /// User defaults to which local model info should ultimately be written.
  39. private let defaults: UserDefaults
  40. /// Keeps track of download associated with this model download task.
  41. private(set) var downloadStatus: ModelDownloadStatus = .ready
  42. /// Downloader instance.
  43. private let downloader: FileDownloader
  44. /// Telemetry logger.
  45. private let telemetryLogger: TelemetryLogger?
  46. /// Progress handler.
  47. private var progressHandler: ProgressHandler?
  48. /// Completion.
  49. private var completion: Completion
  50. init(remoteModelInfo: RemoteModelInfo,
  51. appName: String,
  52. defaults: UserDefaults,
  53. downloader: FileDownloader,
  54. progressHandler: ProgressHandler? = nil,
  55. completion: @escaping Completion,
  56. telemetryLogger: TelemetryLogger? = nil) {
  57. self.remoteModelInfo = remoteModelInfo
  58. self.appName = appName
  59. self.downloader = downloader
  60. self.progressHandler = progressHandler
  61. self.completion = completion
  62. self.telemetryLogger = telemetryLogger
  63. self.defaults = defaults
  64. }
  65. }
  66. extension ModelDownloadTask {
  67. /// Name for model file stored on device.
  68. var downloadedModelFileName: String {
  69. return "fbml_model__\(appName)__\(remoteModelInfo.name).tflite"
  70. }
  71. /// Check if downloading is not complete for merging requests.
  72. func canMergeRequests() -> Bool {
  73. return downloadStatus != .complete
  74. }
  75. /// Check if download task can be resumed.
  76. func canResume() -> Bool {
  77. return downloadStatus == .ready
  78. }
  79. /// Merge duplicate requests. This method is not thread-safe.
  80. func merge(newProgressHandler: ProgressHandler? = nil, newCompletion: @escaping Completion) {
  81. let originalProgressHandler = progressHandler
  82. progressHandler = { progress in
  83. originalProgressHandler?(progress)
  84. newProgressHandler?(progress)
  85. }
  86. let originalCompletion = completion
  87. completion = { result in
  88. originalCompletion(result)
  89. newCompletion(result)
  90. }
  91. }
  92. func resume() {
  93. /// Prevent multiple concurrent downloads.
  94. guard downloadStatus != .downloading else {
  95. DeviceLogger.logEvent(level: .debug,
  96. message: ModelDownloadTask.ErrorDescription.anotherDownloadInProgress,
  97. messageCode: .anotherDownloadInProgressError)
  98. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  99. status: .failed,
  100. downloadErrorCode: .downloadFailed)
  101. return
  102. }
  103. downloadStatus = .downloading
  104. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  105. status: .downloading,
  106. downloadErrorCode: .noError)
  107. downloader.downloadFile(with: remoteModelInfo.downloadURL,
  108. progressHandler: { downloadedBytes, totalBytes in
  109. /// Fraction of model file downloaded.
  110. let calculatedProgress = Float(downloadedBytes) / Float(totalBytes)
  111. self.progressHandler?(calculatedProgress)
  112. }) { result in
  113. self.downloadStatus = .complete
  114. switch result {
  115. case let .success(response):
  116. DeviceLogger.logEvent(level: .debug,
  117. message: ModelDownloadTask.DebugDescription
  118. .receivedServerResponse,
  119. messageCode: .validHTTPResponse)
  120. self.handleResponse(
  121. response: response.urlResponse,
  122. tempURL: response.fileURL,
  123. completion: self.completion
  124. )
  125. case let .failure(error):
  126. var downloadError: DownloadError
  127. switch error {
  128. case let FileDownloaderError.networkError(error):
  129. let description = ModelDownloadTask.ErrorDescription
  130. .invalidHostName(error.localizedDescription)
  131. downloadError = .internalError(description: description)
  132. DeviceLogger.logEvent(level: .debug,
  133. message: description,
  134. messageCode: .hostnameError)
  135. self.telemetryLogger?.logModelDownloadEvent(
  136. eventName: .modelDownload,
  137. status: .failed,
  138. downloadErrorCode: .noConnection
  139. )
  140. case FileDownloaderError.unexpectedResponseType:
  141. let description = ModelDownloadTask.ErrorDescription.invalidHTTPResponse
  142. downloadError = .internalError(description: description)
  143. DeviceLogger.logEvent(level: .debug,
  144. message: description,
  145. messageCode: .invalidHTTPResponse)
  146. self.telemetryLogger?.logModelDownloadEvent(
  147. eventName: .modelDownload,
  148. status: .failed,
  149. downloadErrorCode: .downloadFailed
  150. )
  151. default:
  152. let description = ModelDownloadTask.ErrorDescription.unknownDownloadError
  153. downloadError = .internalError(description: description)
  154. DeviceLogger.logEvent(level: .debug,
  155. message: description,
  156. messageCode: .modelDownloadError)
  157. self.telemetryLogger?.logModelDownloadEvent(
  158. eventName: .modelDownload,
  159. status: .failed,
  160. downloadErrorCode: .downloadFailed
  161. )
  162. }
  163. self.completion(.failure(downloadError))
  164. }
  165. }
  166. }
  167. /// Handle model download response.
  168. func handleResponse(response: HTTPURLResponse, tempURL: URL, completion: @escaping Completion) {
  169. guard (200 ..< 299).contains(response.statusCode) else {
  170. switch response.statusCode {
  171. /// Possible failure due to download URL expiry.
  172. case 400:
  173. let currentDateTime = Date()
  174. /// Check if download url has expired.
  175. guard currentDateTime > remoteModelInfo.urlExpiryTime else {
  176. DeviceLogger.logEvent(level: .debug,
  177. message: ModelDownloadTask.ErrorDescription
  178. .invalidModelName(remoteModelInfo.name),
  179. messageCode: .invalidModelName)
  180. telemetryLogger?.logModelDownloadEvent(
  181. eventName: .modelDownload,
  182. status: .failed,
  183. downloadErrorCode: .httpError(code: response.statusCode)
  184. )
  185. completion(.failure(.invalidArgument))
  186. return
  187. }
  188. DeviceLogger.logEvent(level: .debug,
  189. message: ModelDownloadTask.ErrorDescription.expiredModelInfo,
  190. messageCode: .expiredModelInfo)
  191. telemetryLogger?.logModelDownloadEvent(
  192. eventName: .modelDownload,
  193. status: .failed,
  194. downloadErrorCode: .urlExpired
  195. )
  196. completion(.failure(.expiredDownloadURL))
  197. case 401, 403:
  198. DeviceLogger.logEvent(level: .debug,
  199. message: ModelDownloadTask.ErrorDescription.permissionDenied,
  200. messageCode: .permissionDenied)
  201. telemetryLogger?.logModelDownloadEvent(
  202. eventName: .modelDownload,
  203. status: .failed,
  204. downloadErrorCode: .httpError(code: response.statusCode)
  205. )
  206. completion(.failure(.permissionDenied))
  207. case 404:
  208. DeviceLogger.logEvent(level: .debug,
  209. message: ModelDownloadTask.ErrorDescription
  210. .modelNotFound(remoteModelInfo.name),
  211. messageCode: .modelNotFound)
  212. telemetryLogger?.logModelDownloadEvent(
  213. eventName: .modelDownload,
  214. status: .failed,
  215. downloadErrorCode: .httpError(code: response.statusCode)
  216. )
  217. completion(.failure(.notFound))
  218. default:
  219. let description = ModelDownloadTask.ErrorDescription
  220. .modelDownloadFailed(response.statusCode)
  221. DeviceLogger.logEvent(level: .debug,
  222. message: description,
  223. messageCode: .modelDownloadError)
  224. telemetryLogger?.logModelDownloadEvent(
  225. eventName: .modelDownload,
  226. status: .failed,
  227. downloadErrorCode: .httpError(code: response.statusCode)
  228. )
  229. completion(.failure(.internalError(description: description)))
  230. }
  231. return
  232. }
  233. let modelFileURL = ModelFileManager.getDownloadedModelFilePath(
  234. appName: appName,
  235. modelName: remoteModelInfo.name
  236. )
  237. do {
  238. try ModelFileManager.moveFile(
  239. at: tempURL,
  240. to: modelFileURL,
  241. size: Int64(remoteModelInfo.size)
  242. )
  243. DeviceLogger.logEvent(level: .debug,
  244. message: ModelDownloadTask.DebugDescription.savedModelFile,
  245. messageCode: .downloadedModelFileSaved)
  246. /// Generate local model info.
  247. let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: modelFileURL.absoluteString)
  248. /// Write model to user defaults.
  249. localModelInfo.writeToDefaults(defaults, appName: appName)
  250. DeviceLogger.logEvent(level: .debug,
  251. message: ModelDownloadTask.DebugDescription.savedLocalModelInfo,
  252. messageCode: .downloadedModelInfoSaved)
  253. /// Build model from model info.
  254. let model = CustomModel(localModelInfo: localModelInfo)
  255. DeviceLogger.logEvent(level: .debug,
  256. message: ModelDownloadTask.DebugDescription.modelDownloaded,
  257. messageCode: .modelDownloaded)
  258. telemetryLogger?.logModelDownloadEvent(
  259. eventName: .modelDownload,
  260. status: .succeeded,
  261. model: model,
  262. downloadErrorCode: .noError
  263. )
  264. completion(.success(model))
  265. } catch let error as DownloadError {
  266. if error == .notEnoughSpace {
  267. DeviceLogger.logEvent(level: .debug,
  268. message: ModelDownloadTask.ErrorDescription.notEnoughSpace,
  269. messageCode: .notEnoughSpace)
  270. } else {
  271. DeviceLogger.logEvent(level: .debug,
  272. message: ModelDownloadTask.ErrorDescription.saveModel,
  273. messageCode: .downloadedModelSaveError)
  274. }
  275. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  276. status: .succeeded,
  277. downloadErrorCode: .downloadFailed)
  278. completion(.failure(error))
  279. return
  280. } catch {
  281. DeviceLogger.logEvent(level: .debug,
  282. message: ModelDownloadTask.ErrorDescription.saveModel,
  283. messageCode: .downloadedModelSaveError)
  284. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  285. status: .succeeded,
  286. downloadErrorCode: .downloadFailed)
  287. completion(.failure(.internalError(description: error.localizedDescription)))
  288. return
  289. }
  290. }
  291. }
  292. /// Possible error messages for model downloading.
  293. extension ModelDownloadTask {
  294. /// Debug descriptions.
  295. private enum DebugDescription {
  296. static let savedModelFile = "Model file saved successfully to device."
  297. static let savedLocalModelInfo = "Downloaded model info saved successfully to user defaults."
  298. static let receivedServerResponse = "Received a valid response from download server."
  299. static let modelDownloaded = "Model download completed successfully."
  300. }
  301. /// Error descriptions.
  302. private enum ErrorDescription {
  303. static let invalidHostName = { (error: String) in
  304. "Unable to resolve hostname or connect to host: \(error)"
  305. }
  306. static let modelDownloadFailed = { (code: Int) in
  307. "Model download failed with HTTP error code: \(code)"
  308. }
  309. static let modelNotFound = { (name: String) in
  310. "No model found with name: \(name)"
  311. }
  312. static let invalidModelName = { (name: String) in
  313. "Invalid model name: \(name)"
  314. }
  315. static let sessionInvalidated = "Session invalidated due to failed pre-conditions."
  316. static let invalidHTTPResponse =
  317. "Could not get valid HTTP response for model downloading."
  318. static let unknownDownloadError = "Unable to download model due to unknown error."
  319. static let saveModel = "Unable to save downloaded remote model file."
  320. static let notEnoughSpace = "Not enough space on device."
  321. static let expiredModelInfo = "Unable to update expired model info."
  322. static let anotherDownloadInProgress = "Download already in progress."
  323. static let permissionDenied = "Invalid or missing permissions to download model."
  324. }
  325. }