ModelDownloadTask.swift 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. /// Possible states of model downloading.
  17. enum ModelDownloadStatus {
  18. case notStarted
  19. case inProgress
  20. case successful
  21. case failed
  22. }
  23. /// Progress and completion handlers for a model download.
  24. class DownloadHandlers {
  25. typealias ProgressHandler = (Float) -> Void
  26. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  27. var progressHandler: ProgressHandler?
  28. var completion: Completion
  29. init(progressHandler: ProgressHandler?, completion: @escaping Completion) {
  30. self.progressHandler = progressHandler
  31. self.completion = completion
  32. }
  33. }
  34. /// Manager to handle model downloading device and storing downloaded model info to persistent storage.
  35. class ModelDownloadTask: NSObject {
  36. /// Name of the app associated with this instance of ModelDownloadTask.
  37. private let appName: String
  38. /// Model info downloaded from server.
  39. private(set) var remoteModelInfo: RemoteModelInfo
  40. /// User defaults to which local model info should ultimately be written.
  41. private let defaults: UserDefaults
  42. /// Task to handle model file download.
  43. private var downloadTask: URLSessionDownloadTask?
  44. /// Progress and completion handlers associated with this model download task.
  45. private let downloadHandlers: DownloadHandlers
  46. /// Keeps track of download associated with this model download task.
  47. private(set) var downloadStatus: ModelDownloadStatus = .notStarted
  48. /// URLSession to handle model downloads.
  49. private lazy var downloadSession = URLSession(configuration: .ephemeral,
  50. delegate: self,
  51. delegateQueue: nil)
  52. /// Model info retriever in case of retries.
  53. private let modelInfoRetriever: ModelInfoRetriever
  54. /// Number of retries in case of model download URL expiry.
  55. private var numberOfRetries: Int = 1
  56. /// Telemetry logger.
  57. private let telemetryLogger: TelemetryLogger?
  58. init(remoteModelInfo: RemoteModelInfo, appName: String, defaults: UserDefaults,
  59. modelInfoRetriever: ModelInfoRetriever,
  60. telemetryLogger: TelemetryLogger? = nil,
  61. progressHandler: DownloadHandlers.ProgressHandler? = nil,
  62. completion: @escaping DownloadHandlers.Completion) {
  63. self.remoteModelInfo = remoteModelInfo
  64. self.appName = appName
  65. self.modelInfoRetriever = modelInfoRetriever
  66. self.telemetryLogger = telemetryLogger
  67. self.defaults = defaults
  68. downloadHandlers = DownloadHandlers(
  69. progressHandler: progressHandler,
  70. completion: completion
  71. )
  72. }
  73. }
  74. extension ModelDownloadTask {
  75. /// Asynchronously download model file to device.
  76. func resumeModelDownload() {
  77. guard downloadStatus == .notStarted else { return }
  78. let downloadTask = downloadSession.downloadTask(with: remoteModelInfo.downloadURL)
  79. downloadTask.resume()
  80. downloadStatus = .inProgress
  81. self.downloadTask = downloadTask
  82. }
  83. }
  84. /// Extension to handle delegate methods.
  85. extension ModelDownloadTask: URLSessionDownloadDelegate {
  86. /// Name for model file stored on device.
  87. var downloadedModelFileName: String {
  88. return "fbml_model__\(appName)__\(remoteModelInfo.name)"
  89. }
  90. /// Handle client-side errors.
  91. func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
  92. // TODO: Log this.
  93. guard task == downloadTask else { return }
  94. guard let error = error else { return }
  95. /// Unable to resolve hostname or connect to host.
  96. DispatchQueue.main.async {
  97. self.downloadHandlers
  98. .completion(.failure(.internalError(description: error.localizedDescription)))
  99. }
  100. }
  101. func urlSession(_ session: URLSession,
  102. downloadTask: URLSessionDownloadTask,
  103. didFinishDownloadingTo location: URL) {
  104. // TODO: Log this.
  105. guard downloadTask == self.downloadTask else { return }
  106. guard let response = downloadTask.response as? HTTPURLResponse else {
  107. DispatchQueue.main.async {
  108. self.downloadHandlers
  109. .completion(.failure(.internalError(description: ModelDownloadTask
  110. .ErrorDescription.invalidServerResponse)))
  111. }
  112. return
  113. }
  114. guard (200 ..< 299).contains(response.statusCode) else {
  115. /// Possible failure due to download URL expiry.
  116. if response.statusCode == 400 {
  117. let currentDateTime = Date()
  118. /// Retry download if allowed.
  119. guard currentDateTime > remoteModelInfo.urlExpiryTime, numberOfRetries > 0 else {
  120. downloadStatus = .failed
  121. downloadHandlers
  122. .completion(.failure(.internalError(description: ModelDownloadTask.ErrorDescription
  123. .expiredModelInfo)))
  124. return
  125. }
  126. numberOfRetries -= 1
  127. modelInfoRetriever.downloadModelInfo { result in
  128. switch result {
  129. case let .success(downloadModelInfoResult):
  130. switch downloadModelInfoResult {
  131. /// New model info was downloaded from server.
  132. case let .modelInfo(remoteModelInfo):
  133. self.remoteModelInfo = remoteModelInfo
  134. self.resumeModelDownload()
  135. /// This should not ever be the case - model info cannot be unmodified within ModelDownloadTask.
  136. case .notModified:
  137. DispatchQueue.main.async {
  138. self.downloadHandlers
  139. .completion(.failure(.internalError(description: ModelDownloadTask
  140. .ErrorDescription.expiredModelInfo)))
  141. }
  142. }
  143. case let .failure(downloadError):
  144. self.downloadStatus = .failed
  145. DispatchQueue.main.async {
  146. self.downloadHandlers.completion(.failure(downloadError))
  147. }
  148. }
  149. }
  150. }
  151. return
  152. }
  153. let modelFileURL = ModelFileManager.getDownloadedModelFilePath(
  154. appName: appName,
  155. modelName: remoteModelInfo.name
  156. )
  157. do {
  158. try ModelFileManager.moveFile(at: location, to: modelFileURL)
  159. } catch let error as DownloadError {
  160. downloadStatus = .failed
  161. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload, status: downloadStatus)
  162. DeviceLogger.logEvent(
  163. level: .info,
  164. category: .modelDownload,
  165. message: ModelDownloadTask.ErrorDescription.saveModel,
  166. messageCode: .modelDownloaded
  167. )
  168. DispatchQueue.main.async {
  169. self.downloadHandlers
  170. .completion(.failure(error))
  171. }
  172. } catch {
  173. downloadStatus = .failed
  174. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload, status: downloadStatus)
  175. DeviceLogger.logEvent(
  176. level: .info,
  177. category: .modelDownload,
  178. message: ModelDownloadTask.ErrorDescription.saveModel,
  179. messageCode: .modelDownloaded
  180. )
  181. DispatchQueue.main.async {
  182. self.downloadHandlers
  183. .completion(.failure(.internalError(description: error.localizedDescription)))
  184. }
  185. }
  186. /// Generate local model info.
  187. let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: modelFileURL.absoluteString)
  188. /// Write model to user defaults.
  189. localModelInfo.writeToDefaults(defaults, appName: appName)
  190. /// Build model from model info.
  191. let model = CustomModel(localModelInfo: localModelInfo)
  192. downloadStatus = .successful
  193. telemetryLogger?.logModelDownloadEvent(
  194. eventName: .modelDownload,
  195. status: downloadStatus,
  196. model: model
  197. )
  198. DispatchQueue.main.async {
  199. self.downloadHandlers.completion(.success(model))
  200. }
  201. }
  202. func urlSession(_ session: URLSession,
  203. downloadTask: URLSessionDownloadTask,
  204. didWriteData bytesWritten: Int64,
  205. totalBytesWritten: Int64,
  206. totalBytesExpectedToWrite: Int64) {
  207. // TODO: Log this.
  208. guard downloadTask == self.downloadTask else { return }
  209. /// Check if progress handler is unspecified.
  210. guard let progressHandler = downloadHandlers.progressHandler else { return }
  211. let calculatedProgress = Float(totalBytesWritten) / Float(totalBytesExpectedToWrite)
  212. DispatchQueue.main.async {
  213. progressHandler(calculatedProgress)
  214. }
  215. }
  216. }
  217. /// Possible error messages for model downloading.
  218. extension ModelDownloadTask {
  219. /// Error descriptions.
  220. private enum ErrorDescription {
  221. static let invalidServerResponse =
  222. "Could not get server response for model downloading."
  223. static let saveModel: StaticString =
  224. "Unable to save downloaded remote model file."
  225. static let expiredModelInfo = "Unable to update expired model info."
  226. }
  227. }