ModelDownloadTask.swift 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseCore
  16. /// Possible states of model downloading.
  17. enum DownloadStatus {
  18. case notStarted
  19. case inProgress
  20. case completed
  21. }
  22. /// Progress and completion handlers for a model download.
  23. class DownloadHandlers {
  24. typealias ProgressHandler = (Float) -> Void
  25. typealias Completion = (Result<CustomModel, DownloadError>) -> Void
  26. var progressHandler: ProgressHandler?
  27. var completion: Completion
  28. init(progressHandler: ProgressHandler?, completion: @escaping Completion) {
  29. self.progressHandler = progressHandler
  30. self.completion = completion
  31. }
  32. }
  33. /// Manager to handle model downloading device and storing downloaded model info to persistent storage.
  34. class ModelDownloadTask: NSObject {
  35. /// Name of the app associated with this instance of ModelDownloadTask.
  36. private let appName: String
  37. /// Model info downloaded from server.
  38. private(set) var remoteModelInfo: RemoteModelInfo
  39. /// User defaults to which local model info should ultimately be written.
  40. private let defaults: UserDefaults
  41. /// Task to handle model file download.
  42. private var downloadTask: URLSessionDownloadTask?
  43. /// Progress and completion handlers associated with this model download task.
  44. private let downloadHandlers: DownloadHandlers
  45. /// Keeps track of download associated with this model download task.
  46. private(set) var downloadStatus: DownloadStatus = .notStarted
  47. /// URLSession to handle model downloads.
  48. private lazy var downloadSession = URLSession(configuration: .ephemeral,
  49. delegate: self,
  50. delegateQueue: nil)
  51. init(remoteModelInfo: RemoteModelInfo, appName: String, defaults: UserDefaults,
  52. progressHandler: DownloadHandlers.ProgressHandler? = nil,
  53. completion: @escaping DownloadHandlers.Completion) {
  54. self.remoteModelInfo = remoteModelInfo
  55. self.appName = appName
  56. self.defaults = defaults
  57. downloadHandlers = DownloadHandlers(
  58. progressHandler: progressHandler,
  59. completion: completion
  60. )
  61. }
  62. /// Asynchronously download model file to device.
  63. func resumeModelDownload() {
  64. guard downloadStatus == .notStarted else { return }
  65. let downloadTask = downloadSession.downloadTask(with: remoteModelInfo.downloadURL)
  66. downloadTask.resume()
  67. downloadStatus = .inProgress
  68. self.downloadTask = downloadTask
  69. }
  70. }
  71. /// Extension to handle delegate methods.
  72. extension ModelDownloadTask: URLSessionDownloadDelegate {
  73. /// Name for model file stored on device.
  74. var downloadedModelFileName: String {
  75. return "fbml_model__\(appName)__\(remoteModelInfo.name)"
  76. }
  77. func urlSession(_ session: URLSession,
  78. downloadTask: URLSessionDownloadTask,
  79. didFinishDownloadingTo location: URL) {
  80. assert(downloadTask == self.downloadTask)
  81. downloadStatus = .completed
  82. let savedURL = ModelFileManager.modelsDirectory
  83. .appendingPathComponent(downloadedModelFileName)
  84. do {
  85. try ModelFileManager.moveFile(at: location, to: savedURL)
  86. } catch let error as DownloadError {
  87. DispatchQueue.main.async {
  88. self.downloadHandlers
  89. .completion(.failure(error))
  90. }
  91. } catch {
  92. DispatchQueue.main.async {
  93. self.downloadHandlers
  94. .completion(.failure(.internalError(description: error.localizedDescription)))
  95. }
  96. }
  97. /// Generate local model info.
  98. let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: savedURL.absoluteString)
  99. /// Write model to user defaults.
  100. localModelInfo.writeToDefaults(defaults, appName: appName)
  101. /// Build model from model info.
  102. let model = CustomModel(localModelInfo: localModelInfo)
  103. DispatchQueue.main.async {
  104. self.downloadHandlers.completion(.success(model))
  105. }
  106. }
  107. func urlSession(_ session: URLSession,
  108. downloadTask: URLSessionDownloadTask,
  109. didWriteData bytesWritten: Int64,
  110. totalBytesWritten: Int64,
  111. totalBytesExpectedToWrite: Int64) {
  112. assert(downloadTask == self.downloadTask)
  113. /// Check if progress handler is unspecified.
  114. guard let progressHandler = downloadHandlers.progressHandler else { return }
  115. let calculatedProgress = Float(totalBytesWritten) / Float(totalBytesExpectedToWrite)
  116. DispatchQueue.main.async {
  117. progressHandler(calculatedProgress)
  118. }
  119. }
  120. }