ModelDownloadTask.swift 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. /// Possible states of model downloading.
  17. enum ModelDownloadStatus {
  18. case notStarted
  19. case inProgress
  20. case successful
  21. case failed
  22. }
  23. /// Progress and completion handlers for a model download.
  24. class DownloadHandlers {
  25. typealias ProgressHandler = (Float) -> Void
  26. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  27. var progressHandler: ProgressHandler?
  28. var completion: Completion
  29. init(progressHandler: ProgressHandler?, completion: @escaping Completion) {
  30. self.progressHandler = progressHandler
  31. self.completion = completion
  32. }
  33. }
  34. /// Manager to handle model downloading device and storing downloaded model info to persistent storage.
  35. class ModelDownloadTask: NSObject {
  36. /// Name of the app associated with this instance of ModelDownloadTask.
  37. private let appName: String
  38. /// Model info downloaded from server.
  39. private(set) var remoteModelInfo: RemoteModelInfo
  40. /// User defaults to which local model info should ultimately be written.
  41. private let defaults: UserDefaults
  42. /// Task to handle model file download.
  43. private var downloadTask: URLSessionDownloadTask?
  44. /// Progress and completion handlers associated with this model download task.
  45. private let downloadHandlers: DownloadHandlers
  46. /// Keeps track of download associated with this model download task.
  47. private(set) var downloadStatus: ModelDownloadStatus = .notStarted
  48. /// URLSession to handle model downloads.
  49. private lazy var downloadSession = URLSession(configuration: .ephemeral,
  50. delegate: self,
  51. delegateQueue: nil)
  52. /// Telemetry logger.
  53. private let telemetryLogger: TelemetryLogger?
  54. init(remoteModelInfo: RemoteModelInfo, appName: String, defaults: UserDefaults,
  55. telemetryLogger: TelemetryLogger? = nil,
  56. progressHandler: DownloadHandlers.ProgressHandler? = nil,
  57. completion: @escaping DownloadHandlers.Completion) {
  58. self.remoteModelInfo = remoteModelInfo
  59. self.appName = appName
  60. self.telemetryLogger = telemetryLogger
  61. self.defaults = defaults
  62. downloadHandlers = DownloadHandlers(
  63. progressHandler: progressHandler,
  64. completion: completion
  65. )
  66. }
  67. /// Asynchronously download model file to device.
  68. func resumeModelDownload() {
  69. guard downloadStatus == .notStarted else { return }
  70. let downloadTask = downloadSession.downloadTask(with: remoteModelInfo.downloadURL)
  71. downloadTask.resume()
  72. downloadStatus = .inProgress
  73. self.downloadTask = downloadTask
  74. }
  75. }
  76. /// Extension to handle delegate methods.
  77. extension ModelDownloadTask: URLSessionDownloadDelegate {
  78. /// Name for model file stored on device.
  79. var downloadedModelFileName: String {
  80. return "fbml_model__\(appName)__\(remoteModelInfo.name)"
  81. }
  82. func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
  83. // TODO: Handle model download url expiry and other download errors
  84. }
  85. func urlSession(_ session: URLSession,
  86. downloadTask: URLSessionDownloadTask,
  87. didFinishDownloadingTo location: URL) {
  88. assert(downloadTask == self.downloadTask)
  89. let modelFileURL = ModelFileManager.getDownloadedModelFilePath(
  90. appName: appName,
  91. modelName: remoteModelInfo.name
  92. )
  93. do {
  94. try ModelFileManager.moveFile(at: location, to: modelFileURL)
  95. } catch let error as DownloadError {
  96. downloadStatus = .failed
  97. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload, status: downloadStatus)
  98. DeviceLogger.logEvent(
  99. level: .info,
  100. category: .modelDownload,
  101. message: "Unable to save downloaded remote model file.",
  102. messageCode: .modelDownloaded
  103. )
  104. DispatchQueue.main.async {
  105. self.downloadHandlers
  106. .completion(.failure(error))
  107. }
  108. } catch {
  109. downloadStatus = .failed
  110. telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload, status: downloadStatus)
  111. DeviceLogger.logEvent(
  112. level: .info,
  113. category: .modelDownload,
  114. message: "Unable to save downloaded remote model file.",
  115. messageCode: .modelDownloaded
  116. )
  117. DispatchQueue.main.async {
  118. self.downloadHandlers
  119. .completion(.failure(.internalError(description: error.localizedDescription)))
  120. }
  121. }
  122. /// Generate local model info.
  123. let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: modelFileURL.absoluteString)
  124. /// Write model to user defaults.
  125. localModelInfo.writeToDefaults(defaults, appName: appName)
  126. /// Build model from model info.
  127. let model = CustomModel(localModelInfo: localModelInfo)
  128. downloadStatus = .successful
  129. telemetryLogger?.logModelDownloadEvent(
  130. eventName: .modelDownload,
  131. status: downloadStatus,
  132. model: model
  133. )
  134. DispatchQueue.main.async {
  135. self.downloadHandlers.completion(.success(model))
  136. }
  137. }
  138. func urlSession(_ session: URLSession,
  139. downloadTask: URLSessionDownloadTask,
  140. didWriteData bytesWritten: Int64,
  141. totalBytesWritten: Int64,
  142. totalBytesExpectedToWrite: Int64) {
  143. assert(downloadTask == self.downloadTask)
  144. /// Check if progress handler is unspecified.
  145. guard let progressHandler = downloadHandlers.progressHandler else { return }
  146. let calculatedProgress = Float(totalBytesWritten) / Float(totalBytesExpectedToWrite)
  147. DispatchQueue.main.async {
  148. progressHandler(calculatedProgress)
  149. }
  150. }
  151. }