ModelDownloadTask.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. /// Possible states of model downloading.
  17. enum ModelDownloadStatus {
  18. case notStarted
  19. case inProgress
  20. case complete
  21. }
  22. /// Download error codes.
  23. enum ModelDownloadErrorCode {
  24. case noError
  25. case urlExpired
  26. case noConnection
  27. case downloadFailed
  28. case httpError(code: Int)
  29. }
  30. /// Manager to handle model downloading device and storing downloaded model info to persistent storage.
  31. class ModelDownloadTask: NSObject {
  32. typealias ProgressHandler = (Float) -> Void
  33. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  34. /// Name of the app associated with this instance of ModelDownloadTask.
  35. private let appName: String
  36. /// Model info downloaded from server.
  37. private(set) var remoteModelInfo: RemoteModelInfo
  38. /// User defaults to which local model info should ultimately be written.
  39. private let defaults: UserDefaults
  40. /// Keeps track of download associated with this model download task.
  41. private(set) var downloadStatus: ModelDownloadStatus = .notStarted
  42. /// Downloader instance.
  43. private let downloader: FileDownloader
  44. /// Telemetry logger.
  45. private let telemetryLogger: TelemetryLogger?
  46. init(remoteModelInfo: RemoteModelInfo,
  47. appName: String,
  48. defaults: UserDefaults,
  49. downloader: FileDownloader,
  50. telemetryLogger: TelemetryLogger? = nil) {
  51. self.remoteModelInfo = remoteModelInfo
  52. self.appName = appName
  53. self.downloader = downloader
  54. self.telemetryLogger = telemetryLogger
  55. self.defaults = defaults
  56. }
  57. }
  58. extension ModelDownloadTask {
  59. /// Name for model file stored on device.
  60. var downloadedModelFileName: String {
  61. return "fbml_model__\(appName)__\(remoteModelInfo.name).tflite"
  62. }
  63. func download(progressHandler: ProgressHandler?, completion: @escaping Completion) {
  64. /// Prevent multiple concurrent downloads.
  65. guard downloadStatus != .inProgress else {
  66. DeviceLogger.logEvent(level: .debug,
  67. message: ModelDownloadTask.ErrorDescription.anotherDownloadInProgress,
  68. messageCode: .anotherDownloadInProgressError)
  69. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  70. status: .failed,
  71. downloadErrorCode: .downloadFailed)
  72. completion(.failure(.internalError(description: ModelDownloadTask.ErrorDescription
  73. .anotherDownloadInProgress)))
  74. return
  75. }
  76. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  77. status: .downloading,
  78. downloadErrorCode: .noError)
  79. downloadStatus = .inProgress
  80. downloader.downloadFile(with: remoteModelInfo.downloadURL,
  81. progressHandler: { downloadedBytes, totalBytes in
  82. /// Fraction of model file downloaded.
  83. let calculatedProgress = Float(downloadedBytes) / Float(totalBytes)
  84. progressHandler?(calculatedProgress)
  85. }) { result in
  86. self.downloadStatus = .complete
  87. switch result {
  88. case let .success(response):
  89. DeviceLogger.logEvent(level: .debug,
  90. message: ModelDownloadTask.DebugDescription
  91. .receivedServerResponse,
  92. messageCode: .validHTTPResponse)
  93. self.handleResponse(
  94. response: response.urlResponse,
  95. tempURL: response.fileURL,
  96. completion: completion
  97. )
  98. case let .failure(error):
  99. var downloadError: DownloadError
  100. switch error {
  101. case let FileDownloaderError.networkError(error):
  102. let description = ModelDownloadTask.ErrorDescription
  103. .invalidHostName(error.localizedDescription)
  104. downloadError = .internalError(description: description)
  105. DeviceLogger.logEvent(level: .debug,
  106. message: description,
  107. messageCode: .hostnameError)
  108. self.telemetryLogger?.logModelDownloadEvent(
  109. eventName: .modelDownload,
  110. status: .failed,
  111. downloadErrorCode: .noConnection
  112. )
  113. case FileDownloaderError.unexpectedResponseType:
  114. let description = ModelDownloadTask.ErrorDescription.invalidHTTPResponse
  115. downloadError = .internalError(description: description)
  116. DeviceLogger.logEvent(level: .debug,
  117. message: description,
  118. messageCode: .invalidHTTPResponse)
  119. self.telemetryLogger?.logModelDownloadEvent(
  120. eventName: .modelDownload,
  121. status: .failed,
  122. downloadErrorCode: .downloadFailed
  123. )
  124. default:
  125. let description = ModelDownloadTask.ErrorDescription.unknownDownloadError
  126. downloadError = .internalError(description: description)
  127. DeviceLogger.logEvent(level: .debug,
  128. message: description,
  129. messageCode: .modelDownloadError)
  130. self.telemetryLogger?.logModelDownloadEvent(
  131. eventName: .modelDownload,
  132. status: .failed,
  133. downloadErrorCode: .downloadFailed
  134. )
  135. }
  136. completion(.failure(downloadError))
  137. }
  138. }
  139. }
  140. /// Handle model download response.
  141. func handleResponse(response: HTTPURLResponse, tempURL: URL, completion: @escaping Completion) {
  142. guard (200 ..< 299).contains(response.statusCode) else {
  143. switch response.statusCode {
  144. /// Possible failure due to download URL expiry.
  145. case 400:
  146. let currentDateTime = Date()
  147. /// Check if download url has expired.
  148. guard currentDateTime > remoteModelInfo.urlExpiryTime else {
  149. DeviceLogger.logEvent(level: .debug,
  150. message: ModelDownloadTask.ErrorDescription
  151. .invalidModelName(remoteModelInfo.name),
  152. messageCode: .invalidModelName)
  153. telemetryLogger?.logModelDownloadEvent(
  154. eventName: .modelDownload,
  155. status: .failed,
  156. downloadErrorCode: .httpError(code: response.statusCode)
  157. )
  158. completion(.failure(.invalidArgument))
  159. return
  160. }
  161. DeviceLogger.logEvent(level: .debug,
  162. message: ModelDownloadTask.ErrorDescription.expiredModelInfo,
  163. messageCode: .expiredModelInfo)
  164. telemetryLogger?.logModelDownloadEvent(
  165. eventName: .modelDownload,
  166. status: .failed,
  167. downloadErrorCode: .urlExpired
  168. )
  169. completion(.failure(.expiredDownloadURL))
  170. case 401, 403:
  171. DeviceLogger.logEvent(level: .debug,
  172. message: ModelDownloadTask.ErrorDescription.permissionDenied,
  173. messageCode: .permissionDenied)
  174. telemetryLogger?.logModelDownloadEvent(
  175. eventName: .modelDownload,
  176. status: .failed,
  177. downloadErrorCode: .httpError(code: response.statusCode)
  178. )
  179. completion(.failure(.permissionDenied))
  180. case 404:
  181. DeviceLogger.logEvent(level: .debug,
  182. message: ModelDownloadTask.ErrorDescription
  183. .modelNotFound(remoteModelInfo.name),
  184. messageCode: .modelNotFound)
  185. telemetryLogger?.logModelDownloadEvent(
  186. eventName: .modelDownload,
  187. status: .failed,
  188. downloadErrorCode: .httpError(code: response.statusCode)
  189. )
  190. completion(.failure(.notFound))
  191. default:
  192. let description = ModelDownloadTask.ErrorDescription
  193. .modelDownloadFailed(response.statusCode)
  194. DeviceLogger.logEvent(level: .debug,
  195. message: description,
  196. messageCode: .modelDownloadError)
  197. telemetryLogger?.logModelDownloadEvent(
  198. eventName: .modelDownload,
  199. status: .failed,
  200. downloadErrorCode: .httpError(code: response.statusCode)
  201. )
  202. completion(.failure(.internalError(description: description)))
  203. }
  204. return
  205. }
  206. let modelFileURL = ModelFileManager.getDownloadedModelFilePath(
  207. appName: appName,
  208. modelName: remoteModelInfo.name
  209. )
  210. do {
  211. try ModelFileManager.moveFile(
  212. at: tempURL,
  213. to: modelFileURL,
  214. size: Int64(remoteModelInfo.size)
  215. )
  216. DeviceLogger.logEvent(level: .debug,
  217. message: ModelDownloadTask.DebugDescription.savedModelFile,
  218. messageCode: .downloadedModelFileSaved)
  219. /// Generate local model info.
  220. let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: modelFileURL.absoluteString)
  221. /// Write model to user defaults.
  222. localModelInfo.writeToDefaults(defaults, appName: appName)
  223. DeviceLogger.logEvent(level: .debug,
  224. message: ModelDownloadTask.DebugDescription.savedLocalModelInfo,
  225. messageCode: .downloadedModelInfoSaved)
  226. /// Build model from model info.
  227. let model = CustomModel(localModelInfo: localModelInfo)
  228. DeviceLogger.logEvent(level: .debug,
  229. message: ModelDownloadTask.DebugDescription.modelDownloaded,
  230. messageCode: .modelDownloaded)
  231. telemetryLogger?.logModelDownloadEvent(
  232. eventName: .modelDownload,
  233. status: .succeeded,
  234. model: model,
  235. downloadErrorCode: .noError
  236. )
  237. completion(.success(model))
  238. } catch let error as DownloadError {
  239. if error == .notEnoughSpace {
  240. DeviceLogger.logEvent(level: .debug,
  241. message: ModelDownloadTask.ErrorDescription.notEnoughSpace,
  242. messageCode: .notEnoughSpace)
  243. } else {
  244. DeviceLogger.logEvent(level: .debug,
  245. message: ModelDownloadTask.ErrorDescription.saveModel,
  246. messageCode: .downloadedModelSaveError)
  247. }
  248. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  249. status: .succeeded,
  250. downloadErrorCode: .downloadFailed)
  251. completion(.failure(error))
  252. return
  253. } catch {
  254. DeviceLogger.logEvent(level: .debug,
  255. message: ModelDownloadTask.ErrorDescription.saveModel,
  256. messageCode: .downloadedModelSaveError)
  257. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
  258. status: .succeeded,
  259. downloadErrorCode: .downloadFailed)
  260. completion(.failure(.internalError(description: error.localizedDescription)))
  261. return
  262. }
  263. }
  264. }
  265. /// Possible error messages for model downloading.
  266. extension ModelDownloadTask {
  267. /// Debug descriptions.
  268. private enum DebugDescription {
  269. static let savedModelFile = "Model file saved successfully to device."
  270. static let savedLocalModelInfo = "Downloaded model info saved successfully to user defaults."
  271. static let receivedServerResponse = "Received a valid response from download server."
  272. static let modelDownloaded = "Model download completed successfully."
  273. }
  274. /// Error descriptions.
  275. private enum ErrorDescription {
  276. static let invalidHostName = { (error: String) in
  277. "Unable to resolve hostname or connect to host: \(error)"
  278. }
  279. static let modelDownloadFailed = { (code: Int) in
  280. "Model download failed with HTTP error code: \(code)"
  281. }
  282. static let modelNotFound = { (name: String) in
  283. "No model found with name: \(name)"
  284. }
  285. static let invalidModelName = { (name: String) in
  286. "Invalid model name: \(name)"
  287. }
  288. static let sessionInvalidated = "Session invalidated due to failed pre-conditions."
  289. static let invalidHTTPResponse =
  290. "Could not get valid HTTP response for model downloading."
  291. static let unknownDownloadError = "Unable to download model due to unknown error."
  292. static let saveModel = "Unable to save downloaded remote model file."
  293. static let notEnoughSpace = "Not enough space on device."
  294. static let expiredModelInfo = "Unable to update expired model info."
  295. static let anotherDownloadInProgress = "Download already in progress."
  296. static let permissionDenied = "Invalid or missing permissions to download model."
  297. }
  298. }