Quellcode durchsuchen

ML model downloader fixes (#7515)

* listDownloadedModels: don't cross-check URLs

* Store models in a subfolder.

* Test completion and progress handlers being called on the main thread.

* More tests for callbacks on the main queue.

* Fix public API handlers to be called on the main queue.

* .gitignore ML GoogleService-Info.plist files

* style

* getModel: fail fast when modelName is empty string
Maksym Malyhin vor 5 Jahren
Ursprung
Commit
0341e9ccfd

+ 1 - 1
.gitignore

@@ -36,7 +36,7 @@ FirebaseSegmentation/Tests/Sample/GoogleService-Info.plist
 
 # FirebaseMLModelDownloader integration tests GoogleService-Info.plist
 FirebaseMLModelDownloader/Tests/Integration/Resources/GoogleService-Info.plist
-FirebaseMLModelDownloader/Apps/Sample/Resources/GoogleService-Info.plist
+FirebaseMLModelDownloader/Apps/Sample/**/GoogleService-Info.plist
 
 # FirebasePerformance dev test App and integration tests GoogleService-Info.plist
 FirebasePerformance/**/GoogleService-Info.plist

+ 32 - 25
FirebaseMLModelDownloader/Sources/ModelDownloader.swift

@@ -56,10 +56,10 @@ public class ModelDownloader {
   /// DispatchQueue to manage download task dictionary.
   let taskSerialQueue = DispatchQueue(label: "downloadtask.serial.queue")
 
-  /// Handler that always runs on the main thread
-  let mainQueueHandler = { handler in
+  /// Re-dispatch a function on the main queue.
+  func asyncOnMainQueue(_ work: @autoclosure @escaping () -> Void) {
     DispatchQueue.main.async {
-      handler
+      work()
     }
   }
 
@@ -132,13 +132,18 @@ public class ModelDownloader {
                        conditions: ModelDownloadConditions,
                        progressHandler: ((Float) -> Void)? = nil,
                        completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
+    guard !modelName.isEmpty else {
+      asyncOnMainQueue(completion(.failure(.emptyModelName)))
+      return
+    }
+
     switch downloadType {
     case .localModel:
       if let localModel = getLocalModel(modelName: modelName) {
         DeviceLogger.logEvent(level: .debug,
                               message: ModelDownloader.DebugDescription.localModelFound,
                               messageCode: .localModelFound)
-        mainQueueHandler(completion(.success(localModel)))
+        asyncOnMainQueue(completion(.success(localModel)))
       } else {
         getRemoteModel(
           modelName: modelName,
@@ -153,7 +158,7 @@ public class ModelDownloader {
         DeviceLogger.logEvent(level: .debug,
                               message: ModelDownloader.DebugDescription.localModelFound,
                               messageCode: .localModelFound)
-        mainQueueHandler(completion(.success(localModel)))
+        asyncOnMainQueue(completion(.success(localModel)))
         telemetryLogger?.logModelDownloadEvent(
           eventName: .modelDownload,
           status: .scheduled,
@@ -225,7 +230,7 @@ public class ModelDownloader {
           DeviceLogger.logEvent(level: .debug,
                                 message: description,
                                 messageCode: .modelNameParseError)
-          mainQueueHandler(completion(.failure(.internalError(description: description))))
+          asyncOnMainQueue(completion(.failure(.internalError(description: description))))
           return
         }
         // Check if model information corresponding to model is stored in UserDefaults.
@@ -234,7 +239,7 @@ public class ModelDownloader {
           DeviceLogger.logEvent(level: .debug,
                                 message: description,
                                 messageCode: .noLocalModelInfo)
-          mainQueueHandler(completion(.failure(.internalError(description: description))))
+          asyncOnMainQueue(completion(.failure(.internalError(description: description))))
           return
         }
         // Ensure that local model path is as expected, and reachable.
@@ -242,11 +247,11 @@ public class ModelDownloader {
           appName: appName,
           modelName: modelName
         ),
-          url == modelURL, ModelFileManager.isFileReachable(at: modelURL) else {
+          ModelFileManager.isFileReachable(at: modelURL) else {
           DeviceLogger.logEvent(level: .debug,
                                 message: ModelDownloader.ErrorDescription.outdatedModelPath,
                                 messageCode: .outdatedModelPathError)
-          mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
+          asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
               .ErrorDescription.outdatedModelPath))))
           return
         }
@@ -262,12 +267,12 @@ public class ModelDownloader {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.ErrorDescription.listModelsFailed(error),
                             messageCode: .listModelsError)
-      mainQueueHandler(completion(.failure(error)))
+      asyncOnMainQueue(completion(.failure(error)))
     } catch {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.ErrorDescription.listModelsFailed(error),
                             messageCode: .listModelsError)
-      mainQueueHandler(completion(.failure(.internalError(description: error
+      asyncOnMainQueue(completion(.failure(.internalError(description: error
           .localizedDescription))))
     }
   }
@@ -290,7 +295,7 @@ public class ModelDownloader {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.ErrorDescription.modelNotFound(modelName),
                             messageCode: .modelNotFound)
-      mainQueueHandler(completion(.failure(.notFound)))
+      asyncOnMainQueue(completion(.failure(.notFound)))
       return
     }
     do {
@@ -305,7 +310,7 @@ public class ModelDownloader {
         eventName: .remoteModelDeleteOnDevice,
         isSuccessful: true
       )
-      mainQueueHandler(completion(.success(())))
+      asyncOnMainQueue(completion(.success(())))
     } catch let error as DownloadedModelError {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
@@ -314,7 +319,7 @@ public class ModelDownloader {
         eventName: .remoteModelDeleteOnDevice,
         isSuccessful: false
       )
-      mainQueueHandler(completion(.failure(error)))
+      asyncOnMainQueue(completion(.failure(error)))
     } catch {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
@@ -323,7 +328,7 @@ public class ModelDownloader {
         eventName: .remoteModelDeleteOnDevice,
         isSuccessful: false
       )
-      mainQueueHandler(completion(.failure(.internalError(description: error
+      asyncOnMainQueue(completion(.failure(.internalError(description: error
           .localizedDescription))))
     }
   }
@@ -417,28 +422,28 @@ extension ModelDownloader {
           // Progress handler for model file download.
           let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
             if let progressHandler = progressHandler {
-              self.mainQueueHandler(progressHandler(progress))
+              self.asyncOnMainQueue(progressHandler(progress))
             }
           }
           // Completion handler for model file download.
           let taskCompletion: ModelDownloadTask.Completion = { result in
             switch result {
             case let .success(model):
-              self.mainQueueHandler(completion(.success(model)))
+              self.asyncOnMainQueue(completion(.success(model)))
             case let .failure(error):
               switch error {
               case .notFound:
-                self.mainQueueHandler(completion(.failure(.notFound)))
+                self.asyncOnMainQueue(completion(.failure(.notFound)))
               case .invalidArgument:
-                self.mainQueueHandler(completion(.failure(.invalidArgument)))
+                self.asyncOnMainQueue(completion(.failure(.invalidArgument)))
               case .permissionDenied:
-                self.mainQueueHandler(completion(.failure(.permissionDenied)))
+                self.asyncOnMainQueue(completion(.failure(.permissionDenied)))
               // This is the error returned when model download URL has expired.
               case .expiredDownloadURL:
                 // Retry model info and model file download, if allowed.
                 guard self.numberOfRetries > 0 else {
                   self
-                    .mainQueueHandler(
+                    .asyncOnMainQueue(
                       completion(.failure(.internalError(description: ModelDownloader
                           .ErrorDescription
                           .expiredModelInfo)))
@@ -458,7 +463,7 @@ extension ModelDownloader {
                   completion: completion
                 )
               default:
-                self.mainQueueHandler(completion(.failure(error)))
+                self.asyncOnMainQueue(completion(.failure(error)))
               }
             }
             self.taskSerialQueue.async {
@@ -502,15 +507,15 @@ extension ModelDownloader {
           guard let localModel = self.getLocalModel(modelName: modelName) else {
             // This can only happen if either local model info or the model file was wiped out after model info request but before server response.
             self
-              .mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
+              .asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
                   .ErrorDescription.deletedLocalModelInfoOrFile))))
             return
           }
-          self.mainQueueHandler(completion(.success(localModel)))
+          self.asyncOnMainQueue(completion(.success(localModel)))
         }
       // Error retrieving model info.
       case let .failure(error):
-        self.mainQueueHandler(completion(.failure(error)))
+        self.asyncOnMainQueue(completion(.failure(error)))
       }
     }
   }
@@ -530,6 +535,8 @@ public enum DownloadError: Error, Equatable {
   case notEnoughSpace
   /// Malformed model name or Firebase app options.
   case invalidArgument
+  /// Model name is empty.
+  case emptyModelName
   /// Other errors with description.
   case internalError(description: String)
 }

+ 26 - 5
FirebaseMLModelDownloader/Sources/ModelFileManager.swift

@@ -27,11 +27,32 @@ enum ModelFileManager {
 
   /// Root directory of model file storage on device.
   static var modelsDirectory: URL? {
-    #if os(tvOS)
-      return fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first
-    #else
-      return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? nil
-    #endif
+    let rootDirOptional: URL? = {
+      #if os(tvOS)
+        return fileManager.urls(for: .cachesDirectory, in: .userDomainMask).first
+      #else
+        return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? nil
+      #endif
+    }()
+
+    guard let rootDirURL = rootDirOptional else {
+      return nil
+    }
+
+    let modelDirURL = rootDirURL.appendingPathComponent(
+      "com.firebase.FirebaseMLModelDownloader",
+      isDirectory: true
+    )
+
+    do {
+      if !fileManager.fileExists(atPath: modelDirURL.absoluteString) {
+        try fileManager.createDirectory(at: modelDirURL, withIntermediateDirectories: true)
+      }
+    } catch {
+      return nil
+    }
+
+    return modelDirURL
   }
 
   /// Name for model file stored on device.

+ 54 - 0
FirebaseMLModelDownloader/Tests/Integration/ModelDownloaderIntegrationTests.swift

@@ -226,10 +226,14 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       downloadType: downloadType,
       conditions: conditions,
       progressHandler: { progress in
+        XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
+
         XCTAssertLessThanOrEqual(progress, 1)
         XCTAssertGreaterThanOrEqual(progress, 0)
       }
     ) { result in
+      XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
+
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
@@ -299,6 +303,48 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     wait(for: [localModelExpectation], timeout: 5)
   }
 
+  func testGetModelWhenNameIsEmpty() {
+    guard let testApp = FirebaseApp.app() else {
+      XCTFail("Default app was not configured.")
+      return
+    }
+    let testName = String(#function.dropLast(2))
+    let emptyModelName = ""
+
+    let conditions = ModelDownloadConditions()
+    let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
+      .createTestInstance(testName: testName),
+      app: testApp
+    )
+
+    let completionExpectation = expectation(description: "getModel")
+    let progressExpectation = expectation(description: "progressHandler")
+    progressExpectation.isInverted = true
+
+    modelDownloader.getModel(
+      name: emptyModelName,
+      downloadType: .latestModel,
+      conditions: conditions,
+      progressHandler: { progress in
+        progressExpectation.fulfill()
+      }
+    ) { result in
+      XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
+
+      switch result {
+      case .failure(.emptyModelName):
+        // The expected error.
+        break
+
+      default:
+        XCTFail("Unexpected result: \(result)")
+      }
+      completionExpectation.fulfill()
+    }
+
+    wait(for: [completionExpectation, progressExpectation], timeout: 5)
+  }
+
   /// Delete previously downloaded model.
   func testDeleteModel() {
     guard let testApp = FirebaseApp.app() else {
@@ -323,10 +369,13 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       downloadType: downloadType,
       conditions: conditions,
       progressHandler: { progress in
+        XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
         XCTAssertLessThanOrEqual(progress, 1)
         XCTAssertGreaterThanOrEqual(progress, 0)
       }
     ) { result in
+      XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
+
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)
@@ -343,6 +392,8 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     let deleteExpectation = expectation(description: "Wait for model deletion.")
     modelDownloader.deleteDownloadedModel(name: testModelName) { result in
       deleteExpectation.fulfill()
+      XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
+
       switch result {
       case .success: break
       case let .failure(error):
@@ -494,10 +545,13 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       downloadType: .latestModel,
       conditions: conditions,
       progressHandler: { progress in
+        XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
         XCTAssertLessThanOrEqual(progress, 1)
         XCTAssertGreaterThanOrEqual(progress, 0)
       }
     ) { result in
+      XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
+
       switch result {
       case let .success(model):
         XCTAssertNotNil(model.path)