Просмотр исходного кода

Merge duplicate requests (#7422)

* Restores ML Pods after M77.

* Fix Package.swift

* Re-add catalyst to GHA workflow.

* Merge duplicate model download requests.

* Remove print statement.

* WORK IN PROGRESS - Test for merging requests

* Unit tests for merging requests.

* Fix unit test description.

* Uncomment debug logging
manjanac 5 лет назад
Родитель
Сommit
ef3b4d9c19

+ 1 - 0
FirebaseMLModelDownloader/Sources/DeviceLogger.swift

@@ -55,6 +55,7 @@ enum LoggerMessageCode: Int {
   case modelInfoDeleted
   case modelInfoDownloaded
   case modelInfoUnmodified
+  case mergeRequests
   case authTokenError
   case expiredModelInfo
   case modelDownloadError

+ 41 - 11
FirebaseMLModelDownloader/Sources/ModelDownloadTask.swift

@@ -17,8 +17,8 @@ import FirebaseCore
 
 /// Possible states of model downloading.
 enum ModelDownloadStatus {
-  case notStarted
-  case inProgress
+  case ready
+  case downloading
   case complete
 }
 
@@ -43,20 +43,28 @@ class ModelDownloadTask: NSObject {
   /// User defaults to which local model info should ultimately be written.
   private let defaults: UserDefaults
   /// Keeps track of download associated with this model download task.
-  private(set) var downloadStatus: ModelDownloadStatus = .notStarted
+  private(set) var downloadStatus: ModelDownloadStatus = .ready
   /// Downloader instance.
   private let downloader: FileDownloader
   /// Telemetry logger.
   private let telemetryLogger: TelemetryLogger?
+  /// Progress handler.
+  private var progressHandler: ProgressHandler?
+  /// Completion.
+  private var completion: Completion
 
   init(remoteModelInfo: RemoteModelInfo,
        appName: String,
        defaults: UserDefaults,
        downloader: FileDownloader,
+       progressHandler: ProgressHandler? = nil,
+       completion: @escaping Completion,
        telemetryLogger: TelemetryLogger? = nil) {
     self.remoteModelInfo = remoteModelInfo
     self.appName = appName
     self.downloader = downloader
+    self.progressHandler = progressHandler
+    self.completion = completion
     self.telemetryLogger = telemetryLogger
     self.defaults = defaults
   }
@@ -68,28 +76,50 @@ extension ModelDownloadTask {
     return "fbml_model__\(appName)__\(remoteModelInfo.name).tflite"
   }
 
-  func download(progressHandler: ProgressHandler?, completion: @escaping Completion) {
+  /// Check if downloading is not complete for merging requests.
+  func canMergeRequests() -> Bool {
+    return downloadStatus != .complete
+  }
+
+  /// Check if download task can be resumed.
+  func canResume() -> Bool {
+    return downloadStatus == .ready
+  }
+
+  /// Merge duplicate requests. This method is not thread-safe.
+  func merge(newProgressHandler: ProgressHandler? = nil, newCompletion: @escaping Completion) {
+    let originalProgressHandler = progressHandler
+    progressHandler = { progress in
+      originalProgressHandler?(progress)
+      newProgressHandler?(progress)
+    }
+    let originalCompletion = completion
+    completion = { result in
+      originalCompletion(result)
+      newCompletion(result)
+    }
+  }
+
+  func resume() {
     /// Prevent multiple concurrent downloads.
-    guard downloadStatus != .inProgress else {
+    guard downloadStatus != .downloading else {
       DeviceLogger.logEvent(level: .debug,
                             message: ModelDownloadTask.ErrorDescription.anotherDownloadInProgress,
                             messageCode: .anotherDownloadInProgressError)
       telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
                                              status: .failed,
                                              downloadErrorCode: .downloadFailed)
-      completion(.failure(.internalError(description: ModelDownloadTask.ErrorDescription
-          .anotherDownloadInProgress)))
       return
     }
+    downloadStatus = .downloading
     telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
                                            status: .downloading,
                                            downloadErrorCode: .noError)
-    downloadStatus = .inProgress
     downloader.downloadFile(with: remoteModelInfo.downloadURL,
                             progressHandler: { downloadedBytes, totalBytes in
                               /// Fraction of model file downloaded.
                               let calculatedProgress = Float(downloadedBytes) / Float(totalBytes)
-                              progressHandler?(calculatedProgress)
+                              self.progressHandler?(calculatedProgress)
                             }) { result in
       self.downloadStatus = .complete
       switch result {
@@ -101,7 +131,7 @@ extension ModelDownloadTask {
         self.handleResponse(
           response: response.urlResponse,
           tempURL: response.fileURL,
-          completion: completion
+          completion: self.completion
         )
       case let .failure(error):
         var downloadError: DownloadError
@@ -141,7 +171,7 @@ extension ModelDownloadTask {
             downloadErrorCode: .downloadFailed
           )
         }
-        completion(.failure(downloadError))
+        self.completion(.failure(downloadError))
       }
     }
   }

+ 45 - 11
FirebaseMLModelDownloader/Sources/ModelDownloader.swift

@@ -80,6 +80,11 @@ public class ModelDownloader {
     }
   }
 
+  /// Download task associated with the model currently being downloaded.
+  private var currentDownloadTask: [String: ModelDownloadTask] = [:]
+  /// DispatchQueue to manage download task dictionary.
+  let taskSerialQueue = DispatchQueue(label: "downloadtask.serial.queue")
+
   /// Shared dictionary mapping app name to a specific instance of model downloader.
   // TODO: Switch to using Firebase components.
   private static var modelDownloaderDictionary: [String: ModelDownloader] = [:]
@@ -359,7 +364,6 @@ extension ModelDownloader {
       localModelInfo: localModelInfo,
       telemetryLogger: telemetryLogger
     )
-
     let downloader = ModelFileDownloader(conditions: conditions)
     downloadInfoAndModel(
       modelName: modelName,
@@ -378,25 +382,18 @@ extension ModelDownloader {
                             progressHandler: ((Float) -> Void)? = nil,
                             completion: @escaping (Result<CustomModel, DownloadError>)
                               -> Void) {
-    // TODO: Merge if there are multiple same requests.
     modelInfoRetriever.downloadModelInfo { result in
       switch result {
       case let .success(downloadModelInfoResult):
         switch downloadModelInfoResult {
         /// New model info was downloaded from server.
         case let .modelInfo(remoteModelInfo):
-          let downloadTask = ModelDownloadTask(
-            remoteModelInfo: remoteModelInfo,
-            appName: self.appName,
-            defaults: self.userDefaults,
-            downloader: downloader,
-            telemetryLogger: self.telemetryLogger
-          )
-          downloadTask.download(progressHandler: { progress in
+          let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
             if let progressHandler = progressHandler {
               self.mainQueueHandler(progressHandler(progress))
             }
-          }) { result in
+          }
+          let taskCompletion: ModelDownloadTask.Completion = { result in
             switch result {
             case let .success(model):
               self.mainQueueHandler(completion(.success(model)))
@@ -436,6 +433,41 @@ extension ModelDownloader {
                 self.mainQueueHandler(completion(.failure(error)))
               }
             }
+
+            self.taskSerialQueue.async {
+              /// Stop keeping track of current download task.
+              self.currentDownloadTask.removeValue(forKey: modelName)
+            }
+          }
+
+          self.taskSerialQueue.sync {
+            /// Merge duplicate requests if there is already a download in progress for the same model.
+            if let downloadTask = self.currentDownloadTask[modelName],
+              downloadTask.canMergeRequests() {
+              downloadTask.merge(
+                newProgressHandler: taskProgressHandler,
+                newCompletion: taskCompletion
+              )
+              DeviceLogger.logEvent(level: .debug,
+                                    message: ModelDownloader.DebugDescription.mergingRequests,
+                                    messageCode: .mergeRequests)
+              if downloadTask.canResume() {
+                downloadTask.resume()
+              }
+            } else {
+              let downloadTask = ModelDownloadTask(
+                remoteModelInfo: remoteModelInfo,
+                appName: self.appName,
+                defaults: self.userDefaults,
+                downloader: downloader,
+                progressHandler: taskProgressHandler,
+                completion: taskCompletion,
+                telemetryLogger: self.telemetryLogger
+              )
+              /// Keep track of current download task to allow for merging duplicate requests.
+              self.currentDownloadTask[modelName] = downloadTask
+              downloadTask.resume()
+            }
           }
         /// Local model info is the latest model info.
         case .notModified:
@@ -490,6 +522,8 @@ extension ModelDownloader {
     static let noLocalModelInfo = { (name: String) in
       "No local model info for model file named: \(name)."
     }
+
+    static let mergingRequests = "Merging duplicate download requests."
   }
 
   /// Error descriptions.

+ 1 - 24
FirebaseMLModelDownloader/Sources/ModelInfoRetriever.swift

@@ -16,13 +16,6 @@ import Foundation
 import FirebaseCore
 import FirebaseInstallations
 
-/// Possible states of model downloading.
-enum ModelInfoDownloadStatus {
-  case notStarted
-  case inProgress
-  case complete
-}
-
 /// URL Session to use while retrieving model info.
 protocol ModelInfoRetrieverSession {
   func getModelInfo(with request: URLRequest,
@@ -83,13 +76,11 @@ class ModelInfoRetriever {
   private let apiKey: String
   /// Current Firebase app name.
   private let appName: String
-  /// Keeps track of download associated with this model download task.
-  private(set) var downloadStatus: ModelInfoDownloadStatus = .notStarted
   /// Local model info to validate model freshness.
   private let localModelInfo: LocalModelInfo?
   /// Telemetry logger.
   private let telemetryLogger: TelemetryLogger?
-
+  /// Auth token provider.
   typealias AuthTokenProvider = (_ completion: @escaping (Result<String, DownloadError>) -> Void)
     -> Void
   private let authTokenProvider: AuthTokenProvider
@@ -153,18 +144,6 @@ class ModelInfoRetriever {
   /// Get model info from server.
   func downloadModelInfo(completion: @escaping (Result<DownloadModelInfoResult, DownloadError>)
     -> Void) {
-    /// Prevent multiple concurrent downloads.
-    guard downloadStatus != .inProgress else {
-      DeviceLogger.logEvent(level: .debug,
-                            message: ModelInfoRetriever.ErrorDescription.anotherDownloadInProgress,
-                            messageCode: .anotherDownloadInProgressError)
-      telemetryLogger?.logModelDownloadEvent(eventName: .modelDownload,
-                                             status: .failed,
-                                             downloadErrorCode: .downloadFailed)
-      completion(.failure(.internalError(description: ModelInfoRetriever.ErrorDescription
-          .anotherDownloadInProgress)))
-      return
-    }
     authTokenProvider { result in
       switch result {
       /// Successfully received FIS token.
@@ -186,11 +165,9 @@ class ModelInfoRetriever {
               .invalidModelInfoFetchURL)))
           return
         }
-        self.downloadStatus = .inProgress
         /// Download model info.
         self.session.getModelInfo(with: request) {
           data, response, error in
-          self.downloadStatus = .complete
           if let downloadError = error {
             let description = ModelInfoRetriever.ErrorDescription
               .failedModelInfoRetrieval(downloadError.localizedDescription)

+ 13 - 9
FirebaseMLModelDownloader/Tests/Integration/ModelDownloaderIntegrationTests.swift

@@ -163,17 +163,11 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
     let conditions = ModelDownloadConditions()
     let expectation = self.expectation(description: "Wait for model to download.")
     let downloader = ModelFileDownloader(conditions: conditions)
-    let modelDownloadManager = ModelDownloadTask(
-      remoteModelInfo: remoteModelInfo,
-      appName: testApp.name,
-      defaults: .createTestInstance(testName: #function),
-      downloader: downloader
-    )
-
-    modelDownloadManager.download(progressHandler: { progress in
+    let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
       XCTAssertLessThanOrEqual(progress, 1)
       XCTAssertGreaterThanOrEqual(progress, 0)
-    }) { result in
+    }
+    let taskCompletion: ModelDownloadTask.Completion = { result in
       switch result {
       case let .success(model):
         guard let modelPath = URL(string: model.path) else {
@@ -192,6 +186,16 @@ final class ModelDownloaderIntegrationTests: XCTestCase {
       }
       expectation.fulfill()
     }
+    let modelDownloadManager = ModelDownloadTask(
+      remoteModelInfo: remoteModelInfo,
+      appName: testApp.name,
+      defaults: .createTestInstance(testName: #function),
+      downloader: downloader,
+      progressHandler: taskProgressHandler,
+      completion: taskCompletion
+    )
+
+    modelDownloadManager.resume()
     waitForExpectations(timeout: 5, handler: nil)
     XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
   }

+ 160 - 24
FirebaseMLModelDownloader/Tests/Unit/ModelDownloaderUnitTests.swift

@@ -62,6 +62,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
   }
 
   override class func tearDown() {
+    try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
     FirebaseApp.app()?.delete { _ in }
   }
 
@@ -355,16 +356,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
     progressExpectation.expectedFulfillmentCount = 2
 
     let completionExpectation = expectation(description: "Completion handler")
-    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
-                                              appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
-                                              downloader: downloader)
-
-    modelDownloadTask.download(progressHandler: {
+    let taskProgressHandler: ModelDownloadTask.ProgressHandler = {
       progress in
       progressExpectation.fulfill()
       XCTAssertEqual(progress, 0.4)
-    }) { result in
+    }
+    let taskCompletion: ModelDownloadTask.Completion = { result in
       completionExpectation.fulfill()
       switch result {
       case let .success(model):
@@ -378,6 +375,14 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTFail("Error - \(error)")
       }
     }
+    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
+                                              appName: "fakeAppName",
+                                              defaults: .createTestInstance(testName: #function),
+                                              downloader: downloader,
+                                              progressHandler: taskProgressHandler,
+                                              completion: taskCompletion)
+
+    modelDownloadTask.resume()
     wait(for: [fileDownloaderExpectation], timeout: 0.1)
 
     downloader.progressHandler?(100, 250)
@@ -408,11 +413,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
     }
 
     let completionExpectation = expectation(description: "Completion handler")
-    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
-                                              appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
-                                              downloader: downloader)
-    modelDownloadTask.download(progressHandler: nil) { result in
+    let taskCompletion: ModelDownloadTask.Completion = { result in
       completionExpectation.fulfill()
       switch result {
       case .success: XCTFail("Unexpected successful model download.")
@@ -420,6 +421,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTAssertEqual(error, .expiredDownloadURL)
       }
     }
+    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
+                                              appName: "fakeAppName",
+                                              defaults: .createTestInstance(testName: #function),
+                                              downloader: downloader,
+                                              completion: taskCompletion)
+    modelDownloadTask.resume()
     wait(for: [fileDownloaderExpectation], timeout: 0.1)
 
     downloader
@@ -446,11 +453,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
     }
 
     let completionExpectation = expectation(description: "Completion handler")
-    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
-                                              appName: "fakeAppName",
-                                              defaults: .createTestInstance(testName: #function),
-                                              downloader: downloader)
-    modelDownloadTask.download(progressHandler: nil) { result in
+    let taskCompletion: ModelDownloadTask.Completion = { result in
       completionExpectation.fulfill()
       switch result {
       case .success: XCTFail("Unexpected successful model download.")
@@ -458,6 +461,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTAssertEqual(error, .invalidArgument)
       }
     }
+    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
+                                              appName: "fakeAppName",
+                                              defaults: .createTestInstance(testName: #function),
+                                              downloader: downloader,
+                                              completion: taskCompletion)
+    modelDownloadTask.resume()
     wait(for: [fileDownloaderExpectation], timeout: 0.1)
 
     downloader
@@ -468,27 +477,148 @@ final class ModelDownloaderUnitTests: XCTestCase {
     try? ModelFileManager.removeFile(at: tempFileURL)
   }
 
-  /// Get model if server returns an error due to model not found.
-  func testModelDownloadWith404() {
+  /// Get model if multiple duplicate requests are made.
+  func testModelDownloadWithMergeRequests() {
     let tempFileURL = tempFile()
+    let numberOfRequests = 5
     let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
     let fakeResponse = HTTPURLResponse(url: fakeFileURL,
-                                       statusCode: 404,
+                                       statusCode: 200,
                                        httpVersion: nil,
                                        headerFields: nil)!
     let downloader = MockModelFileDownloader()
     let fileDownloaderExpectation = expectation(description: "File downloader")
+
     downloader.downloadFileHandler = { url in
       fileDownloaderExpectation.fulfill()
       XCTAssertEqual(url, self.fakeDownloadURL)
     }
 
     let completionExpectation = expectation(description: "Completion handler")
+    completionExpectation.expectedFulfillmentCount = numberOfRequests + 1
+
+    let taskCompletion: ModelDownloadTask.Completion = { result in
+      completionExpectation.fulfill()
+      switch result {
+      case let .success(model):
+        XCTAssertEqual(model.name, self.fakeModelName)
+        XCTAssertEqual(model.size, self.fakeModelSize)
+        XCTAssertEqual(model.hash, self.fakeModelHash)
+        let modelPath = URL(string: model.path)!
+        XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+      case let .failure(error):
+        XCTFail("Error - \(error)")
+      }
+    }
+
     let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
                                               appName: "fakeAppName",
                                               defaults: .createTestInstance(testName: #function),
-                                              downloader: downloader)
-    modelDownloadTask.download(progressHandler: nil) { result in
+                                              downloader: downloader,
+                                              completion: taskCompletion)
+
+    for i in 1 ... numberOfRequests {
+      let newTaskCompletion: ModelDownloadTask.Completion = { result in
+        completionExpectation.fulfill()
+        switch result {
+        case let .success(model):
+          XCTAssertEqual(model.name, self.fakeModelName)
+          XCTAssertEqual(model.size, self.fakeModelSize)
+          XCTAssertEqual(model.hash, self.fakeModelHash)
+          let modelPath = URL(string: model.path)!
+          XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+          if i == numberOfRequests {
+            try? self.deleteFile(at: modelPath)
+          }
+        case let .failure(error):
+          XCTFail("Error - \(error)")
+        }
+      }
+      modelDownloadTask.resume()
+      modelDownloadTask.merge(newCompletion: newTaskCompletion)
+    }
+
+    downloader
+      .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
+                                                   fileURL: tempFileURL)))
+    wait(for: [fileDownloaderExpectation], timeout: 1)
+    wait(for: [completionExpectation], timeout: 2)
+
+    try? ModelFileManager.removeFile(at: tempFileURL)
+  }
+
+  /// Get model if multiple duplicate requests are made (test at the API surface).
+  func testGetModelWithMergeRequests() {
+    let modelDownloader = ModelDownloader.modelDownloader()
+    let conditions = ModelDownloadConditions()
+    let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
+    let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
+    let tempFileURL = tempFile()
+    let numberOfRequests = 6
+    let fakeResponse = HTTPURLResponse(url: fakeFileURL,
+                                       statusCode: 200,
+                                       httpVersion: nil,
+                                       headerFields: nil)!
+    let downloader = MockModelFileDownloader()
+    let fileDownloaderExpectation = expectation(description: "File downloader")
+
+    downloader.downloadFileHandler = { url in
+      fileDownloaderExpectation.fulfill()
+      XCTAssertEqual(url, self.fakeDownloadURL)
+    }
+
+    let completionExpectation = expectation(description: "Completion handler")
+    completionExpectation.expectedFulfillmentCount = numberOfRequests
+
+    for i in 1 ... numberOfRequests {
+      modelDownloader.downloadInfoAndModel(modelName: "fakeModelName",
+                                           modelInfoRetriever: modelInfoRetriever,
+                                           downloader: downloader,
+                                           conditions: conditions) { result in
+        completionExpectation.fulfill()
+        switch result {
+        case let .success(model):
+          XCTAssertEqual(model.name, self.fakeModelName)
+          XCTAssertEqual(model.size, self.fakeModelSize)
+          XCTAssertEqual(model.hash, self.fakeModelHash)
+          let modelPath = URL(string: model.path)!
+          XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
+          if i == numberOfRequests {
+            try? self.deleteFile(at: modelPath)
+          }
+        case let .failure(error):
+          XCTFail("Error - \(error)")
+        }
+      }
+    }
+    wait(for: [fileDownloaderExpectation], timeout: 0.5)
+
+    downloader
+      .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
+                                                   fileURL: tempFileURL)))
+
+    wait(for: [completionExpectation], timeout: 1)
+
+    try? ModelFileManager.removeFile(at: tempFileURL)
+  }
+
+  /// Get model if server returns an error due to model not found.
+  func testModelDownloadWith404() {
+    let tempFileURL = tempFile()
+    let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
+    let fakeResponse = HTTPURLResponse(url: fakeFileURL,
+                                       statusCode: 404,
+                                       httpVersion: nil,
+                                       headerFields: nil)!
+    let downloader = MockModelFileDownloader()
+    let fileDownloaderExpectation = expectation(description: "File downloader")
+    downloader.downloadFileHandler = { url in
+      fileDownloaderExpectation.fulfill()
+      XCTAssertEqual(url, self.fakeDownloadURL)
+    }
+
+    let completionExpectation = expectation(description: "Completion handler")
+    let taskCompletion: ModelDownloadTask.Completion = { result in
       completionExpectation.fulfill()
       switch result {
       case .success: XCTFail("Unexpected successful model download.")
@@ -496,6 +626,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
         XCTAssertEqual(error, .notFound)
       }
     }
+    let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
+                                              appName: "fakeAppName",
+                                              defaults: .createTestInstance(testName: #function),
+                                              downloader: downloader,
+                                              completion: taskCompletion)
+    modelDownloadTask.resume()
     wait(for: [fileDownloaderExpectation], timeout: 0.1)
 
     downloader
@@ -948,10 +1084,10 @@ extension ModelDownloaderUnitTests {
     return modelInfoRetriever
   }
 
-  func tempFile() -> URL {
+  func tempFile(name: String = "fake-model-file.tmp") -> URL {
     let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
                                isDirectory: true)
-    let tempFileURL = tempDirectoryURL.appendingPathComponent("fake-model-file.tmp")
+    let tempFileURL = tempDirectoryURL.appendingPathComponent(name)
     let tempData: Data = "fakeModelData".data(using: .utf8)!
     try? tempData.write(to: tempFileURL)
     return tempFileURL