Browse Source

Removes saved path on device (#7439)

* Restores ML Pods after M77.

* Fix Package.swift

* Re-add catalyst to GHA workflow.

* Add tests for list models.

* Failable emptying of model directory in tests.

* Remove path from local model info and user defaults.

* Clean up user defaults on delete + add interpreter to test app.

* Replace .absoluteString with .path. This makes model.path work directly with TFLite interpreter.

* Fix paths in tests.

* Unit test for successful get model.

* Rename modelPath to modelURL.

* Use defaults.integer instead of casting to Int.

* Remove force casting.

* Mostly minor fixes that came out of the bugbash (#7462)

* Use appropriate errors for network offline and api disabled.

* Error + test for resourceExhausted.

* Surface backend error message in logs whenever available.

* Add disclaimer to empty directories for testing only.

* Add Firelog events for model info retrieval.

* Add documentation + other minor fixes. (#7465)

* Add docs + clean up ModelDownloader.swift.

* Add docs + clean up ModelDownloadTask.swift.

* Add docs + clean up FileDownloader.swift.

* Add docs + clean up LocalModelInfo.swift. Handle TODO for userdefaults init failure.

* Add docs + clean up RemoteModelInfo.swift, DownloaderUserDefaults.swift.

* Add docs + clean up ModelInfoRetriever.swift.

* Add docs + clean up ModelDownloadConditions.swift.

* Add docs + clean up DeviceLogger.swift.

* Add docs + clean up TelemetryLogger.swift.

* Add docs + clean up ModelFileManager.swift.

* Remove resolved TODOs + other minor fixes.

* Update model download conditions init.

* Fix comment formatting.

* Minor newline formatting.

* Minor comment formatting.

* Add docs for ModelDownloadConditions.

* Basic fix for user defaults race condition.

* Minor updates to tests.

* Minor updates to tests.

* Final fixes (#7471)

* Refactor error parsing in model info response.

* Add Firelog for delete event + minor fixes.

* Temporary - test to check telemetry for delete event.

* Disable backup of model files; handle inability to locate model storage dir.

* Minor fixes.

* Add tflite path extension to model file.
manjanac 5 years ago
parent
commit
9597ffe51b

+ 18 - 0
FirebaseMLModelDownloader/Apps/Sample/MLDownloaderTestApp/Downloader.swift

@@ -13,6 +13,7 @@
 // limitations under the License.
 
 import Foundation
+import TensorFlowLite
 import FirebaseMLModelDownloader
 
 class Downloader: ObservableObject {
@@ -57,6 +58,23 @@ class Downloader: ObservableObject {
       case let .success(model):
         self.isDownloaded = true
         self.filePath = model.path
+        let fileURL = URL(fileURLWithPath: self.filePath)
+        do {
+          _ = try fileURL.checkResourceIsReachable()
+          let attr = try FileManager.default.attributesOfItem(atPath: self.filePath)
+          if let size = attr[FileAttributeKey.size] {
+            print("File size: \(size)")
+          } else {
+            print("Error - could not get file size.")
+          }
+        } catch {
+          print("File access error - \(error)")
+        }
+        do {
+          _ = try Interpreter(modelPath: self.filePath)
+        } catch {
+          print("Tensorflow error - \(error)")
+        }
       case let .failure(error):
         self.isDownloaded = false
         self.isError = true

+ 1 - 0
FirebaseMLModelDownloader/Apps/Sample/Podfile

@@ -2,6 +2,7 @@ platform :ios, '10.0'
 use_frameworks!
 
 pod 'FirebaseMLModelDownloader', :path => '../../../'
+pod 'TensorFlowLiteSwift'
 
 target 'MLDownloaderTestApp' do
     target 'MLDownloaderTestAppTests' do

+ 5 - 2
FirebaseMLModelDownloader/Sources/CustomModel.swift

@@ -18,10 +18,13 @@ import Foundation
 public struct CustomModel: Hashable {
   /// Name of the model.
   public let name: String
+
   /// Size of the custom model, provided by the server.
   public let size: Int
+
   /// Path where the model is stored on device.
   public let path: String
+
   /// Hash for the model, used for model verification.
   public let hash: String
 
@@ -33,11 +36,11 @@ public struct CustomModel: Hashable {
   }
 
   /// Convenience init to create model from local model info.
-  init(localModelInfo: LocalModelInfo) {
+  init(localModelInfo: LocalModelInfo, path: String) {
     self.init(
       name: localModelInfo.name,
       size: localModelInfo.size,
-      path: localModelInfo.path,
+      path: path,
       hash: localModelInfo.modelHash
     )
   }

+ 30 - 25
FirebaseMLModelDownloader/Sources/DeviceLogger.swift

@@ -19,58 +19,63 @@ import Foundation
   import GoogleUtilities
 #endif
 
-/// Enum of debug messages.
-// TODO: Create list of all possible messages with code - according to format.
+/// Enum of log messages.
 enum LoggerMessageCode: Int {
-  case modelDownloaded = 1
-  case downloadedModelFileSaved
-  case downloadedModelInfoSaved
-  case downloaderInstanceCreated
+  case downloaderInstanceCreated = 1
   case downloaderInstanceRetrieved
   case downloaderInstanceDeleted
-  case localModelFound
+  case modelDownloaded
+  case modelDownloadError
+  case retryDownload
   case backgroundModelDownloaded
+  case backgroundDownloadError
+  case disableBackupError
+  case downloadedModelFileSaved
+  case downloadedModelSaveError
+  case anotherDownloadInProgressError
+  case invalidDownloadSessionError
+  case mergeRequests
+  case localModelFound
+  case allLocalModelsFound
+  case listModelsError
   case modelNameParseError
   case noLocalModelInfo
+  case noLocalModelFile
   case outdatedModelPathError
-  case allLocalModelsFound
-  case listModelsError
-  case modelNotFound
   case modelDeleted
-  case invalidOptions
-  case retryDownload
-  case anotherDownloadInProgressError
-  case invalidModelName
-  case permissionDenied
-  case notEnoughSpace
+  case modelDeletionFailed
   case validHTTPResponse
-  case invalidModelInfoFetchURL
-  case modelInfoRetrievalError
   case validAuthToken
-  case hostnameError
-  case invalidDownloadSessionError
-  case invalidHTTPResponse
+  case invalidOptions
+  case invalidModelInfoFetchURL
+  case downloadedModelInfoSaved
   case missingModelHash
   case invalidModelInfoJSON
   case modelInfoDeleted
   case modelInfoDownloaded
   case modelInfoUnmodified
-  case mergeRequests
   case authTokenError
   case expiredModelInfo
-  case modelDownloadError
-  case downloadedModelSaveError
   case modelHashMismatchError
   case noModelHash
+  case modelInfoRetrievalError
+  case modelNotFound
+  case invalidArgument
+  case permissionDenied
+  case resourceExhausted
+  case notEnoughSpace
+  case hostnameError
+  case invalidHTTPResponse
   case analyticsEventEncodeError
   case telemetryInitError
-  case backgroundDownloadError
   case testError
 }
 
 /// On-device logger.
 class DeviceLogger {
+  /// Log identifier.
   static let service = "[Firebase/MLModelDownloader]"
+
   static func logEvent(level: GoogleLoggerLevel, message: String, messageCode: LoggerMessageCode) {
     let code = String(format: "I-MLM%06d", messageCode.rawValue)
     let args: [CVarArg] = []

+ 2 - 1
FirebaseMLModelDownloader/Sources/DownloaderUserDefaults.swift

@@ -14,7 +14,8 @@
 
 import Foundation
 
-/// Protocol to write info to user defaults.
+/// Protocol to save or delete model info in user defaults.
 protocol DownloaderUserDefaultsWriteable {
   func writeToDefaults(_ defaults: UserDefaults, appName: String)
+  func removeFromDefaults(_ defaults: UserDefaults, appName: String)
 }

+ 19 - 10
FirebaseMLModelDownloader/Sources/FileDownloader.swift

@@ -14,17 +14,19 @@
 
 import Foundation
 
-enum FileDownloaderError: Error {
-  case unexpectedResponseType
-  case networkError(Error)
-}
-
+/// File downloader response.
 struct FileDownloaderResponse {
   var urlResponse: HTTPURLResponse
   var fileURL: URL
 }
 
-/// URL Session to use while retrieving model info.
+/// Possible file downloader errors.
+enum FileDownloaderError: Error {
+  case unexpectedResponseType
+  case networkError(Error)
+}
+
+/// Protocol to download a file from server.
 protocol FileDownloader {
   typealias CompletionHandler = (Result<FileDownloaderResponse, Error>) -> Void
   typealias ProgressHandler = (_ bytesWritten: Int64, _ bytesExpectedToWrite: Int64) -> Void
@@ -38,10 +40,13 @@ protocol FileDownloader {
 class ModelFileDownloader: NSObject, FileDownloader {
   /// Model conditions for download.
   private let conditions: ModelDownloadConditions
+
   /// URL session configuration.
   private let configuration: URLSessionConfiguration
+
   /// Task to handle model file download.
   private var downloadTask: URLSessionDownloadTask?
+
   /// URLSession to handle model downloads.
   private lazy var downloadSession: URLSession = URLSession(
     configuration: configuration,
@@ -50,13 +55,14 @@ class ModelFileDownloader: NSObject, FileDownloader {
   )
   /// Successful download completion handler.
   private var completion: FileDownloader.CompletionHandler?
+
   /// Download progress handler.
   private var progressHandler: FileDownloader.ProgressHandler?
 
   init(conditions: ModelDownloadConditions) {
     self.conditions = conditions
     configuration = URLSessionConfiguration.ephemeral
-    /// Wait for network connectivity, if unavailable.
+    /// Wait for network connectivity.
     if #available(iOS 11.0, macOS 10.13, macCatalyst 13.0, tvOS 11.0, watchOS 4.0, *) {
       self.configuration.waitsForConnectivity = true
       /// Wait for 10 minutes.
@@ -68,10 +74,11 @@ class ModelFileDownloader: NSObject, FileDownloader {
   func downloadFile(with url: URL,
                     progressHandler: @escaping (Int64, Int64) -> Void,
                     completion: @escaping (Result<FileDownloaderResponse, Error>) -> Void) {
-    // TODO: Fail if download already in progress
+    // TODO: Fail if download already in progress.
     self.completion = completion
     self.progressHandler = progressHandler
     let downloadTask = downloadSession.downloadTask(with: url)
+    // Begin or resume model download.
     downloadTask.resume()
     self.downloadTask = downloadTask
   }
@@ -79,14 +86,15 @@ class ModelFileDownloader: NSObject, FileDownloader {
 
 /// Extension to handle delegate methods.
 extension ModelFileDownloader: URLSessionDownloadDelegate {
-  /// Handle client-side errors.
+  // Handle client-side errors.
   func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
     guard let error = error else { return }
     session.finishTasksAndInvalidate()
-    /// Unable to resolve hostname or connect to host.
+    // Unable to resolve hostname or connect to host.
     completion?(.failure(FileDownloaderError.networkError(error)))
   }
 
+  /// Download completion.
   func urlSession(_ session: URLSession,
                   downloadTask: URLSessionDownloadTask,
                   didFinishDownloadingTo location: URL) {
@@ -100,6 +108,7 @@ extension ModelFileDownloader: URLSessionDownloadDelegate {
     completion?(.success(downloaderResponse))
   }
 
+  /// Download progress.
   func urlSession(_ session: URLSession,
                   downloadTask: URLSessionDownloadTask,
                   didWriteData bytesWritten: Int64,

+ 17 - 19
FirebaseMLModelDownloader/Sources/LocalModelInfo.swift

@@ -13,10 +13,8 @@
 // limitations under the License.
 
 import Foundation
-import FirebaseCore
 
 /// Model info object with details about downloaded and locally available model.
-// TODO: Can this be backed by user defaults property wrappers?
 class LocalModelInfo {
   /// Model name.
   let name: String
@@ -27,41 +25,35 @@ class LocalModelInfo {
   /// Size of the model, as returned by server.
   let size: Int
 
-  /// Local path of the model.
-  let path: String
-
-  init(name: String, modelHash: String, size: Int, path: String) {
+  init(name: String, modelHash: String, size: Int) {
     self.name = name
     self.modelHash = modelHash
     self.size = size
-    self.path = path
   }
 
-  /// Convenience init to create local model info from remotely downloaded model info and a local model path.
-  convenience init(from remoteModelInfo: RemoteModelInfo, path: String) {
+  /// Convenience init to create local model info from remotely downloaded model info.
+  convenience init(from remoteModelInfo: RemoteModelInfo) {
     self.init(
       name: remoteModelInfo.name,
       modelHash: remoteModelInfo.modelHash,
-      size: remoteModelInfo.size,
-      path: path
+      size: remoteModelInfo.size
     )
   }
 
   /// Convenience init to create local model info from stored info in user defaults.
   convenience init?(fromDefaults defaults: UserDefaults, name: String, appName: String) {
     let defaultsPrefix = LocalModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
-    guard let modelHash = defaults.value(forKey: "\(defaultsPrefix).model-hash") as? String,
-      let size = defaults.value(forKey: "\(defaultsPrefix).model-size") as? Int,
-      let path = defaults.value(forKey: "\(defaultsPrefix).model-path") as? String else {
+    guard let modelHash = defaults.string(forKey: "\(defaultsPrefix).model-hash") else {
       return nil
     }
-    self.init(name: name, modelHash: modelHash, size: size, path: path)
+    let size = defaults.integer(forKey: "\(defaultsPrefix).model-size")
+    self.init(name: name, modelHash: modelHash, size: size)
   }
 }
 
 /// Extension to write local model info to user defaults.
 extension LocalModelInfo: DownloaderUserDefaultsWriteable {
-  /// Get user defaults key prefix.
+  // Get user defaults key prefix.
   private static func getUserDefaultsKeyPrefix(appName: String, modelName: String) -> String {
     let bundleID = Bundle.main.bundleIdentifier ?? ""
     return "\(bundleID).\(appName).\(modelName)"
@@ -72,7 +64,12 @@ extension LocalModelInfo: DownloaderUserDefaultsWriteable {
     let defaultsPrefix = LocalModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
     defaults.setValue(modelHash, forKey: "\(defaultsPrefix).model-hash")
     defaults.setValue(size, forKey: "\(defaultsPrefix).model-size")
-    defaults.setValue(path, forKey: "\(defaultsPrefix).model-path")
+  }
+
+  func removeFromDefaults(_ defaults: UserDefaults, appName: String) {
+    let defaultsPrefix = LocalModelInfo.getUserDefaultsKeyPrefix(appName: appName, modelName: name)
+    defaults.removeObject(forKey: "\(defaultsPrefix).model-hash")
+    defaults.removeObject(forKey: "\(defaultsPrefix).model-size")
   }
 }
 
@@ -80,8 +77,9 @@ extension LocalModelInfo: DownloaderUserDefaultsWriteable {
 extension UserDefaults {
   static var firebaseMLDefaults: UserDefaults {
     let suiteName = "com.google.firebase.ml"
-    // TODO: reconsider force unwrapping
-    let defaults = UserDefaults(suiteName: suiteName)!
+    guard let defaults = UserDefaults(suiteName: suiteName) else {
+      return UserDefaults.standard
+    }
     return defaults
   }
 }

+ 7 - 4
FirebaseMLModelDownloader/Sources/ModelDownloadConditions.swift

@@ -16,8 +16,11 @@ import Foundation
 
 /// Model download conditions.
 public struct ModelDownloadConditions {
-  /// Allow model downloading on a cellular connection. Default is `YES`.
-  public var allowsCellularAccess = true
-  /// Public init.
-  public init() {}
+  let allowsCellularAccess: Bool
+
+  /// Conditions that need to be met to start a model file download.
+  /// - Parameter allowsCellularAccess: Allow model downloading on a cellular connection. Default is `true`.
+  public init(allowsCellularAccess: Bool = true) {
+    self.allowsCellularAccess = allowsCellularAccess
+  }
 }

+ 93 - 68
FirebaseMLModelDownloader/Sources/ModelDownloadTask.swift

@@ -13,44 +13,33 @@
 // limitations under the License.
 
 import Foundation
-import FirebaseCore
-
-/// Possible states of model downloading.
-enum ModelDownloadStatus {
-  case ready
-  case downloading
-  case complete
-}
-
-/// Download error codes.
-enum ModelDownloadErrorCode {
-  case noError
-  case urlExpired
-  case noConnection
-  case downloadFailed
-  case httpError(code: Int)
-}
-
-/// Manager to handle model downloading device and storing downloaded model info to persistent storage.
-class ModelDownloadTask: NSObject {
-  typealias ProgressHandler = (Float) -> Void
-  typealias Completion = (Result<CustomModel, DownloadError>) -> Void
 
+/// Task to download model file to device.
+class ModelDownloadTask {
   /// Name of the app associated with this instance of ModelDownloadTask.
   private let appName: String
+
   /// Model info downloaded from server.
   private(set) var remoteModelInfo: RemoteModelInfo
+
   /// User defaults to which local model info should ultimately be written.
   private let defaults: UserDefaults
+
   /// Keeps track of download associated with this model download task.
   private(set) var downloadStatus: ModelDownloadStatus = .ready
+
   /// Downloader instance.
   private let downloader: FileDownloader
+
   /// Telemetry logger.
   private let telemetryLogger: TelemetryLogger?
-  /// Progress handler.
+
+  /// Download progress handler.
+  typealias ProgressHandler = (Float) -> Void
   private var progressHandler: ProgressHandler?
-  /// Completion.
+
+  /// Download completion handler.
+  typealias Completion = (Result<CustomModel, DownloadError>) -> Void
   private var completion: Completion
 
   init(remoteModelInfo: RemoteModelInfo,
@@ -62,30 +51,20 @@ class ModelDownloadTask: NSObject {
        telemetryLogger: TelemetryLogger? = nil) {
     self.remoteModelInfo = remoteModelInfo
     self.appName = appName
+    self.defaults = defaults
     self.downloader = downloader
     self.progressHandler = progressHandler
     self.completion = completion
     self.telemetryLogger = telemetryLogger
-    self.defaults = defaults
   }
 }
 
 extension ModelDownloadTask {
-  /// Name for model file stored on device.
-  var downloadedModelFileName: String {
-    return "fbml_model__\(appName)__\(remoteModelInfo.name).tflite"
-  }
-
-  /// Check if downloading is not complete for merging requests.
+  /// Check if requests can be merged.
   func canMergeRequests() -> Bool {
     return downloadStatus != .complete
   }
 
-  /// Check if download task can be resumed.
-  func canResume() -> Bool {
-    return downloadStatus == .ready
-  }
-
   /// Merge duplicate requests. This method is not thread-safe.
   func merge(newProgressHandler: ProgressHandler? = nil, newCompletion: @escaping Completion) {
     let originalProgressHandler = progressHandler
@@ -100,8 +79,14 @@ extension ModelDownloadTask {
     }
   }
 
+  /// Check if download task can be resumed.
+  func canResume() -> Bool {
+    return downloadStatus == .ready
+  }
+
+  /// Download model file.
   func resume() {
-    /// Prevent multiple concurrent downloads.
+    // Prevent multiple concurrent downloads.
     guard downloadStatus != .downloading else {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloadTask.ErrorDescription.anotherDownloadInProgress,
@@ -139,7 +124,7 @@ extension ModelDownloadTask {
         case let FileDownloaderError.networkError(error):
           let description = ModelDownloadTask.ErrorDescription
             .invalidHostName(error.localizedDescription)
-          downloadError = .internalError(description: description)
+          downloadError = .failedPrecondition
           DeviceLogger.logEvent(level: .debug,
                                 message: description,
                                 messageCode: .hostnameError)
@@ -180,15 +165,13 @@ extension ModelDownloadTask {
   func handleResponse(response: HTTPURLResponse, tempURL: URL, completion: @escaping Completion) {
     guard (200 ..< 299).contains(response.statusCode) else {
       switch response.statusCode {
-      /// Possible failure due to download URL expiry.
       case 400:
-        let currentDateTime = Date()
-        /// Check if download url has expired.
-        guard currentDateTime > remoteModelInfo.urlExpiryTime else {
+        // Possible failure due to download URL expiry. Check if download URL has expired.
+        guard remoteModelInfo.urlExpiryTime < Date() else {
           DeviceLogger.logEvent(level: .debug,
                                 message: ModelDownloadTask.ErrorDescription
-                                  .invalidModelName(remoteModelInfo.name),
-                                messageCode: .invalidModelName)
+                                  .invalidArgument(remoteModelInfo.name),
+                                messageCode: .invalidArgument)
           telemetryLogger?.logModelDownloadEvent(
             eventName: .modelDownload,
             status: .failed,
@@ -243,12 +226,36 @@ extension ModelDownloadTask {
       return
     }
 
-    let modelFileURL = ModelFileManager.getDownloadedModelFilePath(
+    /// Construct local model file URL.
+    guard var modelFileURL = ModelFileManager.getDownloadedModelFileURL(
       appName: appName,
       modelName: remoteModelInfo.name
-    )
+    ) else {
+      // Could not create Application Support directory to store model files.
+      let description = ModelDownloadTask.ErrorDescription.noModelsDirectory
+      DeviceLogger.logEvent(level: .debug,
+                            message: description,
+                            messageCode: .downloadedModelSaveError)
+      // Downloading the file succeeding but saving failed.
+      telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
+                                             status: .succeeded,
+                                             downloadErrorCode: .downloadFailed)
+      completion(.failure(.internalError(description: description)))
+      return
+    }
 
     do {
+      // Try disabling iCloud backup for model files, because UserDefaults is not backed up.
+      var resourceValue = URLResourceValues()
+      resourceValue.isExcludedFromBackup = true
+      do {
+        try modelFileURL.setResourceValues(resourceValue)
+      } catch {
+        DeviceLogger.logEvent(level: .debug,
+                              message: ModelDownloadTask.ErrorDescription.disableBackupError,
+                              messageCode: .disableBackupError)
+      }
+      // Save model file to device.
       try ModelFileManager.moveFile(
         at: tempURL,
         to: modelFileURL,
@@ -257,15 +264,15 @@ extension ModelDownloadTask {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloadTask.DebugDescription.savedModelFile,
                             messageCode: .downloadedModelFileSaved)
-      /// Generate local model info.
-      let localModelInfo = LocalModelInfo(from: remoteModelInfo, path: modelFileURL.absoluteString)
-      /// Write model to user defaults.
+      // Generate local model info.
+      let localModelInfo = LocalModelInfo(from: remoteModelInfo)
+      // Write model info to user defaults.
       localModelInfo.writeToDefaults(defaults, appName: appName)
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloadTask.DebugDescription.savedLocalModelInfo,
                             messageCode: .downloadedModelInfoSaved)
-      /// Build model from model info.
-      let model = CustomModel(localModelInfo: localModelInfo)
+      // Build model from model info and local path.
+      let model = CustomModel(localModelInfo: localModelInfo, path: modelFileURL.path)
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloadTask.DebugDescription.modelDownloaded,
                             messageCode: .modelDownloaded)
@@ -283,9 +290,10 @@ extension ModelDownloadTask {
                               messageCode: .notEnoughSpace)
       } else {
         DeviceLogger.logEvent(level: .debug,
-                              message: ModelDownloadTask.ErrorDescription.saveModel,
+                              message: ModelDownloadTask.ErrorDescription.modelSaveError,
                               messageCode: .downloadedModelSaveError)
       }
+      // Downloading the file succeeding but saving failed.
       telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
                                              status: .succeeded,
                                              downloadErrorCode: .downloadFailed)
@@ -293,8 +301,9 @@ extension ModelDownloadTask {
       return
     } catch {
       DeviceLogger.logEvent(level: .debug,
-                            message: ModelDownloadTask.ErrorDescription.saveModel,
+                            message: ModelDownloadTask.ErrorDescription.modelSaveError,
                             messageCode: .downloadedModelSaveError)
+      // Downloading the file succeeding but saving failed.
       telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
                                              status: .succeeded,
                                              downloadErrorCode: .downloadFailed)
@@ -304,13 +313,29 @@ extension ModelDownloadTask {
   }
 }
 
-/// Possible error messages for model downloading.
+/// Possible states of model downloading.
+enum ModelDownloadStatus {
+  case ready
+  case downloading
+  case complete
+}
+
+/// Download error codes.
+enum ModelDownloadErrorCode {
+  case noError
+  case urlExpired
+  case noConnection
+  case downloadFailed
+  case httpError(code: Int)
+}
+
+/// Possible debug and error messages for model downloading.
 extension ModelDownloadTask {
   /// Debug descriptions.
   private enum DebugDescription {
+    static let receivedServerResponse = "Received a valid response from download server."
     static let savedModelFile = "Model file saved successfully to device."
     static let savedLocalModelInfo = "Downloaded model info saved successfully to user defaults."
-    static let receivedServerResponse = "Received a valid response from download server."
     static let modelDownloaded = "Model download completed successfully."
   }
 
@@ -320,26 +345,26 @@ extension ModelDownloadTask {
       "Unable to resolve hostname or connect to host: \(error)"
     }
 
-    static let modelDownloadFailed = { (code: Int) in
-      "Model download failed with HTTP error code: \(code)"
-    }
-
+    static let sessionInvalidated = "Session invalidated due to failed pre-conditions."
+    static let invalidHTTPResponse = "Could not get valid HTTP response for model downloading."
+    static let anotherDownloadInProgress = "Download already in progress."
     static let modelNotFound = { (name: String) in
       "No model found with name: \(name)"
     }
 
-    static let invalidModelName = { (name: String) in
-      "Invalid model name: \(name)"
+    static let invalidArgument = { (name: String) in
+      "Invalid argument for model name: \(name)"
     }
 
-    static let sessionInvalidated = "Session invalidated due to failed pre-conditions."
-    static let invalidHTTPResponse =
-      "Could not get valid HTTP response for model downloading."
-    static let unknownDownloadError = "Unable to download model due to unknown error."
-    static let saveModel = "Unable to save downloaded remote model file."
-    static let notEnoughSpace = "Not enough space on device."
     static let expiredModelInfo = "Unable to update expired model info."
-    static let anotherDownloadInProgress = "Download already in progress."
     static let permissionDenied = "Invalid or missing permissions to download model."
+    static let notEnoughSpace = "Not enough space on device."
+    static let disableBackupError = "Unable to disable model file backup."
+    static let noModelsDirectory = "Could not create directory for model storage."
+    static let modelSaveError = "Unable to save downloaded remote model file."
+    static let unknownDownloadError = "Unable to download model due to unknown error."
+    static let modelDownloadFailed = { (code: Int) in
+      "Model download failed with HTTP error code: \(code)"
+    }
   }
 }

+ 191 - 129
FirebaseMLModelDownloader/Sources/ModelDownloader.swift

@@ -16,46 +16,13 @@ import Foundation
 import FirebaseCore
 import FirebaseInstallations
 
-/// Possible errors with model downloading.
-public enum DownloadError: Error, Equatable {
-  /// No model with this name found on server.
-  case notFound
-  /// Caller does not have necessary permissions for this operation.
-  case permissionDenied
-  /// Conditions not met to perform download.
-  case failedPrecondition
-  /// Not enough space for model on device.
-  case notEnoughSpace
-  /// Malformed model name.
-  case invalidArgument
-  /// Other errors with description.
-  case internalError(description: String)
-}
-
-/// Possible errors with locating model on device.
-public enum DownloadedModelError: Error {
-  /// File system error.
-  case fileIOError(description: String)
-  /// Model not found on device.
-  case notFound
-  /// Other errors with description.
-  case internalError(description: String)
-}
-
-/// Extension to handle internally meaningful errors.
-extension DownloadError {
-  static let expiredDownloadURL: DownloadError = {
-    DownloadError.internalError(description: "Expired model download URL.")
-  }()
-}
-
 /// Possible ways to get a custom model.
 public enum ModelDownloadType {
-  /// Get local model stored on device.
+  /// Get local model stored on device if available. If no local model on device, this is the same as `latestModel`.
   case localModel
-  /// Get local model on device and update to latest model from server in the background.
+  /// Get local model on device if available and update to latest model from server in the background. If no local model on device, this is the same as `latestModel`.
   case localModelUpdateInBackground
-  /// Get latest model from server.
+  /// Get latest model from server. Does not make a network call for model file download if local model matches the latest version on server.
   case latestModel
 }
 
@@ -63,41 +30,48 @@ public enum ModelDownloadType {
 public class ModelDownloader {
   /// Name of the app associated with this instance of ModelDownloader.
   private let appName: String
+
   /// Current Firebase app options.
   private let options: FirebaseOptions
+
   /// Installations instance for current Firebase app.
   private let installations: Installations
+
   /// User defaults for model info.
   private let userDefaults: UserDefaults
+
   /// Telemetry logger tied to this instance of model downloader.
   let telemetryLogger: TelemetryLogger?
+
   /// Number of retries in case of model download URL expiry.
   var numberOfRetries: Int = 1
-  /// Handler that always runs on the main thread
-  let mainQueueHandler = { handler in
-    DispatchQueue.main.async {
-      handler
-    }
-  }
+
+  /// Shared dictionary mapping app name to a specific instance of model downloader.
+  // TODO: Switch to using Firebase components.
+  private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
 
   /// Download task associated with the model currently being downloaded.
   private var currentDownloadTask: [String: ModelDownloadTask] = [:]
+
   /// DispatchQueue to manage download task dictionary.
   let taskSerialQueue = DispatchQueue(label: "downloadtask.serial.queue")
 
-  /// Shared dictionary mapping app name to a specific instance of model downloader.
-  // TODO: Switch to using Firebase components.
-  private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
+  /// Handler that always runs on the main thread
+  let mainQueueHandler = { handler in
+    DispatchQueue.main.async {
+      handler
+    }
+  }
 
-  /// Private init for downloader.
+  /// Private init for model downloader.
   private init(app: FirebaseApp, defaults: UserDefaults = .firebaseMLDefaults) {
     appName = app.name
     options = app.options
     installations = Installations.installations(app: app)
-    /// Respect Firebase-wide data collection setting.
-    telemetryLogger = TelemetryLogger(app: app)
     userDefaults = defaults
-
+    // Respect Firebase-wide data collection setting.
+    telemetryLogger = TelemetryLogger(app: app)
+    // Notification of app deletion.
     let notificationName = "FIRAppDeleteNotification"
     NotificationCenter.default.addObserver(
       self,
@@ -113,9 +87,8 @@ public class ModelDownloader {
     if let userInfo = notification.userInfo,
       let appName = userInfo[userInfoKey] as? String {
       ModelDownloader.modelDownloaderDictionary.removeValue(forKey: appName)
-      // TODO: Do we need to force deinit downloader instance?
-      // TODO: Clean up user defaults
-      // TODO: Clean up local instances of app
+      // TODO: Clean up user defaults.
+      // TODO: Clean up local instances of app.
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.DebugDescription.deleteModelDownloader,
                             messageCode: .downloaderInstanceDeleted)
@@ -147,7 +120,13 @@ public class ModelDownloader {
     }
   }
 
-  /// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
+  /// Downloads a custom model to device or gets a custom model already on device, with an optional handler for progress.
+  /// - Parameters:
+  ///   - modelName: The name of the model, matching Firebase console.
+  ///   - downloadType: ModelDownloadType used to get the model.
+  ///   - conditions: Conditions needed to perform a model download.
+  ///   - progressHandler: Optional. Returns a float in [0.0, 1.0] that can be used to monitor model download progress.
+  ///   - completion: Returns either a `CustomModel` on success, or a `DownloadError` on failure, at the end of a model download.
   public func getModel(name modelName: String,
                        downloadType: ModelDownloadType,
                        conditions: ModelDownloadConditions,
@@ -168,6 +147,7 @@ public class ModelDownloader {
           completion: completion
         )
       }
+
     case .localModelUpdateInBackground:
       if let localModel = getLocalModel(modelName: modelName) {
         DeviceLogger.logEvent(level: .debug,
@@ -179,6 +159,7 @@ public class ModelDownloader {
           status: .scheduled,
           downloadErrorCode: .noError
         )
+        // Update local model in the background.
         DispatchQueue.global(qos: .utility).async { [weak self] in
           self?.getRemoteModel(
             modelName: modelName,
@@ -230,21 +211,24 @@ public class ModelDownloader {
     }
   }
 
-  /// Gets all downloaded models.
+  /// Gets the set of all downloaded models saved on device.
+  /// - Parameter completion: Returns either a set of `CustomModel` models on success, or a `DownloadedModelError` on failure.
   public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
     DownloadedModelError>) -> Void) {
     do {
-      let modelPaths = try ModelFileManager.contentsOfModelsDirectory()
+      let modelURLs = try ModelFileManager.contentsOfModelsDirectory()
       var customModels = Set<CustomModel>()
-      for path in modelPaths {
-        guard let modelName = ModelFileManager.getModelNameFromFilePath(path) else {
-          let description = ModelDownloader.ErrorDescription.parseModelName(path.absoluteString)
+      // Retrieve model name from URL.
+      for url in modelURLs {
+        guard let modelName = ModelFileManager.getModelNameFromFilePath(url) else {
+          let description = ModelDownloader.ErrorDescription.parseModelName(url.path)
           DeviceLogger.logEvent(level: .debug,
                                 message: description,
                                 messageCode: .modelNameParseError)
           mainQueueHandler(completion(.failure(.internalError(description: description))))
           return
         }
+        // Check if model information corresponding to model is stored in UserDefaults.
         guard let modelInfo = getLocalModelInfo(modelName: modelName) else {
           let description = ModelDownloader.ErrorDescription.noLocalModelInfo(modelName)
           DeviceLogger.logEvent(level: .debug,
@@ -253,8 +237,12 @@ public class ModelDownloader {
           mainQueueHandler(completion(.failure(.internalError(description: description))))
           return
         }
-        guard let pathURL = URL(string: modelInfo.path)?.standardizedFileURL,
-          pathURL == path.standardizedFileURL else {
+        // Ensure that local model path is as expected, and reachable.
+        guard let modelURL = ModelFileManager.getDownloadedModelFileURL(
+          appName: appName,
+          modelName: modelName
+        ),
+          url == modelURL, ModelFileManager.isFileReachable(at: modelURL) else {
           DeviceLogger.logEvent(level: .debug,
                                 message: ModelDownloader.ErrorDescription.outdatedModelPath,
                                 messageCode: .outdatedModelPathError)
@@ -262,7 +250,8 @@ public class ModelDownloader {
               .ErrorDescription.outdatedModelPath))))
           return
         }
-        let model = CustomModel(localModelInfo: modelInfo)
+        let model = CustomModel(localModelInfo: modelInfo, path: modelURL.path)
+        // Add model to result set.
         customModels.insert(model)
       }
       DeviceLogger.logEvent(level: .debug,
@@ -283,12 +272,20 @@ public class ModelDownloader {
     }
   }
 
-  /// Deletes a custom model from device.
+  /// Deletes a custom model file from device as well as corresponding model information saved in UserDefaults.
+  /// - Parameters:
+  ///   - modelName: The name of the model, matching Firebase console and already downloaded to device.
+  ///   - completion: Returns a `DownloadedModelError` on failure.
   public func deleteDownloadedModel(name modelName: String,
                                     completion: @escaping (Result<Void, DownloadedModelError>)
                                       -> Void) {
-    guard let localModelInfo = getLocalModelInfo(modelName: modelName),
-      let localPath = URL(string: localModelInfo.path)
+    // Ensure that there is a matching model file on device, with corresponding model information in UserDefaults.
+    guard let modelURL = ModelFileManager.getDownloadedModelFileURL(
+      appName: appName,
+      modelName: modelName
+    ),
+      let localModelInfo = getLocalModelInfo(modelName: modelName),
+      ModelFileManager.isFileReachable(at: modelURL)
     else {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.ErrorDescription.modelNotFound(modelName),
@@ -297,14 +294,35 @@ public class ModelDownloader {
       return
     }
     do {
-      try ModelFileManager.removeFile(at: localPath)
+      // Remove model file from device.
+      try ModelFileManager.removeFile(at: modelURL)
+      // Clear out corresponding local model info.
+      localModelInfo.removeFromDefaults(userDefaults, appName: appName)
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.DebugDescription.modelDeleted,
                             messageCode: .modelDeleted)
+      telemetryLogger?.logModelDeletedEvent(
+        eventName: .remoteModelDeleteOnDevice,
+        isSuccessful: true
+      )
       mainQueueHandler(completion(.success(())))
     } catch let error as DownloadedModelError {
+      DeviceLogger.logEvent(level: .debug,
+                            message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
+                            messageCode: .modelDeletionFailed)
+      telemetryLogger?.logModelDeletedEvent(
+        eventName: .remoteModelDeleteOnDevice,
+        isSuccessful: false
+      )
       mainQueueHandler(completion(.failure(error)))
     } catch {
+      DeviceLogger.logEvent(level: .debug,
+                            message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
+                            messageCode: .modelDeletionFailed)
+      telemetryLogger?.logModelDeletedEvent(
+        eventName: .remoteModelDeleteOnDevice,
+        isSuccessful: false
+      )
       mainQueueHandler(completion(.failure(.internalError(description: error
           .localizedDescription))))
     }
@@ -312,7 +330,7 @@ public class ModelDownloader {
 }
 
 extension ModelDownloader {
-  /// Return local model info only if the model info is available and the corresponding model file is already on device.
+  /// Get model information for model saved on device, if available.
   private func getLocalModelInfo(modelName: String) -> LocalModelInfo? {
     guard let localModelInfo = LocalModelInfo(
       fromDefaults: userDefaults,
@@ -325,19 +343,27 @@ extension ModelDownloader {
                             messageCode: .noLocalModelInfo)
       return nil
     }
-    /// There is local model info on device, but no model file at the expected path.
-    guard let localPath = URL(string: localModelInfo.path),
-      ModelFileManager.isFileReachable(at: localPath) else {
-      // TODO: Consider deleting local model info in user defaults.
+    /// Local model info is only considered valid if there is a corresponding model file on device.
+    guard let modelURL = ModelFileManager.getDownloadedModelFileURL(
+      appName: appName,
+      modelName: modelName
+    ), ModelFileManager.isFileReachable(at: modelURL) else {
+      let description = ModelDownloader.DebugDescription.noLocalModelFile(modelName)
+      DeviceLogger.logEvent(level: .debug,
+                            message: description,
+                            messageCode: .noLocalModelFile)
       return nil
     }
     return localModelInfo
   }
 
-  /// Get model saved on device if available.
+  /// Get model saved on device, if available.
   private func getLocalModel(modelName: String) -> CustomModel? {
-    guard let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
-    let model = CustomModel(localModelInfo: localModelInfo)
+    guard let modelURL = ModelFileManager.getDownloadedModelFileURL(
+      appName: appName,
+      modelName: modelName
+    ), let localModelInfo = getLocalModelInfo(modelName: modelName) else { return nil }
+    let model = CustomModel(localModelInfo: localModelInfo, path: modelURL.path)
     return model
   }
 
@@ -359,8 +385,7 @@ extension ModelDownloader {
       modelName: modelName,
       projectID: projectID,
       apiKey: apiKey,
-      installations: installations,
-      appName: appName,
+      appName: appName, installations: installations,
       localModelInfo: localModelInfo,
       telemetryLogger: telemetryLogger
     )
@@ -375,6 +400,7 @@ extension ModelDownloader {
     )
   }
 
+  /// Get model info and model file from server.
   func downloadInfoAndModel(modelName: String,
                             modelInfoRetriever: ModelInfoRetriever,
                             downloader: FileDownloader,
@@ -386,13 +412,15 @@ extension ModelDownloader {
       switch result {
       case let .success(downloadModelInfoResult):
         switch downloadModelInfoResult {
-        /// New model info was downloaded from server.
+        // New model info was downloaded from server.
         case let .modelInfo(remoteModelInfo):
+          // Progress handler for model file download.
           let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
             if let progressHandler = progressHandler {
               self.mainQueueHandler(progressHandler(progress))
             }
           }
+          // Completion handler for model file download.
           let taskCompletion: ModelDownloadTask.Completion = { result in
             switch result {
             case let .success(model):
@@ -405,9 +433,9 @@ extension ModelDownloader {
                 self.mainQueueHandler(completion(.failure(.invalidArgument)))
               case .permissionDenied:
                 self.mainQueueHandler(completion(.failure(.permissionDenied)))
-              /// This is the error returned when URL expired.
+              // This is the error returned when model download URL has expired.
               case .expiredDownloadURL:
-                /// Check if retries are allowed.
+                // Retry model info and model file download, if allowed.
                 guard self.numberOfRetries > 0 else {
                   self
                     .mainQueueHandler(
@@ -433,15 +461,13 @@ extension ModelDownloader {
                 self.mainQueueHandler(completion(.failure(error)))
               }
             }
-
             self.taskSerialQueue.async {
-              /// Stop keeping track of current download task.
+              // Stop keeping track of current download task.
               self.currentDownloadTask.removeValue(forKey: modelName)
             }
           }
-
           self.taskSerialQueue.sync {
-            /// Merge duplicate requests if there is already a download in progress for the same model.
+            // Merge duplicate requests if there is already a download in progress for the same model.
             if let downloadTask = self.currentDownloadTask[modelName],
               downloadTask.canMergeRequests() {
               downloadTask.merge(
@@ -454,7 +480,9 @@ extension ModelDownloader {
               if downloadTask.canResume() {
                 downloadTask.resume()
               }
+              // TODO: Handle else.
             } else {
+              // Create download task for model file download.
               let downloadTask = ModelDownloadTask(
                 remoteModelInfo: remoteModelInfo,
                 appName: self.appName,
@@ -464,7 +492,7 @@ extension ModelDownloader {
                 completion: taskCompletion,
                 telemetryLogger: self.telemetryLogger
               )
-              /// Keep track of current download task to allow for merging duplicate requests.
+              // Keep track of current download task to allow for merging duplicate requests.
               self.currentDownloadTask[modelName] = downloadTask
               downloadTask.resume()
             }
@@ -472,16 +500,15 @@ extension ModelDownloader {
         /// Local model info is the latest model info.
         case .notModified:
           guard let localModel = self.getLocalModel(modelName: modelName) else {
-            /// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
-            // TODO: Consider handling: If model file is deleted after local model info is retrieved but before model info network call.
+            // This can only happen if either local model info or the model file was wiped out after model info request but before server response.
             self
               .mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
-                  .ErrorDescription.deletedLocalModelInfo))))
+                  .ErrorDescription.deletedLocalModelInfoOrFile))))
             return
           }
-
           self.mainQueueHandler(completion(.success(localModel)))
         }
+      // Error retrieving model info.
       case let .failure(error):
         self.mainQueueHandler(completion(.failure(error)))
       }
@@ -489,79 +516,114 @@ extension ModelDownloader {
   }
 }
 
-/// Model downloader extension for testing.
-extension ModelDownloader {
-  /// Model downloader instance for testing.
-  // TODO: Consider using protocols
-  static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
-                                          app: FirebaseApp) -> ModelDownloader {
-    if let downloader = modelDownloaderDictionary[app.name] {
-      return downloader
-    } else {
-      let downloader = ModelDownloader(app: app, defaults: defaults)
-      modelDownloaderDictionary[app.name] = downloader
-      return downloader
-    }
-  }
+/// Possible errors with model downloading.
+public enum DownloadError: Error, Equatable {
+  /// No model with this name exists on server.
+  case notFound
+  /// Invalid, incomplete, or missing permissions for model download.
+  case permissionDenied
+  /// Conditions not met to perform download.
+  case failedPrecondition
+  /// Requests quota exhausted.
+  case resourceExhausted
+  /// Not enough space for model on device.
+  case notEnoughSpace
+  /// Malformed model name or Firebase app options.
+  case invalidArgument
+  /// Other errors with description.
+  case internalError(description: String)
+}
+
+/// Possible errors with locating a model file on device.
+public enum DownloadedModelError: Error {
+  /// No model with this name exists on device.
+  case notFound
+  /// File system error.
+  case fileIOError(description: String)
+  /// Other errors with description.
+  case internalError(description: String)
+}
+
+/// Extension to handle internally meaningful errors.
+extension DownloadError {
+  /// Model download URL expired before model download.
+  // Model info retrieval and download is retried `numberOfRetries` times before failing.
+  static let expiredDownloadURL: DownloadError = {
+    DownloadError.internalError(description: "Expired model download URL.")
+  }()
 }
 
-/// Possible error messages while using model downloader.
+/// Possible debug and error messages while using model downloader.
 extension ModelDownloader {
   /// Debug descriptions.
   private enum DebugDescription {
-    static let deleteModelDownloader = "Model downloader instance deleted due to app deletion."
-    static let retrieveModelDownloader =
-      "Initialized with existing downloader instance associated with this app."
     static let createModelDownloader =
       "Initialized with new downloader instance associated with this app."
+    static let retrieveModelDownloader =
+      "Initialized with existing downloader instance associated with this app."
+    static let deleteModelDownloader = "Model downloader instance deleted due to app deletion."
     static let localModelFound = "Found local model on device."
-    static let backgroundModelDownloaded = "Downloaded latest model in the background."
     static let allLocalModelsFound = "Found and listed all local models."
-    static let modelDeleted = "Model deleted successfully."
-    static let retryDownload = "Retrying download."
     static let noLocalModelInfo = { (name: String) in
-      "No local model info for model file named: \(name)."
+      "No local model info for model named: \(name)."
+    }
+
+    static let noLocalModelFile = { (name: String) in
+      "No local model file for model named: \(name)."
     }
 
+    static let backgroundModelDownloaded = "Downloaded latest model in the background."
+    static let modelDeleted = "Model deleted successfully."
     static let mergingRequests = "Merging duplicate download requests."
+    static let retryDownload = "Retrying download."
   }
 
   /// Error descriptions.
   private enum ErrorDescription {
-    static let defaultAppNotConfigured =
-      "Default Firebase app not configured."
-    static let parseModelName = { (path: String) in
-      "List models failed due to unexpected model file name at \(path)."
-    }
-
+    static let defaultAppNotConfigured = "Default Firebase app not configured."
     static let invalidOptions = "Unable to retrieve project ID and/or API key for Firebase app."
-
-    static let listModelsFailed = { (error: Error) in
-      "Unable to list models, failed with error: \(error)"
+    static let modelDownloadFailed = { (error: Error) in
+      "Model download failed with error: \(error)"
     }
 
     static let modelNotFound = { (name: String) in
       "Model deletion failed due to no model found with name: \(name)"
     }
 
-    static let noLocalModelInfo = { (name: String) in
-      "List models failed due to no local model info for model file named: \(name)."
+    static let modelInfoRetrievalFailed = { (error: Error) in
+      "Model info retrieval failed with error: \(error)"
     }
 
-    static let modelDownloadFailed = { (error: Error) in
-      "Model download failed with error: \(error)"
+    static let backgroundModelDownload = "Failed to update model in background."
+    static let expiredModelInfo = "Unable to update expired model info."
+    static let listModelsFailed = { (error: Error) in
+      "Unable to list models, failed with error: \(error)"
     }
 
-    static let modelInfoRetrievalFailed = { (error: Error) in
-      "Model info retrieval failed with error: \(error)"
+    static let parseModelName = { (path: String) in
+      "List models failed due to unexpected model file name at \(path)."
     }
 
+    static let noLocalModelInfo = { (name: String) in
+      "List models failed due to no local model info for model file named: \(name)."
+    }
+
+    static let deletedLocalModelInfoOrFile =
+      "Model unavailable due to deleted local model info or model file."
     static let outdatedModelPath =
       "List models failed due to outdated model paths in local storage."
-    static let deletedLocalModelInfo =
-      "Model unavailable due to deleted local model info."
-    static let backgroundModelDownload =
-      "Failed to update model in background."
-    static let expiredModelInfo = "Unable to update expired model info."
+    static let modelDeletionFailed = { (error: Error) in
+      "Model deletion failed with error: \(error)"
+    }
+  }
+}
+
+/// Model downloader extension for testing.
+extension ModelDownloader {
+  /// Model downloader instance for testing.
+  static func modelDownloaderWithDefaults(_ defaults: UserDefaults,
+                                          app: FirebaseApp) -> ModelDownloader {
+    let downloader = ModelDownloader(app: app, defaults: defaults)
+    return downloader
   }
 }

+ 56 - 14
FirebaseMLModelDownloader/Sources/ModelFileManager.swift

@@ -15,40 +15,49 @@
 import Foundation
 
 /// Manager for common file operations.
-// TODO: Consider mocking this for tests?
 enum ModelFileManager {
-  private static let nameSeparator = "__"
+  /// Separator in model file name components.
+  private static let nameSeparator = "@@"
+
   private static let modelNamePrefix = "fbml_model"
+
+  private static let fileExtension = "tflite"
+
   private static let fileManager = FileManager.default
 
   /// Root directory of model file storage on device.
-  static var modelsDirectory: URL {
-    // TODO: Reconsider force unwrapping.
+  static var modelsDirectory: URL? {
     #if os(tvOS)
-      return fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first!
+      return fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first
     #else
-      return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
+      return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? nil
     #endif
   }
 
   /// Name for model file stored on device.
-  private static func getDownloadedModelFileName(appName: String, modelName: String) -> String {
+  static func getDownloadedModelFileName(appName: String, modelName: String) -> String {
     return [modelNamePrefix, appName, modelName].joined(separator: nameSeparator)
   }
 
   /// Model name from file path.
   static func getModelNameFromFilePath(_ path: URL) -> String? {
-    return path.lastPathComponent.components(separatedBy: nameSeparator).last
+    let components = path.deletingPathExtension().lastPathComponent
+      .components(separatedBy: nameSeparator)
+    // The file name should have prefix, app name, and model name.
+    if components.count == 3 {
+      return components.last
+    }
+    return nil
   }
 
   /// Full path of model file stored on device.
-  static func getDownloadedModelFilePath(appName: String, modelName: String) -> URL {
+  static func getDownloadedModelFileURL(appName: String, modelName: String) -> URL? {
     let modelFileName = ModelFileManager.getDownloadedModelFileName(
       appName: appName,
       modelName: modelName
     )
-    return ModelFileManager.modelsDirectory
-      .appendingPathComponent(modelFileName)
+    guard let modelsDir = ModelFileManager.modelsDirectory else { return nil }
+    return modelsDir.appendingPathComponent(modelFileName).appendingPathExtension(fileExtension)
   }
 
   /// Check if file is available at URL.
@@ -56,7 +65,7 @@ enum ModelFileManager {
     do {
       return try fileURL.checkResourceIsReachable()
     } catch {
-      /// File unreachable.
+      // File unreachable.
       return false
     }
   }
@@ -98,15 +107,23 @@ enum ModelFileManager {
     }
   }
 
+  /// Get all model files in models directory.
   static func contentsOfModelsDirectory() throws -> [URL] {
+    guard let modelsDir = modelsDirectory else {
+      throw DownloadedModelError
+        .internalError(
+          description: ModelFileManager.ErrorDescription.noApplicationSupportDirectory
+        )
+    }
     do {
       let directoryContents = try ModelFileManager.fileManager.contentsOfDirectory(
-        at: modelsDirectory,
+        at: modelsDir,
         includingPropertiesForKeys: nil,
         options: .skipsHiddenFiles
       )
       return directoryContents.filter { directoryItem in
-        !directoryItem.hasDirectoryPath
+        directoryItem.path.contains(ModelFileManager.modelNamePrefix) && !directoryItem
+          .hasDirectoryPath
       }
     } catch {
       throw DownloadedModelError
@@ -118,6 +135,24 @@ enum ModelFileManager {
   }
 }
 
+/// NOTE: Use for testing only.
+extension ModelFileManager {
+  static func emptyModelsDirectory() throws {
+    do {
+      let directoryContents = try ModelFileManager.contentsOfModelsDirectory()
+      for path in directoryContents {
+        try ModelFileManager.removeFile(at: path)
+      }
+    } catch {
+      throw DownloadedModelError
+        .internalError(
+          description: ModelFileManager.ErrorDescription
+            .emptyDirectory(error.localizedDescription)
+        )
+    }
+  }
+}
+
 /// Possible error messages during file management.
 extension ModelFileManager {
   /// Error descriptions.
@@ -145,5 +180,12 @@ extension ModelFileManager {
         return "Failed to check storage capacity on device."
       }
     }
+
+    static let noApplicationSupportDirectory =
+      "Could not locate Application Support directory for model storage."
+
+    static let emptyDirectory = { (error: String) in
+      "Could not empty models directory: \(error)"
+    }
   }
 }

+ 146 - 109
FirebaseMLModelDownloader/Sources/ModelInfoRetriever.swift

@@ -33,15 +33,19 @@ extension URLSession: ModelInfoRetrieverSession {
   }
 }
 
-/// Model info response object.
-private struct ModelInfoResponse: Codable {
-  var downloadURL: String
-  var urlExpiryTime: String
-  var size: String
+/// Model info result type.
+/// Downloading model info will return new model info only if it different from local model info.
+enum DownloadModelInfoResult {
+  case modelInfo(RemoteModelInfo)
+  case notModified
 }
 
-/// Properties for server response keys.
-private extension ModelInfoResponse {
+/// Model info response.
+private struct ModelInfoResponse: Codable {
+  let downloadURL: String
+  let urlExpiryTime: String
+  let size: String
+  /// Properties for server response keys.
   enum CodingKeys: String, CodingKey {
     case downloadURL = "downloadUri"
     case urlExpiryTime = "expireTime"
@@ -49,91 +53,80 @@ private extension ModelInfoResponse {
   }
 }
 
-/// Downloading model info will return new model info only if it different from local model info.
-enum DownloadModelInfoResult {
-  case notModified
-  case modelInfo(RemoteModelInfo)
-}
-
-/// Model info retrieval error codes.
-enum ModelInfoErrorCode {
-  case noError
-  case noHash
-  case httpError(code: Int)
-  case connectionFailed
-  case hashMismatch
-}
-
-/// Model info retriever for a model from local user defaults or server.
+/// Fetch model info for a model from server.
 class ModelInfoRetriever {
   /// Model name.
   private let modelName: String
-  /// URL session for model info request.
-  private let session: ModelInfoRetrieverSession
+
   /// Current Firebase app project ID.
   private let projectID: String
+
   /// Current Firebase app API key.
   private let apiKey: String
+
   /// Current Firebase app name.
   private let appName: String
-  /// Local model info to validate model freshness.
-  private let localModelInfo: LocalModelInfo?
-  /// Telemetry logger.
-  private let telemetryLogger: TelemetryLogger?
+
   /// Auth token provider.
   typealias AuthTokenProvider = (_ completion: @escaping (Result<String, DownloadError>) -> Void)
     -> Void
   private let authTokenProvider: AuthTokenProvider
 
+  /// URL session for model info request.
+  private let session: ModelInfoRetrieverSession
+
+  /// Local model info to validate model freshness.
+  private let localModelInfo: LocalModelInfo?
+
+  /// Telemetry logger.
+  private let telemetryLogger: TelemetryLogger?
+
   /// Associate model info retriever with current Firebase app, and model name.
   init(modelName: String,
        projectID: String,
        apiKey: String,
-       authTokenProvider: @escaping AuthTokenProvider,
        appName: String,
-       localModelInfo: LocalModelInfo? = nil,
+       authTokenProvider: @escaping AuthTokenProvider,
        session: ModelInfoRetrieverSession? = nil,
+       localModelInfo: LocalModelInfo? = nil,
        telemetryLogger: TelemetryLogger? = nil) {
     self.modelName = modelName
     self.projectID = projectID
     self.apiKey = apiKey
     self.appName = appName
-    self.localModelInfo = localModelInfo
     self.authTokenProvider = authTokenProvider
+    self.session = session ?? URLSession(configuration: .ephemeral)
+    self.localModelInfo = localModelInfo
     self.telemetryLogger = telemetryLogger
-
-    if let urlSession = session {
-      self.session = urlSession
-    } else {
-      self.session = URLSession(configuration: .ephemeral)
-    }
   }
 
+  /// Convenience init to use FirebaseInstallations as auth token provider.
   convenience init(modelName: String,
                    projectID: String,
                    apiKey: String,
-                   installations: Installations,
                    appName: String,
-                   localModelInfo: LocalModelInfo? = nil,
+                   installations: Installations,
                    session: ModelInfoRetrieverSession? = nil,
+                   localModelInfo: LocalModelInfo? = nil,
                    telemetryLogger: TelemetryLogger? = nil) {
     self.init(modelName: modelName,
               projectID: projectID,
               apiKey: apiKey,
-              authTokenProvider: ModelInfoRetriever.authTokenProvider(installation: installations),
               appName: appName,
-              localModelInfo: localModelInfo,
+              authTokenProvider: ModelInfoRetriever.authTokenProvider(installation: installations),
               session: session,
+              localModelInfo: localModelInfo,
               telemetryLogger: telemetryLogger)
   }
 
+  /// Auth token provider to validate credentials.
   private static func authTokenProvider(installation: Installations) -> AuthTokenProvider {
     return { completion in
       installation.authToken { tokenResult, error in
         guard let result = tokenResult
         else {
           completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
-              .authToken)))
+              .authTokenError)))
           return
         }
         completion(.success(result.authToken))
@@ -146,26 +139,26 @@ class ModelInfoRetriever {
     -> Void) {
     authTokenProvider { result in
       switch result {
-      /// Successfully received FIS token.
+      // Successfully received FIS token.
       case let .success(authToken):
         DeviceLogger.logEvent(level: .debug,
                               message: ModelInfoRetriever.DebugDescription
                                 .receivedAuthToken,
                               messageCode: .validAuthToken)
-        /// Get model info fetch URL with appropriate HTTP headers.
+        // Get model info fetch URL with appropriate HTTP headers.
         guard let request = self.getModelInfoFetchURLRequest(token: authToken) else {
           DeviceLogger.logEvent(level: .debug,
                                 message: ModelInfoRetriever.ErrorDescription
                                   .invalidModelInfoFetchURL,
                                 messageCode: .invalidModelInfoFetchURL)
           self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                           status: .failed,
-                                                           errorCode: .connectionFailed)
+                                                           status: .modelInfoRetrievalFailed,
+                                                           modelInfoErrorCode: .connectionFailed)
           completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
               .invalidModelInfoFetchURL)))
           return
         }
-        /// Download model info.
+        // Download model info.
         self.session.getModelInfo(with: request) {
           data, response, error in
           if let downloadError = error {
@@ -175,8 +168,8 @@ class ModelInfoRetriever {
                                   message: description,
                                   messageCode: .modelInfoRetrievalError)
             self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                             status: .failed,
-                                                             errorCode: .connectionFailed)
+                                                             status: .modelInfoRetrievalFailed,
+                                                             modelInfoErrorCode: .connectionFailed)
             completion(.failure(.internalError(description: description)))
           } else {
             guard let httpResponse = response as? HTTPURLResponse else {
@@ -185,8 +178,8 @@ class ModelInfoRetriever {
                                       .invalidHTTPResponse,
                                     messageCode: .invalidHTTPResponse)
               self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                               status: .failed,
-                                                               errorCode: .connectionFailed)
+                                                               status: .modelInfoRetrievalFailed,
+                                                               modelInfoErrorCode: .connectionFailed)
               completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
                   .invalidHTTPResponse)))
               return
@@ -203,8 +196,8 @@ class ModelInfoRetriever {
                                       message: ModelInfoRetriever.ErrorDescription.missingModelHash,
                                       messageCode: .missingModelHash)
                 self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                                 status: .failed,
-                                                                 errorCode: .noHash)
+                                                                 status: .modelInfoRetrievalFailed,
+                                                                 modelInfoErrorCode: .noHash)
                 completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
                     .missingModelHash)))
                 return
@@ -219,18 +212,19 @@ class ModelInfoRetriever {
                 return
               }
               do {
+                // Parse model info from HTTP response.
                 let modelInfo = try self.getRemoteModelInfoFromResponse(data, modelHash: modelHash)
                 DeviceLogger.logEvent(level: .debug,
                                       message: ModelInfoRetriever.DebugDescription
                                         .modelInfoDownloaded,
                                       messageCode: .modelInfoDownloaded)
                 self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                                 status: .updateAvailable,
-                                                                 errorCode: .noError)
+                                                                 status: .modelInfoRetrievalSucceeded,
+                                                                 modelInfoErrorCode: .noError)
                 completion(.success(.modelInfo(modelInfo)))
               } catch {
                 let description = ModelInfoRetriever.ErrorDescription
-                  .invalidmodelInfoJSON(error.localizedDescription)
+                  .invalidModelInfoJSON(error.localizedDescription)
                 DeviceLogger.logEvent(level: .debug,
                                       message: description,
                                       messageCode: .invalidModelInfoJSON)
@@ -239,8 +233,7 @@ class ModelInfoRetriever {
                 )
               }
             case 304:
-              /// For this case to occur, local model info has to already be available on device.
-              // TODO: Is this needed? Currently handles the case if model info disappears between request and response
+              // For this case to occur, local model info has to already be available on device.
               guard let localInfo = self.localModelInfo else {
                 DeviceLogger.logEvent(level: .debug,
                                       message: ModelInfoRetriever.ErrorDescription
@@ -259,20 +252,21 @@ class ModelInfoRetriever {
                                         .missingModelHash,
                                       messageCode: .noModelHash)
                 self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                                 status: .failed,
-                                                                 errorCode: .noHash)
+                                                                 status: .modelInfoRetrievalFailed,
+                                                                 modelInfoErrorCode: .noHash)
                 completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
                     .missingModelHash)))
                 return
               }
+              // Ensure that there is local model info on device with matching hash.
               guard modelHash == localInfo.modelHash else {
                 DeviceLogger.logEvent(level: .debug,
                                       message: ModelInfoRetriever.ErrorDescription
                                         .modelHashMismatch,
                                       messageCode: .modelHashMismatchError)
                 self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                                 status: .failed,
-                                                                 errorCode: .hashMismatch)
+                                                                 status: .modelInfoRetrievalFailed,
+                                                                 modelInfoErrorCode: .hashMismatch)
                 completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
                     .modelHashMismatch)))
                 return
@@ -281,59 +275,81 @@ class ModelInfoRetriever {
                                     message: ModelInfoRetriever.DebugDescription
                                       .modelInfoUnmodified,
                                     messageCode: .modelInfoUnmodified)
+              self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
+                                                               status: .modelInfoRetrievalSucceeded,
+                                                               modelInfoErrorCode: .noError)
               completion(.success(.notModified))
             case 400:
-              let description = ModelInfoRetriever.ErrorDescription.invalidModelName(self.modelName)
+              let errorMessage = self.getErrorFromResponse(data)
+              let description = errorMessage ?? ModelInfoRetriever.ErrorDescription
+                .invalidArgument(self.modelName)
               DeviceLogger.logEvent(level: .debug,
                                     message: description,
-                                    messageCode: .invalidModelName)
+                                    messageCode: .invalidArgument)
               self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                               status: .failed,
-                                                               errorCode: .httpError(code: httpResponse
+                                                               status: .modelInfoRetrievalFailed,
+                                                               modelInfoErrorCode: .httpError(code: httpResponse
                                                                  .statusCode))
               completion(.failure(.invalidArgument))
             case 401, 403:
+              // Error could be due to FirebaseML API not enabled for project, or invalid permissions.
+              let errorMessage = self.getErrorFromResponse(data)
+              let description = errorMessage ?? ModelInfoRetriever.ErrorDescription.permissionDenied
               DeviceLogger.logEvent(level: .debug,
-                                    message: ModelInfoRetriever.ErrorDescription.permissionDenied,
+                                    message: description,
                                     messageCode: .permissionDenied)
               self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                               status: .failed,
-                                                               errorCode: .httpError(code: httpResponse
+                                                               status: .modelInfoRetrievalFailed,
+                                                               modelInfoErrorCode: .httpError(code: httpResponse
                                                                  .statusCode))
               completion(.failure(.permissionDenied))
             case 404:
-              let description = ModelInfoRetriever.ErrorDescription.modelNotFound(self.modelName)
+              let errorMessage = self.getErrorFromResponse(data)
+              let description = errorMessage ?? ModelInfoRetriever.ErrorDescription
+                .modelNotFound(self.modelName)
               DeviceLogger.logEvent(level: .debug,
                                     message: description,
                                     messageCode: .modelNotFound)
               self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                               status: .failed,
-                                                               errorCode: .httpError(code: httpResponse
+                                                               status: .modelInfoRetrievalFailed,
+                                                               modelInfoErrorCode: .httpError(code: httpResponse
                                                                  .statusCode))
               completion(.failure(.notFound))
-            // TODO: Handle more http status codes
+            case 429:
+              let errorMessage = self.getErrorFromResponse(data)
+              let description = errorMessage ?? ModelInfoRetriever.ErrorDescription
+                .resourceExhausted
+              DeviceLogger.logEvent(level: .debug,
+                                    message: description,
+                                    messageCode: .resourceExhausted)
+              self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
+                                                               status: .modelInfoRetrievalFailed,
+                                                               modelInfoErrorCode: .httpError(code: httpResponse
+                                                                 .statusCode))
+              completion(.failure(.resourceExhausted))
             default:
-              let description = ModelInfoRetriever.ErrorDescription
+              let errorMessage = self.getErrorFromResponse(data)
+              let description = errorMessage ?? ModelInfoRetriever.ErrorDescription
                 .modelInfoRetrievalFailed(httpResponse.statusCode)
               DeviceLogger.logEvent(level: .debug,
                                     message: description,
                                     messageCode: .modelInfoRetrievalError)
               self.telemetryLogger?.logModelInfoRetrievalEvent(eventName: .modelDownload,
-                                                               status: .failed,
-                                                               errorCode: .httpError(code: httpResponse
+                                                               status: .modelInfoRetrievalFailed,
+                                                               modelInfoErrorCode: .httpError(code: httpResponse
                                                                  .statusCode))
               completion(.failure(.internalError(description: description)))
             }
           }
         }
-      /// FIS token error.
+      // Error retrieving auth token.
       case .failure:
         DeviceLogger.logEvent(level: .debug,
                               message: ModelInfoRetriever.ErrorDescription
-                                .authToken,
+                                .authTokenError,
                               messageCode: .authTokenError)
         completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
-            .authToken)))
+            .authTokenError)))
         return
       }
     }
@@ -344,7 +360,9 @@ class ModelInfoRetriever {
 extension ModelInfoRetriever {
   /// HTTP request headers.
   private static let fisTokenHTTPHeader = "x-goog-firebase-installations-auth"
+
   private static let hashMatchHTTPHeader = "if-none-match"
+
   private static let bundleIDHTTPHeader = "x-ios-bundle-identifier"
 
   /// HTTP response headers.
@@ -365,11 +383,10 @@ extension ModelInfoRetriever {
     guard let fetchURL = modelInfoFetchURL else { return nil }
     var request = URLRequest(url: fetchURL)
     request.httpMethod = "GET"
-    // TODO: Check if bundle ID needs to be part of the request header.
     let bundleID = Bundle.main.bundleIdentifier ?? ""
     request.setValue(bundleID, forHTTPHeaderField: ModelInfoRetriever.bundleIDHTTPHeader)
     request.setValue(token, forHTTPHeaderField: ModelInfoRetriever.fisTokenHTTPHeader)
-    /// Get model hash if local model info is available on device.
+    // Get model hash if local model info is available on device.
     if let modelInfo = localModelInfo {
       request.setValue(
         modelInfo.modelHash,
@@ -379,6 +396,18 @@ extension ModelInfoRetriever {
     return request
   }
 
+  /// Parse error message from server response.
+  private func getErrorFromResponse(_ data: Data?) -> String? {
+    if let data = data,
+      let responseJSON = try? JSONSerialization
+      .jsonObject(with: data, options: []) as? [String: Any],
+      let error = responseJSON["error"] as? [String: Any],
+      let errorMessage = error["message"] as? String {
+      return errorMessage
+    }
+    return nil
+  }
+
   /// Parse date from string - used to get download URL expiry time.
   private static func getDateFromString(_ strDate: String) -> Date? {
     if #available(iOS 11, macOS 10.13, macCatalyst 13.0, tvOS 11.0, watchOS 4.0, *) {
@@ -398,19 +427,20 @@ extension ModelInfoRetriever {
   /// Return model info created from server response.
   private func getRemoteModelInfoFromResponse(_ data: Data,
                                               modelHash: String) throws -> RemoteModelInfo {
-    let decoder = JSONDecoder()
-    guard let modelInfoJSON = try? decoder.decode(ModelInfoResponse.self, from: data) else {
+    guard let modelInfoJSON = try? JSONDecoder().decode(ModelInfoResponse.self, from: data) else {
       throw DownloadError
         .internalError(description: ModelInfoRetriever.ErrorDescription.decodeModelInfoResponse)
     }
-    // TODO: Possibly improve handling invalid server responses.
     guard let downloadURL = URL(string: modelInfoJSON.downloadURL) else {
       throw DownloadError
         .internalError(description: ModelInfoRetriever.ErrorDescription
           .invalidModelDownloadURL)
     }
-    let modelHash = modelHash
-    let size = Int(modelInfoJSON.size) ?? 0
+    guard let size = Int(modelInfoJSON.size) else {
+      throw DownloadError
+        .internalError(description: ModelInfoRetriever.ErrorDescription
+          .invalidModelSize)
+    }
     guard let expiryTime = ModelInfoRetriever.getDateFromString(modelInfoJSON.urlExpiryTime) else {
       throw DownloadError
         .internalError(description: ModelInfoRetriever.ErrorDescription
@@ -426,46 +456,48 @@ extension ModelInfoRetriever {
   }
 }
 
+/// Model info retrieval error codes.
+enum ModelInfoErrorCode {
+  case noError
+  case noHash
+  case httpError(code: Int)
+  case connectionFailed
+  case hashMismatch
+}
+
 /// Possible error messages for model info retrieval.
 extension ModelInfoRetriever {
   /// Debug descriptions.
   private enum DebugDescription {
-    static let receivedServerResponse = "Received a valid response from model info server."
     static let receivedAuthToken = "Generated valid auth token."
+    static let receivedServerResponse = "Received a valid response from model info server."
     static let modelInfoDownloaded = "Successfully downloaded model info."
     static let modelInfoUnmodified = "Local model info matches the latest on server."
   }
 
   /// Error descriptions.
   private enum ErrorDescription {
-    static let authToken = "Error retrieving auth token."
-    static let selfDeallocated = "Self deallocated."
-    static let missingModelHash = "Model hash missing in model info server response."
+    static let authTokenError = "Error retrieving auth token."
     static let invalidModelInfoFetchURL = "Unable to create URL to fetch model info."
     static let invalidHTTPResponse =
       "Could not get a valid HTTP response for model info retrieval."
-    static let invalidmodelInfoJSON = { (error: String) in
-      "Failed to parse model info: \(error)"
-    }
-
-    static let failedModelInfoRetrieval = { (error: String) in
-      "Failed to retrieve model info: \(error)"
-    }
-
-    static let unexpectedModelInfoDeletion =
-      "Model info was deleted unexpectedly."
-    static let serverResponse = { (errorCode: Int) in
+    static let serverResponseError = { (errorCode: Int) in
       "Server returned with HTTP error code: \(errorCode)."
     }
 
-    static let invalidModelName = { (name: String) in
-      "Invalid model name: \(name)"
-    }
-
+    static let missingModelHash = "Model hash missing in model info server response."
+    static let modelHashMismatch = "Unexpected model hash value."
+    static let unexpectedModelInfoDeletion = "Model info was deleted unexpectedly."
     static let modelNotFound = { (name: String) in
       "No model found with name: \(name)"
     }
 
+    static let invalidArgument = { (name: String) in
+      "Invalid argument for model name: \(name)"
+    }
+
+    static let permissionDenied = "Invalid or missing permissions to retrieve model info."
+    static let resourceExhausted = "Resource exhausted due to too many requests."
     static let modelInfoRetrievalFailed = { (code: Int) in
       "Model info retrieval failed with HTTP error code: \(code)"
     }
@@ -474,10 +506,15 @@ extension ModelInfoRetriever {
       "Unable to decode model info response from server."
     static let invalidModelDownloadURL =
       "Invalid model download URL from server."
+    static let invalidModelSize = "Invalid model size from server."
     static let invalidModelDownloadURLExpiryTime =
       "Invalid download URL expiry time from server."
-    static let modelHashMismatch = "Unexpected model hash value."
-    static let permissionDenied = "Invalid or missing permissions to retrieve model info."
-    static let anotherDownloadInProgress = "Model info download already in progress."
+    static let invalidModelInfoJSON = { (error: String) in
+      "Failed to parse model info: \(error)"
+    }
+
+    static let failedModelInfoRetrieval = { (error: String) in
+      "Failed to retrieve model info: \(error)"
+    }
   }
 }

+ 0 - 1
FirebaseMLModelDownloader/Sources/RemoteModelInfo.swift

@@ -13,7 +13,6 @@
 // limitations under the License.
 
 import Foundation
-import FirebaseCore
 
 /// Model info object with details about not-yet downloaded model.
 struct RemoteModelInfo {

+ 57 - 4
FirebaseMLModelDownloader/Sources/TelemetryLogger.swift

@@ -40,6 +40,14 @@ extension ModelOptions {
   }
 }
 
+/// Extension to build model delete log event.
+extension DeleteModelLogEvent {
+  mutating func setEvent(modelType: ModelInfo.ModelType = .custom, isSuccessful: Bool) {
+    self.modelType = modelType
+    self.isSuccessful = isSuccessful
+  }
+}
+
 /// Extension to build model download log event.
 extension ModelDownloadLogEvent {
   mutating func setEvent(status: DownloadStatus, errorCode: ErrorCode,
@@ -60,7 +68,7 @@ extension ModelDownloadLogEvent {
   }
 }
 
-/// Extension to build Firebase ML log event.
+/// Extension to build log event.
 extension FirebaseMlLogEvent {
   mutating func setEvent(eventName: EventName, systemInfo: SystemInfo,
                          modelDownloadLogEvent: ModelDownloadLogEvent) {
@@ -68,6 +76,13 @@ extension FirebaseMlLogEvent {
     self.systemInfo = systemInfo
     self.modelDownloadLogEvent = modelDownloadLogEvent
   }
+
+  mutating func setEvent(eventName: EventName, systemInfo: SystemInfo,
+                         deleteModelLogEvent: DeleteModelLogEvent) {
+    self.eventName = eventName
+    self.systemInfo = systemInfo
+    self.deleteModelLogEvent = deleteModelLogEvent
+  }
 }
 
 /// Data object for Firelog event.
@@ -81,7 +96,6 @@ class FBMLDataObject: NSObject, GDTCOREventDataObject {
   /// Encode Firelog event for transport.
   func transportBytes() -> Data {
     do {
-      // TODO: Should this be binary or json serialized?
       let data = try event.serializedData()
       return data
     } catch {
@@ -97,8 +111,10 @@ class FBMLDataObject: NSObject, GDTCOREventDataObject {
 class TelemetryLogger {
   /// Mapping ID for the log source.
   private let mappingID = "1326"
+
   /// Current Firebase app.
   private let app: FirebaseApp
+
   /// Transport for Firelog events.
   private let fllTransport: GDTCORTransport
 
@@ -125,16 +141,53 @@ class TelemetryLogger {
     fllTransport.sendTelemetryEvent(eventForTransport)
   }
 
+  /// Log model deleted event to Firelog.
+  func logModelDeletedEvent(eventName: EventName, isSuccessful: Bool) {
+    guard app.isDataCollectionDefaultEnabled else { return }
+    var systemInfo = SystemInfo()
+    let apiKey = app.options.apiKey
+    let projectID = app.options.projectID
+    systemInfo.setAppInfo(apiKey: apiKey, projectID: projectID)
+    var deleteModelLogEvent = DeleteModelLogEvent()
+    deleteModelLogEvent.setEvent(isSuccessful: isSuccessful)
+    var fbmlEvent = FirebaseMlLogEvent()
+    fbmlEvent.setEvent(
+      eventName: eventName,
+      systemInfo: systemInfo,
+      deleteModelLogEvent: deleteModelLogEvent
+    )
+    logModelEvent(event: fbmlEvent)
+  }
+
   /// Log model info retrieval event to Firelog.
   func logModelInfoRetrievalEvent(eventName: EventName,
                                   status: ModelDownloadLogEvent.DownloadStatus,
-                                  model: CustomModel? = nil, errorCode: ModelInfoErrorCode) {
+                                  modelInfoErrorCode: ModelInfoErrorCode) {
     guard app.isDataCollectionDefaultEnabled else { return }
     var systemInfo = SystemInfo()
     let apiKey = app.options.apiKey
     let projectID = app.options.projectID
     systemInfo.setAppInfo(apiKey: apiKey, projectID: projectID)
-    let modelDownloadLogEvent = ModelDownloadLogEvent()
+    var errorCode = ErrorCode()
+    switch modelInfoErrorCode {
+    case .noError:
+      errorCode = .noError
+    case .noHash:
+      errorCode = .modelInfoDownloadNoHash
+    case .connectionFailed:
+      errorCode = .modelInfoDownloadConnectionFailed
+    case .hashMismatch:
+      errorCode = .modelHashMismatch
+    case let .httpError(code):
+      errorCode = ErrorCode(rawValue: code) ?? .modelInfoDownloadUnsuccessfulHTTPStatus
+    }
+    let modelOptions = ModelOptions()
+    var modelDownloadLogEvent = ModelDownloadLogEvent()
+    modelDownloadLogEvent.setEvent(
+      status: status,
+      errorCode: errorCode,
+      modelOptions: modelOptions
+    )
     var fbmlEvent = FirebaseMlLogEvent()
     fbmlEvent.setEvent(
       eventName: eventName,

+ 145 - 30
FirebaseMLModelDownloader/Sources/proto/firebase_ml_log_sdk.pb.swift

@@ -39,6 +39,7 @@ enum EventName: SwiftProtobuf.Enum {
   case unknownEvent // = 0
   case modelDownload // = 100
   case modelUpdate // = 101
+  case remoteModelDeleteOnDevice // = 252
   case UNRECOGNIZED(Int)
 
   init() {
@@ -50,6 +51,7 @@ enum EventName: SwiftProtobuf.Enum {
     case 0: self = .unknownEvent
     case 100: self = .modelDownload
     case 101: self = .modelUpdate
+    case 252: self = .remoteModelDeleteOnDevice
     default: self = .UNRECOGNIZED(rawValue)
     }
   }
@@ -59,6 +61,7 @@ enum EventName: SwiftProtobuf.Enum {
     case .unknownEvent: return 0
     case .modelDownload: return 100
     case .modelUpdate: return 101
+    case .remoteModelDeleteOnDevice: return 252
     case .UNRECOGNIZED(let i): return i
     }
   }
@@ -73,6 +76,7 @@ extension EventName: CaseIterable {
     .unknownEvent,
     .modelDownload,
     .modelUpdate,
+    .remoteModelDeleteOnDevice,
   ]
 }
 
@@ -445,6 +449,23 @@ extension ModelDownloadLogEvent.DownloadStatus: CaseIterable {
 
 #endif  // swift(>=4.2)
 
+/// Information about deleting a downloaded model on device.
+struct DeleteModelLogEvent {
+  // SwiftProtobuf.Message conformance is added in an extension below. See the
+  // `Message` and `Message+*Additions` files in the SwiftProtobuf library for
+  // methods supported on all messages.
+
+  /// The type of the downloaded model requested to be deleted.
+  var modelType: ModelInfo.ModelType = .typeUnknown
+
+  /// Whether the downloaded model is deleted successfully.
+  var isSuccessful: Bool = false
+
+  var unknownFields = SwiftProtobuf.UnknownStorage()
+
+  init() {}
+}
+
 /// Main log event for FirebaseMl, that contains individual API events, like model
 /// download.
 /// NEXT ID: 44.
@@ -455,34 +476,45 @@ struct FirebaseMlLogEvent {
 
   /// Information about various parts of the system: app, Firebase, SDK.
   var systemInfo: SystemInfo {
-    get {return _systemInfo ?? SystemInfo()}
-    set {_systemInfo = newValue}
+    get {return _storage._systemInfo ?? SystemInfo()}
+    set {_uniqueStorage()._systemInfo = newValue}
   }
   /// Returns true if `systemInfo` has been explicitly set.
-  var hasSystemInfo: Bool {return self._systemInfo != nil}
+  var hasSystemInfo: Bool {return _storage._systemInfo != nil}
   /// Clears the value of `systemInfo`. Subsequent reads from it will return its default value.
-  mutating func clearSystemInfo() {self._systemInfo = nil}
+  mutating func clearSystemInfo() {_uniqueStorage()._systemInfo = nil}
 
   /// The event name.
-  var eventName: EventName = .unknownEvent
+  var eventName: EventName {
+    get {return _storage._eventName}
+    set {_uniqueStorage()._eventName = newValue}
+  }
 
-  /// Model downloading logs.
-  /// ==========================
+  /// Information about model download.
   var modelDownloadLogEvent: ModelDownloadLogEvent {
-    get {return _modelDownloadLogEvent ?? ModelDownloadLogEvent()}
-    set {_modelDownloadLogEvent = newValue}
+    get {return _storage._modelDownloadLogEvent ?? ModelDownloadLogEvent()}
+    set {_uniqueStorage()._modelDownloadLogEvent = newValue}
   }
   /// Returns true if `modelDownloadLogEvent` has been explicitly set.
-  var hasModelDownloadLogEvent: Bool {return self._modelDownloadLogEvent != nil}
+  var hasModelDownloadLogEvent: Bool {return _storage._modelDownloadLogEvent != nil}
   /// Clears the value of `modelDownloadLogEvent`. Subsequent reads from it will return its default value.
-  mutating func clearModelDownloadLogEvent() {self._modelDownloadLogEvent = nil}
+  mutating func clearModelDownloadLogEvent() {_uniqueStorage()._modelDownloadLogEvent = nil}
+
+  /// Information about deleting a downloaded model.
+  var deleteModelLogEvent: DeleteModelLogEvent {
+    get {return _storage._deleteModelLogEvent ?? DeleteModelLogEvent()}
+    set {_uniqueStorage()._deleteModelLogEvent = newValue}
+  }
+  /// Returns true if `deleteModelLogEvent` has been explicitly set.
+  var hasDeleteModelLogEvent: Bool {return _storage._deleteModelLogEvent != nil}
+  /// Clears the value of `deleteModelLogEvent`. Subsequent reads from it will return its default value.
+  mutating func clearDeleteModelLogEvent() {_uniqueStorage()._deleteModelLogEvent = nil}
 
   var unknownFields = SwiftProtobuf.UnknownStorage()
 
   init() {}
 
-  fileprivate var _systemInfo: SystemInfo? = nil
-  fileprivate var _modelDownloadLogEvent: ModelDownloadLogEvent? = nil
+  fileprivate var _storage = _StorageClass.defaultInstance
 }
 
 // MARK: - Code below here is support for the SwiftProtobuf runtime.
@@ -492,6 +524,7 @@ extension EventName: SwiftProtobuf._ProtoNameProviding {
     0: .same(proto: "UNKNOWN_EVENT"),
     100: .same(proto: "MODEL_DOWNLOAD"),
     101: .same(proto: "MODEL_UPDATE"),
+    252: .same(proto: "REMOTE_MODEL_DELETE_ON_DEVICE"),
   ]
 }
 
@@ -737,12 +770,11 @@ extension ModelDownloadLogEvent.DownloadStatus: SwiftProtobuf._ProtoNameProvidin
   ]
 }
 
-extension FirebaseMlLogEvent: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
-  static let protoMessageName: String = "FirebaseMlLogEvent"
+extension DeleteModelLogEvent: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
+  static let protoMessageName: String = "DeleteModelLogEvent"
   static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
-    1: .standard(proto: "system_info"),
-    2: .standard(proto: "event_name"),
-    3: .standard(proto: "model_download_log_event"),
+    1: .standard(proto: "model_type"),
+    2: .standard(proto: "is_successful"),
   ]
 
   mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@@ -751,31 +783,114 @@ extension FirebaseMlLogEvent: SwiftProtobuf.Message, SwiftProtobuf._MessageImple
       // allocates stack space for every case branch when no optimizations are
       // enabled. https://github.com/apple/swift-protobuf/issues/1034
       switch fieldNumber {
-      case 1: try { try decoder.decodeSingularMessageField(value: &self._systemInfo) }()
-      case 2: try { try decoder.decodeSingularEnumField(value: &self.eventName) }()
-      case 3: try { try decoder.decodeSingularMessageField(value: &self._modelDownloadLogEvent) }()
+      case 1: try { try decoder.decodeSingularEnumField(value: &self.modelType) }()
+      case 2: try { try decoder.decodeSingularBoolField(value: &self.isSuccessful) }()
       default: break
       }
     }
   }
 
   func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
-    if let v = self._systemInfo {
-      try visitor.visitSingularMessageField(value: v, fieldNumber: 1)
+    if self.modelType != .typeUnknown {
+      try visitor.visitSingularEnumField(value: self.modelType, fieldNumber: 1)
+    }
+    if self.isSuccessful != false {
+      try visitor.visitSingularBoolField(value: self.isSuccessful, fieldNumber: 2)
+    }
+    try unknownFields.traverse(visitor: &visitor)
+  }
+
+  static func ==(lhs: DeleteModelLogEvent, rhs: DeleteModelLogEvent) -> Bool {
+    if lhs.modelType != rhs.modelType {return false}
+    if lhs.isSuccessful != rhs.isSuccessful {return false}
+    if lhs.unknownFields != rhs.unknownFields {return false}
+    return true
+  }
+}
+
+extension FirebaseMlLogEvent: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
+  static let protoMessageName: String = "FirebaseMlLogEvent"
+  static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
+    1: .standard(proto: "system_info"),
+    2: .standard(proto: "event_name"),
+    3: .standard(proto: "model_download_log_event"),
+    40: .standard(proto: "delete_model_log_event"),
+  ]
+
+  fileprivate class _StorageClass {
+    var _systemInfo: SystemInfo? = nil
+    var _eventName: EventName = .unknownEvent
+    var _modelDownloadLogEvent: ModelDownloadLogEvent? = nil
+    var _deleteModelLogEvent: DeleteModelLogEvent? = nil
+
+    static let defaultInstance = _StorageClass()
+
+    private init() {}
+
+    init(copying source: _StorageClass) {
+      _systemInfo = source._systemInfo
+      _eventName = source._eventName
+      _modelDownloadLogEvent = source._modelDownloadLogEvent
+      _deleteModelLogEvent = source._deleteModelLogEvent
+    }
+  }
+
+  fileprivate mutating func _uniqueStorage() -> _StorageClass {
+    if !isKnownUniquelyReferenced(&_storage) {
+      _storage = _StorageClass(copying: _storage)
     }
-    if self.eventName != .unknownEvent {
-      try visitor.visitSingularEnumField(value: self.eventName, fieldNumber: 2)
+    return _storage
+  }
+
+  mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
+    _ = _uniqueStorage()
+    try withExtendedLifetime(_storage) { (_storage: _StorageClass) in
+      while let fieldNumber = try decoder.nextFieldNumber() {
+        // The use of inline closures is to circumvent an issue where the compiler
+        // allocates stack space for every case branch when no optimizations are
+        // enabled. https://github.com/apple/swift-protobuf/issues/1034
+        switch fieldNumber {
+        case 1: try { try decoder.decodeSingularMessageField(value: &_storage._systemInfo) }()
+        case 2: try { try decoder.decodeSingularEnumField(value: &_storage._eventName) }()
+        case 3: try { try decoder.decodeSingularMessageField(value: &_storage._modelDownloadLogEvent) }()
+        case 40: try { try decoder.decodeSingularMessageField(value: &_storage._deleteModelLogEvent) }()
+        default: break
+        }
+      }
     }
-    if let v = self._modelDownloadLogEvent {
-      try visitor.visitSingularMessageField(value: v, fieldNumber: 3)
+  }
+
+  func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
+    try withExtendedLifetime(_storage) { (_storage: _StorageClass) in
+      if let v = _storage._systemInfo {
+        try visitor.visitSingularMessageField(value: v, fieldNumber: 1)
+      }
+      if _storage._eventName != .unknownEvent {
+        try visitor.visitSingularEnumField(value: _storage._eventName, fieldNumber: 2)
+      }
+      if let v = _storage._modelDownloadLogEvent {
+        try visitor.visitSingularMessageField(value: v, fieldNumber: 3)
+      }
+      if let v = _storage._deleteModelLogEvent {
+        try visitor.visitSingularMessageField(value: v, fieldNumber: 40)
+      }
     }
     try unknownFields.traverse(visitor: &visitor)
   }
 
   static func ==(lhs: FirebaseMlLogEvent, rhs: FirebaseMlLogEvent) -> Bool {
-    if lhs._systemInfo != rhs._systemInfo {return false}
-    if lhs.eventName != rhs.eventName {return false}
-    if lhs._modelDownloadLogEvent != rhs._modelDownloadLogEvent {return false}
+    if lhs._storage !== rhs._storage {
+      let storagesAreEqual: Bool = withExtendedLifetime((lhs._storage, rhs._storage)) { (_args: (_StorageClass, _StorageClass)) in
+        let _storage = _args.0
+        let rhs_storage = _args.1
+        if _storage._systemInfo != rhs_storage._systemInfo {return false}
+        if _storage._eventName != rhs_storage._eventName {return false}
+        if _storage._modelDownloadLogEvent != rhs_storage._modelDownloadLogEvent {return false}
+        if _storage._deleteModelLogEvent != rhs_storage._deleteModelLogEvent {return false}
+        return true
+      }
+      if !storagesAreEqual {return false}
+    }
     if lhs.unknownFields != rhs.unknownFields {return false}
     return true
   }

+ 12 - 3
FirebaseMLModelDownloader/Sources/proto/firebase_ml_log_sdk.proto

@@ -67,7 +67,7 @@ enum EventName {
   UNKNOWN_EVENT = 0;
   MODEL_DOWNLOAD = 100;
   MODEL_UPDATE = 101;
-
+  REMOTE_MODEL_DELETE_ON_DEVICE = 252;
 }
 
 // A list of error codes for various components of the system. For model downloading, the
@@ -192,6 +192,13 @@ message ModelDownloadLogEvent {
   int64 download_failure_status = 6;
 }
 
+// Information about deleting a downloaded model on device.
+message DeleteModelLogEvent {
+  // The type of the downloaded model requested to be deleted.
+  ModelInfo.ModelType model_type = 1;
+  // Whether the downloaded model is deleted successfully.
+  bool is_successful = 2;
+}
 
 // Main log event for FirebaseMl, that contains individual API events, like model
 // download.
@@ -203,7 +210,9 @@ message FirebaseMlLogEvent {
   // The event name.
   EventName event_name = 2;
 
-  // Model downloading logs.
-  // ==========================
+  // Information about model download.
   ModelDownloadLogEvent model_download_log_event = 3;
+
+  // Information about deleting a downloaded model.
+  DeleteModelLogEvent delete_model_log_event = 40;
 }

+ 89 - 80
FirebaseMLModelDownloader/Tests/Integration/ModelDownloaderIntegrationTests.swift

@@ -21,7 +21,6 @@ extension UserDefaults {
   /// Returns a new cleared instance of user defaults.
   static func createTestInstance(testName: String) -> UserDefaults {
     let suiteName = "com.google.firebase.ml.test.\(testName)"
-    // TODO: reconsider force unwrapping
     let defaults = UserDefaults(suiteName: suiteName)!
     defaults.removePersistentDomain(forName: suiteName)
     return defaults
@@ -30,15 +29,14 @@ extension UserDefaults {
   /// Returns the existing user defaults instance.
   static func getTestInstance(testName: String) -> UserDefaults {
     let suiteName = "com.google.firebase.ml.test.\(testName)"
-    // TODO: reconsider force unwrapping
     return UserDefaults(suiteName: suiteName)!
   }
 }
 
-// TODO: Use FirebaseApp internal init for testApp
 final class ModelDownloaderIntegrationTests: XCTestCase {
   override class func setUp() {
     super.setUp()
+    // TODO: Use FirebaseApp internal for test app.
     let bundle = Bundle(for: self)
     if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
       let options = FirebaseOptions(contentsOfFile: plistPath) {
@@ -49,6 +47,14 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     FirebaseConfiguration.shared.setLoggerLevel(.debug)
   }
 
+  override func setUp() {
+    do {
+      try ModelFileManager.emptyModelsDirectory()
+    } catch {
+      XCTFail("Could not empty models directory.")
+    }
+  }
+
   /// Test to download model info - makes an actual network call.
   func testDownloadModelInfo() {
     guard let testApp = FirebaseApp.app() else {
@@ -56,17 +62,17 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       return
     }
 
+    let testName = String(#function.dropLast(2))
     let testModelName = "pose-detection"
 
     let modelInfoRetriever = ModelInfoRetriever(
       modelName: testModelName,
       projectID: testApp.options.projectID!,
       apiKey: testApp.options.apiKey!,
-      installations: Installations.installations(app: testApp),
-      appName: testApp.name
+      appName: testApp.name, installations: Installations.installations(app: testApp)
     )
 
-    let downloadExpectation = expectation(description: "Wait for model info to download.")
+    let modelInfoDownloadExpectation = expectation(description: "Wait for model info to download.")
     modelInfoRetriever.downloadModelInfo(completion: { result in
       switch result {
       case let .success(modelInfoResult):
@@ -76,9 +82,9 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
           XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
           XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
           XCTAssertGreaterThan(modelInfo.size, 0)
-          let localModelInfo = LocalModelInfo(from: modelInfo, path: "mock-valid-path")
+          let localModelInfo = LocalModelInfo(from: modelInfo)
           localModelInfo.writeToDefaults(
-            .createTestInstance(testName: #function),
+            .createTestInstance(testName: testName),
             appName: testApp.name
           )
         case .notModified:
@@ -88,13 +94,13 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
         XCTAssertNotNil(error)
         XCTFail("Failed to retrieve model info - \(error)")
       }
-      downloadExpectation.fulfill()
+      modelInfoDownloadExpectation.fulfill()
     })
 
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [modelInfoDownloadExpectation], timeout: 5)
 
     if let localInfo = LocalModelInfo(
-      fromDefaults: .getTestInstance(testName: #function),
+      fromDefaults: .getTestInstance(testName: testName),
       name: testModelName,
       appName: testApp.name
     ) {
@@ -116,12 +122,12 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       modelName: testModelName,
       projectID: testApp.options.projectID!,
       apiKey: testApp.options.apiKey!,
-      installations: Installations.installations(app: testApp),
-      appName: testApp.name,
+      appName: testApp.name, installations: Installations.installations(app: testApp),
       localModelInfo: localInfo
     )
 
-    let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
+    let modelInfoRetrieveExpectation =
+      expectation(description: "Wait for model info to be retrieved.")
     modelInfoRetriever.downloadModelInfo(completion: { result in
       switch result {
       case let .success(modelInfoResult):
@@ -134,10 +140,10 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
         XCTAssertNotNil(error)
         XCTFail("Failed to retrieve model info - \(error)")
       }
-      retrieveExpectation.fulfill()
+      modelInfoRetrieveExpectation.fulfill()
     })
 
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [modelInfoRetrieveExpectation], timeout: 5)
   }
 
   /// Test to download model file - makes an actual network call.
@@ -146,7 +152,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       XCTFail("Default app was not configured.")
       return
     }
-    let testName = #function.dropLast(2)
+    let testName = String(#function.dropLast(2))
     let testModelName = "\(testName)-test-model"
     let urlString =
       "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
@@ -161,7 +167,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     )
 
     let conditions = ModelDownloadConditions()
-    let expectation = self.expectation(description: "Wait for model to download.")
+    let downloadExpectation = expectation(description: "Wait for model to download.")
     let downloader = ModelFileDownloader(conditions: conditions)
     let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
       XCTAssertLessThanOrEqual(progress, 1)
@@ -170,33 +176,30 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     let taskCompletion: ModelDownloadTask.Completion = { result in
       switch result {
       case let .success(model):
-        guard let modelPath = URL(string: model.path) else {
-          XCTFail("Invalid or empty model path.")
-          return
-        }
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
-        /// Remove downloaded model file.
+        let modelURL = URL(fileURLWithPath: model.path)
+        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
+        // Remove downloaded model file.
         do {
-          try ModelFileManager.removeFile(at: modelPath)
+          try ModelFileManager.removeFile(at: modelURL)
         } catch {
           XCTFail("Model removal failed - \(error)")
         }
       case let .failure(error):
         XCTFail("Error: \(error)")
       }
-      expectation.fulfill()
+      downloadExpectation.fulfill()
     }
     let modelDownloadManager = ModelDownloadTask(
       remoteModelInfo: remoteModelInfo,
       appName: testApp.name,
-      defaults: .createTestInstance(testName: #function),
+      defaults: .createTestInstance(testName: testName),
       downloader: downloader,
       progressHandler: taskProgressHandler,
       completion: taskCompletion
     )
 
     modelDownloadManager.resume()
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [downloadExpectation], timeout: 5)
     XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
   }
 
@@ -230,22 +233,20 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
-        guard let modelPath = URL(string: model.path) else {
-          XCTFail("Invalid model path.")
-          return
-        }
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+        let modelURL = URL(fileURLWithPath: model.path)
+        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
       }
       latestModelExpectation.fulfill()
     }
 
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [latestModelExpectation], timeout: 5)
 
-    /// Test download type - local model.
-    downloadType = .localModel
-    let localModelExpectation = expectation(description: "Get local model.")
+    /// Test download type - local model update in background.
+    downloadType = .localModelUpdateInBackground
+    let backgroundModelExpectation =
+      expectation(description: "Get local model and update in background.")
 
     modelDownloader.getModel(
       name: testModelName,
@@ -258,22 +259,18 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
-        guard let modelPath = URL(string: model.path) else {
-          XCTFail("Invalid model path.")
-          return
-        }
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+        let modelURL = URL(fileURLWithPath: model.path)
+        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
       }
-      localModelExpectation.fulfill()
+      backgroundModelExpectation.fulfill()
     }
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [backgroundModelExpectation], timeout: 5)
 
-    /// Test download type - local model update in background.
-    downloadType = .localModelUpdateInBackground
-    let backgroundModelExpectation =
-      expectation(description: "Get local model and update in background.")
+    /// Test download type - local model.
+    downloadType = .localModel
+    let localModelExpectation = expectation(description: "Get local model.")
 
     modelDownloader.getModel(
       name: testModelName,
@@ -286,23 +283,20 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
-        guard let modelPath = URL(string: model.path) else {
-          XCTFail("Invalid model path.")
-          return
-        }
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
-        /// Remove downloaded model file.
+        let modelURL = URL(fileURLWithPath: model.path)
+        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
+        // Remove downloaded model file.
         do {
-          try ModelFileManager.removeFile(at: modelPath)
+          try ModelFileManager.removeFile(at: modelURL)
         } catch {
           XCTFail("Model removal failed - \(error)")
         }
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
       }
-      backgroundModelExpectation.fulfill()
+      localModelExpectation.fulfill()
     }
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [localModelExpectation], timeout: 5)
   }
 
   /// Delete previously downloaded model.
@@ -336,10 +330,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
-        guard let filePath = URL(string: model.path) else {
-          XCTFail("Invalid model path.")
-          return
-        }
+        let filePath = URL(fileURLWithPath: model.path)
         XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
@@ -347,15 +338,19 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       latestModelExpectation.fulfill()
     }
 
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [latestModelExpectation], timeout: 5)
 
+    let deleteExpectation = expectation(description: "Wait for model deletion.")
     modelDownloader.deleteDownloadedModel(name: testModelName) { result in
+      deleteExpectation.fulfill()
       switch result {
       case .success: break
       case let .failure(error):
         XCTFail("Failed to delete model - \(error)")
       }
     }
+
+    wait(for: [deleteExpectation], timeout: 5)
   }
 
   /// Test listing models in model directory.
@@ -388,10 +383,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
-        guard let filePath = URL(string: model.path) else {
-          XCTFail("Invalid model path.")
-          return
-        }
+        let filePath = URL(fileURLWithPath: model.path)
         XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
@@ -399,9 +391,11 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       latestModelExpectation.fulfill()
     }
 
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [latestModelExpectation], timeout: 5)
 
+    let listExpectation = expectation(description: "Wait for list models.")
     modelDownloader.listDownloadedModels { result in
+      listExpectation.fulfill()
       switch result {
       case let .success(models):
         XCTAssertGreaterThan(models.count, 0)
@@ -409,6 +403,8 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
         XCTFail("Failed to list models - \(error)")
       }
     }
+
+    wait(for: [listExpectation], timeout: 5)
   }
 
   /// Test logging telemetry event.
@@ -419,7 +415,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     }
 
     let testModelName = "image-classification"
-    let testName = #function
+    let testName = String(#function.dropLast(2))
 
     let conditions = ModelDownloadConditions()
     let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
@@ -427,7 +423,7 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       app: testApp
     )
 
-    let latestModelExpectation = expectation(description: "Test telemetry.")
+    let latestModelExpectation = expectation(description: "Test get model telemetry.")
 
     modelDownloader.getModel(
       name: testModelName,
@@ -447,14 +443,31 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
           model: model,
           downloadErrorCode: .noError
         )
-        let modelPath = URL(string: model.path)!
-        try? ModelFileManager.removeFile(at: modelPath)
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
       }
       latestModelExpectation.fulfill()
     }
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [latestModelExpectation], timeout: 5)
+
+    let deleteModelExpectation = expectation(description: "Test delete model telemetry.")
+    modelDownloader.deleteDownloadedModel(name: testModelName) { result in
+      switch result {
+      case .success(()):
+        guard let telemetryLogger = TelemetryLogger(app: testApp) else {
+          XCTFail("Could not initialize logger.")
+          return
+        }
+        // TODO: Remove actual logging and stub out with mocks.
+        telemetryLogger.logModelDeletedEvent(eventName: .remoteModelDeleteOnDevice,
+                                             isSuccessful: true)
+      case let .failure(error):
+        XCTFail("Failed to delete model - \(error)")
+      }
+      deleteModelExpectation.fulfill()
+    }
+
+    wait(for: [deleteModelExpectation], timeout: 5)
   }
 
   func testGetModelWithConditions() {
@@ -464,11 +477,10 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     }
 
     let testModelName = "pose-detection"
-    let testName = #function
+    let testName = String(#function.dropLast(2))
 
     // TODO: Figure out a better way to test this.
-    var conditions = ModelDownloadConditions()
-    conditions.allowsCellularAccess = false
+    let conditions = ModelDownloadConditions(allowsCellularAccess: false)
 
     let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
       .createTestInstance(testName: testName),
@@ -489,16 +501,13 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
-        guard let filePath = URL(string: model.path) else {
-          XCTFail("Invalid model path.")
-          return
-        }
+        let filePath = URL(fileURLWithPath: model.path)
         XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
       case let .failure(error):
         XCTFail("Failed to download model - \(error)")
       }
       latestModelExpectation.fulfill()
     }
-    waitForExpectations(timeout: 5, handler: nil)
+    wait(for: [latestModelExpectation], timeout: 5)
   }
 }

+ 353 - 127
FirebaseMLModelDownloader/Tests/Unit/ModelDownloaderUnitTests.swift

@@ -28,6 +28,7 @@ private enum MockOptions {
 // TODO: Create separate files for each class tested.
 final class ModelDownloaderUnitTests: XCTestCase {
   let fakeModelName = "fakeModelName"
+  let fakemodelPath = "fakemodelPath"
   let fakeModelHash = "fakeModelHash"
   let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
   let fakeFileURL = URL(string: "www.fake-model-file.com")!
@@ -35,6 +36,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
   let fakeModelSize = 20
   let fakeProjectID = "fakeProjectID"
   let fakeAPIKey = "fakeAPIKey"
+  let fakeAppName = "fakeAppName"
   var fakeModelJSON: String {
     """
     {
@@ -61,6 +63,14 @@ final class ModelDownloaderUnitTests: XCTestCase {
     FirebaseConfiguration.shared.setLoggerLevel(.debug)
   }
 
+  override func setUp() {
+    do {
+      try ModelFileManager.emptyModelsDirectory()
+    } catch {
+      XCTFail("Could not empty models directory.")
+    }
+  }
+
   override class func tearDown() {
     try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
     FirebaseApp.app()?.delete { _ in }
@@ -95,7 +105,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
   /// Test invalid model file deletion.
   func testDeleteInvalidModel() {
     let modelDownloader = ModelDownloader.modelDownloader()
-    modelDownloader.deleteDownloadedModel(name: "fake-model-name") { result in
+    modelDownloader.deleteDownloadedModel(name: fakeModelName) { result in
       switch result {
       case .success: XCTFail("Unexpected successful model deletion.")
       case let .failure(error):
@@ -104,79 +114,116 @@ final class ModelDownloaderUnitTests: XCTestCase {
     }
   }
 
-  /// Test listing models in model directory.
-  // TODO: Add unit test.
+  func testGetModelNameFromFilePath() {
+    var fakeURL = URL(fileURLWithPath: "modelDirectory/@@appName@@\(fakeModelName).tflite")
+    XCTAssertEqual(ModelFileManager.getModelNameFromFilePath(fakeURL), fakeModelName)
+    fakeURL = URL(fileURLWithPath: "modelDirectory/--appName--\(fakeModelName).tflite")
+    XCTAssertNil(ModelFileManager.getModelNameFromFilePath(fakeURL))
+  }
+
+  /// Test listing models.
   func testListModels() {
+    guard let testApp = FirebaseApp.app() else {
+      XCTFail("Default app was not configured.")
+      return
+    }
+    let testName = String(#function.dropLast(2))
+    let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
+      .createUnitTestInstance(testName: testName),
+      app: testApp
+    )
+
+    let tempModelFileURL = tempModelFile()
+    let completionExpectation = expectation(description: "Completion handler")
+
+    let localModelInfo = fakeLocalModelInfo()
+    localModelInfo.writeToDefaults(.getUnitTestInstance(testName: testName), appName: testApp.name)
+
+    modelDownloader.listDownloadedModels { result in
+      completionExpectation.fulfill()
+      switch result {
+      case let .success(models):
+        /// Expected success since model info was written to user defaults.
+        XCTAssertEqual(models.first?.path, tempModelFileURL.path)
+      case let .failure(error):
+        XCTFail("Error - \(error)")
+      }
+    }
+    wait(for: [completionExpectation], timeout: 0.5)
+    try? ModelFileManager.removeFile(at: tempModelFileURL)
+  }
+
+  /// Test invalid path in model directory.
+  func testListModelsInvalidPath() {
     let modelDownloader = ModelDownloader.modelDownloader()
+    let invalidTempModelFileURL = tempModelFile(invalid: true)
+    let completionExpectation = expectation(description: "Completion handler")
 
     modelDownloader.listDownloadedModels { result in
+      completionExpectation.fulfill()
       switch result {
-      case .success: break
-      case .failure: break
+      case .success:
+        XCTFail("Unexpected success of list models.")
+      case let .failure(error):
+        switch error {
+        case let .internalError(description):
+          XCTAssertTrue(description
+            .contains("List models failed due to unexpected model file name"))
+        default: XCTFail("Expected failure due to bad model name.")
+        }
       }
     }
+    wait(for: [completionExpectation], timeout: 0.5)
+    try? ModelFileManager.removeFile(at: invalidTempModelFileURL)
   }
 
-  func testGetModel() {
-    // This is an example of a functional test case.
-    // Use XCTAssert and related functions to verify your tests produce the correct
-    // results.
+  /// Test listing models without local info in user defaults.
+  func testListModelsNoLocalInfo() {
     guard let testApp = FirebaseApp.app() else {
       XCTFail("Default app was not configured.")
       return
     }
 
-    let modelDownloader = ModelDownloader.modelDownloader()
-    let modelDownloaderWithApp = ModelDownloader.modelDownloader(app: testApp)
-
-    /// These should point to the same instance.
-    XCTAssert(modelDownloader === modelDownloaderWithApp)
+    let testName = String(#function.dropLast(2))
+    let tempModelFileURL = tempModelFile()
+    let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
+      .createUnitTestInstance(testName: testName),
+      app: testApp
+    )
 
-    let conditions = ModelDownloadConditions()
+    let completionExpectation = expectation(description: "Completion handler")
 
-    // Download model w/ progress handler
-    modelDownloader.getModel(
-      name: "your_model_name",
-      downloadType: .latestModel,
-      conditions: conditions,
-      progressHandler: { progress in
-        // Handle progress
-      }
-    ) { result in
+    modelDownloader.listDownloadedModels { result in
+      completionExpectation.fulfill()
       switch result {
       case .success:
-        // Use model with your inference API
-        // let interpreter = Interpreter(modelPath: customModel.modelPath)
-        break
-      case .failure:
-        // Handle download error
-        break
+        XCTFail("Unexpected success of list models.")
+      case let .failure(error):
+        switch error {
+        /// Expected failure since no model info was written to user defaults.
+        case let .internalError(description):
+          XCTAssertTrue(description.contains("List models failed due to no local model info"))
+        default: XCTFail("Expected failure due to bad model name.")
+        }
       }
     }
+    wait(for: [completionExpectation], timeout: 0.5)
+    try? ModelFileManager.removeFile(at: tempModelFileURL)
+  }
 
-    // Access array of downloaded models
-    modelDownloaderWithApp.listDownloadedModels { result in
-      switch result {
-      case .success:
-        // Pick model(s) for further use
-        break
-      case .failure:
-        // Handle failure
-        break
-      }
-    }
+  /// Test how URL conversion behaves if there are spaces in the path.
+  func testURLConversion() {
+    /// Spaces in the string only convert to URL when using URL(fileURLWithPath: ).
+    let fakeURLWithSpace = URL(string: "file:///fakeDir1/fake%20Dir2/fakeFile")!
 
-    // Delete downloaded model
-    modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
-      switch result {
-      case .success():
-        // Apply any other clean up
-        break
-      case .failure:
-        // Handle failure
-        break
-      }
-    }
+    XCTAssertEqual(fakeURLWithSpace, URL(string: fakeURLWithSpace.absoluteString))
+    XCTAssertEqual(fakeURLWithSpace, URL(fileURLWithPath: fakeURLWithSpace.path))
+    XCTAssertNil(URL(string: fakeURLWithSpace.path))
+
+    /// Strings without spaces should work fine either way.
+    let fakeURLWithoutSpace = URL(string: "fakeDir1/fakeDir2/fakeFile")!
+    XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.absoluteString)!)
+    XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.path)!)
   }
 
   /// Test on-device logging.
@@ -189,10 +236,10 @@ final class ModelDownloaderUnitTests: XCTestCase {
   /// Compare proto serialization methods.
   func testTelemetryEncoding() {
     let fakeModel = CustomModel(
-      name: "fakeModelName",
+      name: fakeModelName,
       size: 10,
-      path: "fakeModelPath",
-      hash: "fakeModelHash"
+      path: fakemodelPath,
+      hash: fakeModelHash
     )
     var modelOptions = ModelOptions()
     modelOptions.setModelOptions(model: fakeModel)
@@ -338,6 +385,56 @@ final class ModelDownloaderUnitTests: XCTestCase {
     wait(for: [completionExpectation], timeout: 0.5)
   }
 
+  /// Get model info if model name does not exist.
+  func testGetModelInfoWith404() {
+    let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 404)
+    let localModelInfo = fakeLocalModelInfo(mismatch: false)
+    let modelInfoRetriever = fakeModelRetriever(
+      fakeSession: session,
+      fakeLocalModelInfo: localModelInfo
+    )
+    let completionExpectation = expectation(description: "Completion handler")
+
+    modelInfoRetriever.downloadModelInfo { result in
+      completionExpectation.fulfill()
+      switch result {
+      case let .success(fakemodelInfoResult):
+        switch fakemodelInfoResult {
+        case .modelInfo: XCTFail("Unexpected new model info from server.")
+        case .notModified: XCTFail("Expected failure since model name is not availabl.")
+        }
+      case let .failure(error):
+        XCTAssertEqual(error, .notFound)
+      }
+    }
+    wait(for: [completionExpectation], timeout: 0.5)
+  }
+
+  /// Get model info if resource exhausted due to too many requests.
+  func testGetModelInfoWith429() {
+    let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 429)
+    let localModelInfo = fakeLocalModelInfo(mismatch: false)
+    let modelInfoRetriever = fakeModelRetriever(
+      fakeSession: session,
+      fakeLocalModelInfo: localModelInfo
+    )
+    let completionExpectation = expectation(description: "Completion handler")
+
+    modelInfoRetriever.downloadModelInfo { result in
+      completionExpectation.fulfill()
+      switch result {
+      case let .success(fakemodelInfoResult):
+        switch fakemodelInfoResult {
+        case .modelInfo: XCTFail("Unexpected new model info from server.")
+        case .notModified: XCTFail("Expected failure since resource exhausted.")
+        }
+      case let .failure(error):
+        XCTAssertEqual(error, .resourceExhausted)
+      }
+    }
+    wait(for: [completionExpectation], timeout: 0.5)
+  }
+
   /// Get model if server returns a model file.
   func testModelDownloadWith200() {
     let tempFileURL = tempFile()
@@ -365,19 +462,18 @@ final class ModelDownloaderUnitTests: XCTestCase {
       completionExpectation.fulfill()
       switch result {
       case let .success(model):
-        XCTAssertEqual(model.name, self.fakeModelName)
-        XCTAssertEqual(model.size, self.fakeModelSize)
-        XCTAssertEqual(model.hash, self.fakeModelHash)
-        let modelPath = URL(string: model.path)!
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
-        try? self.deleteFile(at: modelPath)
+        self.checkModel(model: model)
+        self.deleteFile(at: model.path)
       case let .failure(error):
         XCTFail("Error - \(error)")
       }
     }
+    let testName = String(#function.dropLast(2))
     let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
                                               appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
+                                              defaults: .createUnitTestInstance(
+                                                testName: testName
+                                              ),
                                               downloader: downloader,
                                               progressHandler: taskProgressHandler,
                                               completion: taskCompletion)
@@ -421,9 +517,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTAssertEqual(error, .expiredDownloadURL)
       }
     }
+    let testName = String(#function.dropLast(2))
     let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
                                               appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
+                                              defaults: .createUnitTestInstance(
+                                                testName: testName
+                                              ),
                                               downloader: downloader,
                                               completion: taskCompletion)
     modelDownloadTask.resume()
@@ -461,9 +560,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTAssertEqual(error, .invalidArgument)
       }
     }
+    let testName = String(#function.dropLast(2))
     let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
                                               appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
+                                              defaults: .createUnitTestInstance(
+                                                testName: testName
+                                              ),
                                               downloader: downloader,
                                               completion: taskCompletion)
     modelDownloadTask.resume()
@@ -477,6 +579,41 @@ final class ModelDownloaderUnitTests: XCTestCase {
     try? ModelFileManager.removeFile(at: tempFileURL)
   }
 
+  /// Get model info if Firebase ML API is not enabled or permission denied.
+  func testGetModelInfoWith403() {
+    let errorMessage =
+      """
+      {
+        "error":
+          {
+            "message":"API Disabled."
+          }
+      }
+      """
+    let session = fakeModelInfoSessionWithURL(
+      fakeDownloadURL,
+      statusCode: 403,
+      customJSON: errorMessage
+    )
+    let localModelInfo = fakeLocalModelInfo(mismatch: false)
+    let modelInfoRetriever = fakeModelRetriever(
+      fakeSession: session,
+      fakeLocalModelInfo: localModelInfo
+    )
+    let completionExpectation = expectation(description: "Completion handler")
+
+    modelInfoRetriever.downloadModelInfo { result in
+      completionExpectation.fulfill()
+      switch result {
+      case .success:
+        XCTFail("Unexpected new model info from server.")
+      case let .failure(error):
+        XCTAssertEqual(error, .permissionDenied)
+      }
+    }
+    wait(for: [completionExpectation], timeout: 0.5)
+  }
+
   /// Get model if multiple duplicate requests are made.
   func testModelDownloadWithMergeRequests() {
     let tempFileURL = tempFile()
@@ -501,19 +638,18 @@ final class ModelDownloaderUnitTests: XCTestCase {
       completionExpectation.fulfill()
       switch result {
       case let .success(model):
-        XCTAssertEqual(model.name, self.fakeModelName)
-        XCTAssertEqual(model.size, self.fakeModelSize)
-        XCTAssertEqual(model.hash, self.fakeModelHash)
-        let modelPath = URL(string: model.path)!
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+        self.checkModel(model: model)
       case let .failure(error):
         XCTFail("Error - \(error)")
       }
     }
 
+    let testName = String(#function.dropLast(2))
     let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
-                                              appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
+                                              appName: fakeAppName,
+                                              defaults: .createUnitTestInstance(
+                                                testName: testName
+                                              ),
                                               downloader: downloader,
                                               completion: taskCompletion)
 
@@ -522,13 +658,9 @@ final class ModelDownloaderUnitTests: XCTestCase {
         completionExpectation.fulfill()
         switch result {
         case let .success(model):
-          XCTAssertEqual(model.name, self.fakeModelName)
-          XCTAssertEqual(model.size, self.fakeModelSize)
-          XCTAssertEqual(model.hash, self.fakeModelHash)
-          let modelPath = URL(string: model.path)!
-          XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+          self.checkModel(model: model)
           if i == numberOfRequests {
-            try? self.deleteFile(at: modelPath)
+            self.deleteFile(at: model.path)
           }
         case let .failure(error):
           XCTFail("Error - \(error)")
@@ -578,13 +710,9 @@ final class ModelDownloaderUnitTests: XCTestCase {
         completionExpectation.fulfill()
         switch result {
         case let .success(model):
-          XCTAssertEqual(model.name, self.fakeModelName)
-          XCTAssertEqual(model.size, self.fakeModelSize)
-          XCTAssertEqual(model.hash, self.fakeModelHash)
-          let modelPath = URL(string: model.path)!
-          XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+          self.checkModel(model: model)
           if i == numberOfRequests {
-            try? self.deleteFile(at: modelPath)
+            self.deleteFile(at: model.path)
           }
         case let .failure(error):
           XCTFail("Error - \(error)")
@@ -626,9 +754,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTAssertEqual(error, .notFound)
       }
     }
+    let testName = String(#function.dropLast(2))
     let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
                                               appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
+                                              defaults: .createUnitTestInstance(
+                                                testName: testName
+                                              ),
                                               downloader: downloader,
                                               completion: taskCompletion)
     modelDownloadTask.resume()
@@ -642,6 +773,91 @@ final class ModelDownloaderUnitTests: XCTestCase {
     try? ModelFileManager.removeFile(at: tempFileURL)
   }
 
+  /// Get model if server returns an error due to model not found.
+  func testModelDownloadNetworkError() {
+    let tempFileURL = tempFile()
+    let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
+    let downloader = MockModelFileDownloader()
+    let fileDownloaderExpectation = expectation(description: "File downloader")
+    downloader.downloadFileHandler = { url in
+      fileDownloaderExpectation.fulfill()
+      XCTAssertEqual(url, self.fakeDownloadURL)
+    }
+
+    let completionExpectation = expectation(description: "Completion handler")
+    let taskCompletion: ModelDownloadTask.Completion = { result in
+      completionExpectation.fulfill()
+      switch result {
+      case .success: XCTFail("Unexpected successful model download.")
+      case let .failure(error):
+        XCTAssertEqual(error, .failedPrecondition)
+      }
+    }
+    let testName = String(#function.dropLast(2))
+    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
+                                              appName: "fakeAppName",
+                                              defaults: .createUnitTestInstance(
+                                                testName: testName
+                                              ),
+                                              downloader: downloader,
+                                              completion: taskCompletion)
+    modelDownloadTask.resume()
+    wait(for: [fileDownloaderExpectation], timeout: 0.1)
+
+    let offlineError = DownloadError.internalError(description: "Network appears to be offline.")
+    downloader
+      .completion?(.failure(FileDownloaderError.networkError(offlineError)))
+    wait(for: [completionExpectation], timeout: 0.5)
+
+    try? ModelFileManager.removeFile(at: tempFileURL)
+  }
+
+  /// Get model with a valid server response.
+  func testGetModelSuccessful() {
+    let modelDownloader = ModelDownloader.modelDownloader()
+    let conditions = ModelDownloadConditions()
+    let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
+    let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
+    let downloader = MockModelFileDownloader()
+    let tempFileURL = tempFile()
+
+    let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
+                                               statusCode: 200,
+                                               httpVersion: nil,
+                                               headerFields: nil)!
+
+    let fileDownloaderExpectation = expectation(description: "File downloader")
+    let completionExpectation = expectation(description: "Completion handler")
+
+    // Handles initial response.
+    downloader.downloadFileHandler = { url in
+      fileDownloaderExpectation.fulfill()
+      XCTAssertEqual(url, self.fakeDownloadURL)
+    }
+
+    modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
+                                         modelInfoRetriever: modelInfoRetriever,
+                                         downloader: downloader,
+                                         conditions: conditions) { result in
+      completionExpectation.fulfill()
+      switch result {
+      case let .success(model):
+        self.checkModel(model: model)
+        self.deleteFile(at: model.path)
+      case let .failure(error):
+        XCTFail("Error - \(error)")
+      }
+    }
+    wait(for: [fileDownloaderExpectation], timeout: 0.5)
+
+    downloader
+      .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
+                                                   fileURL: tempFileURL)))
+    wait(for: [completionExpectation], timeout: 0.5)
+
+    try? ModelFileManager.removeFile(at: tempFileURL)
+  }
+
   /// Get model if download url expired but retry was successful.
   func testGetModelSuccessfulRetry() {
     let modelDownloader = ModelDownloader.modelDownloader()
@@ -678,12 +894,8 @@ final class ModelDownloaderUnitTests: XCTestCase {
       switch result {
       // Retry successful - model returned.
       case let .success(model):
-        XCTAssertEqual(model.name, self.fakeModelName)
-        XCTAssertEqual(model.size, self.fakeModelSize)
-        XCTAssertEqual(model.hash, self.fakeModelHash)
-        let modelPath = URL(string: model.path)!
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
-        try? self.deleteFile(at: modelPath)
+        self.checkModel(model: model)
+        self.deleteFile(at: model.path)
       case let .failure(error):
         XCTFail("Error - \(error)")
       }
@@ -771,7 +983,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
                                                httpVersion: nil,
                                                headerFields: nil)!
 
-    let fileDownloaderExpectation = expectation(description: "File downloader before retry")
+    let fileDownloaderExpectation = expectation(description: "File downloader")
     let completionExpectation = expectation(description: "Completion handler")
 
     // Handles initial response.
@@ -786,7 +998,6 @@ final class ModelDownloaderUnitTests: XCTestCase {
                                          conditions: conditions) { result in
       completionExpectation.fulfill()
       switch result {
-      // Retry failed due to another expired url - error returned.
       case .success: XCTFail("Unexpected successful model download.")
       case let .failure(error):
         switch error {
@@ -974,12 +1185,8 @@ final class ModelDownloaderUnitTests: XCTestCase {
       switch result {
       // One retry failed but the next retry succeeded - model returned.
       case let .success(model):
-        XCTAssertEqual(model.name, self.fakeModelName)
-        XCTAssertEqual(model.size, self.fakeModelSize)
-        XCTAssertEqual(model.hash, self.fakeModelHash)
-        let modelPath = URL(string: model.path)!
-        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
-        try? self.deleteFile(at: modelPath)
+        self.checkModel(model: model)
+        self.deleteFile(at: model.path)
       case let .failure(error):
         XCTFail("Error - \(error)")
       }
@@ -1044,21 +1251,28 @@ class MockModelFileDownloader: FileDownloader {
 
 extension UserDefaults {
   /// Returns a new cleared instance of user defaults.
-  static func createTestInstance(testName: String) -> UserDefaults {
+  static func createUnitTestInstance(testName: String) -> UserDefaults {
     let suiteName = "com.google.firebase.ml.test.\(testName)"
-    // TODO: reconsider force unwrapping
     let defaults = UserDefaults(suiteName: suiteName)!
     defaults.removePersistentDomain(forName: suiteName)
     return defaults
   }
+
+  /// Returns the existing user defaults instance.
+  static func getUnitTestInstance(testName: String) -> UserDefaults {
+    let suiteName = "com.google.firebase.ml.test.\(testName)"
+    return UserDefaults(suiteName: suiteName)!
+  }
 }
 
 // MARK: - Helpers
 
 extension ModelDownloaderUnitTests {
-  func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int) -> MockModelInfoRetrieverSession {
+  func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int,
+                                   customJSON: String? = nil) -> MockModelInfoRetrieverSession {
     let fakeSession = MockModelInfoRetrieverSession()
-    fakeSession.data = fakeModelJSON.data(using: .utf8)
+    let fakeJSON = customJSON ?? fakeModelJSON
+    fakeSession.data = fakeJSON.data(using: .utf8)
     fakeSession.response = HTTPURLResponse(
       url: url,
       statusCode: statusCode,
@@ -1071,19 +1285,25 @@ extension ModelDownloaderUnitTests {
 
   func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
                           fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
-    // TODO: Replace with a fake one so we can check that is was used correctly by the download task.
+    // TODO: Replace with fake to check if it was used correctly by the download task.
     let modelInfoRetriever = ModelInfoRetriever(
-      modelName: "fakeModelName",
-      projectID: "fakeProjectID",
-      apiKey: "fakeAPIKey",
-      authTokenProvider: successAuthTokenProvider,
-      appName: "fakeAppName",
-      localModelInfo: fakeLocalModelInfo,
-      session: fakeSession
+      modelName: fakeModelName,
+      projectID: fakeProjectID,
+      apiKey: fakeAPIKey,
+      appName: fakeAppName, authTokenProvider: successAuthTokenProvider,
+      session: fakeSession, localModelInfo: fakeLocalModelInfo
     )
     return modelInfoRetriever
   }
 
+  func checkModel(model: CustomModel) {
+    XCTAssertEqual(model.name, fakeModelName)
+    XCTAssertEqual(model.size, fakeModelSize)
+    XCTAssertEqual(model.hash, fakeModelHash)
+    let modelURL = URL(fileURLWithPath: model.path)
+    XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
+  }
+
   func tempFile(name: String = "fake-model-file.tmp") -> URL {
     let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
                                isDirectory: true)
@@ -1093,32 +1313,38 @@ extension ModelDownloaderUnitTests {
     return tempFileURL
   }
 
-  func deleteFile(at url: URL) throws {
-    try ModelFileManager.removeFile(at: url)
+  func tempModelFile(invalid: Bool = false) -> URL {
+    let modelDirectoryURL = ModelFileManager.modelsDirectory!
+    let modelFileName = invalid ? "fbml_model_fake-model-file.tmp" :
+      "fbml_model@@__FIRAPP_DEFAULT@@\(fakeModelName)"
+    let tempFileURL = modelDirectoryURL.appendingPathComponent(modelFileName)
+      .appendingPathExtension("tflite")
+    let tempData: Data = "fakeModelData".data(using: .utf8)!
+    try? tempData.write(to: tempFileURL)
+    return tempFileURL
   }
 
-  func fakeLocalModelInfo(mismatch: Bool) -> LocalModelInfo {
-    var hash = "fakeModelHash"
-    if mismatch {
-      hash = "wrongModelHash"
+  func deleteFile(at path: String) {
+    let url = URL(fileURLWithPath: path)
+    do {
+      try ModelFileManager.removeFile(at: url)
+    } catch {
+      XCTFail("Failed to delete file at \(path).")
     }
+  }
+
+  func fakeLocalModelInfo(mismatch: Bool = false) -> LocalModelInfo {
+    let hash = mismatch ? "wrongModelHash" : fakeModelHash
     return LocalModelInfo(
-      name: "fakeModelName",
+      name: fakeModelName,
       modelHash: hash,
-      size: 20,
-      path: "fakeModelPath"
+      size: 20
     )
   }
 
   func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
-    var expiryTime = Date()
-    var size = fakeModelSize
-    if !expired {
-      expiryTime = expiryTime.addingTimeInterval(1000)
-    }
-    if oversized {
-      size = Int(Int64.max) / 4
-    }
+    let expiryTime = expired ? Date() : Date().addingTimeInterval(1000)
+    let size = oversized ? Int(Int64.max) / 4 : fakeModelSize
     return RemoteModelInfo(name: fakeModelName,
                            downloadURL: fakeDownloadURL,
                            modelHash: fakeModelHash,