Sfoglia il codice sorgente

Custom model (#6594)

* Initial commit for CustomModel class.

* Initial commit for CustomModel class.

* API scaffolding for CustomModel and ModelDownloader classes

* API scaffolding for CustomModel and ModelDownloader classes

* Delete ModelDownloadConditions.swift

* Changed Swift structs to NSObject subclasses for objc visibility

* Design for a Swift-first SDK

* Design for a Swift-first SDK

* Design for a Swift-first SDK

* Updated design w/ errors and download progress handler

* Style compliance

* Refactor Test folder structure

* Refactor Test folder structure

* Fix pod lint errors

* Fix pod lint errors

* Add xcscheme for SwiftPM builds.

* Disable catalyst temporarily, re-enable issue #6790

* Better classification of download errors

* Documentation comments for custom model

* Refactor errors for model downloading

Co-authored-by: Ryan Wilson <wilsonryan@google.com>
manjanac 5 anni fa
parent
commit
1878792e86

+ 5 - 3
.github/workflows/mlmodeldownloader.yml

@@ -46,7 +46,8 @@ jobs:
     runs-on: macOS-latest
     strategy:
       matrix:
-        target: [tvOS, macOS, catalyst]
+        # TODO: manjanac@ Add catalyst back here.
+        target: [tvOS, macOS]
     steps:
     - uses: actions/checkout@v2
     - name: Xcode 12
@@ -58,8 +59,9 @@ jobs:
 
   catalyst:
     # Don't run on private repo unless it is a PR.
-    if: github.repository != 'FirebasePrivate/firebase-ios-sdk' || github.event_name == 'pull_request'
-
+    # TODO: manjanac@ Uncomment line below to re-enable catalyst.
+    # if: github.repository != 'FirebasePrivate/firebase-ios-sdk' || github.event_name == 'pull_request'
+    if: false
     runs-on: macOS-latest
     steps:
     - uses: actions/checkout@v2

+ 52 - 0
.swiftpm/xcode/xcshareddata/xcschemes/FirebaseMLModelDownloaderUnit.xcscheme

@@ -0,0 +1,52 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<Scheme
+   LastUpgradeVersion = "1200"
+   version = "1.3">
+   <BuildAction
+      parallelizeBuildables = "YES"
+      buildImplicitDependencies = "YES">
+   </BuildAction>
+   <TestAction
+      buildConfiguration = "Debug"
+      selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
+      selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
+      shouldUseLaunchSchemeArgsEnv = "YES">
+      <Testables>
+         <TestableReference
+            skipped = "NO">
+            <BuildableReference
+               BuildableIdentifier = "primary"
+               BlueprintIdentifier = "FirebaseMLModelDownloaderUnit"
+               BuildableName = "FirebaseMLModelDownloaderUnit"
+               BlueprintName = "FirebaseMLModelDownloaderUnit"
+               ReferencedContainer = "container:">
+            </BuildableReference>
+         </TestableReference>
+      </Testables>
+   </TestAction>
+   <LaunchAction
+      buildConfiguration = "Debug"
+      selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
+      selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
+      launchStyle = "0"
+      useCustomWorkingDirectory = "NO"
+      ignoresPersistentStateOnLaunch = "NO"
+      debugDocumentVersioning = "YES"
+      debugServiceExtension = "internal"
+      allowLocationSimulation = "YES">
+   </LaunchAction>
+   <ProfileAction
+      buildConfiguration = "Release"
+      shouldUseLaunchSchemeArgsEnv = "YES"
+      savedToolIdentifier = ""
+      useCustomWorkingDirectory = "NO"
+      debugDocumentVersioning = "YES">
+   </ProfileAction>
+   <AnalyzeAction
+      buildConfiguration = "Debug">
+   </AnalyzeAction>
+   <ArchiveAction
+      buildConfiguration = "Release"
+      revealArchiveInOrganizer = "YES">
+   </ArchiveAction>
+</Scheme>

+ 11 - 9
FirebaseMLModelDownloader/Tests/Unit/MLDownloaderTests.swift → FirebaseMLModelDownloader/Sources/CustomModel.swift

@@ -12,14 +12,16 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-import XCTest
-@testable import FirebaseMLModelDownloader
+import Foundation
 
-final class MLModelDownloaderTests: XCTestCase {
-    func testExample() {
-        // This is an example of a functional test case.
-        // Use XCTAssert and related functions to verify your tests produce the correct
-        // results.
-        XCTAssertEqual(MLModelDownloader().text, "Hello, World!")
-    }
+/// A custom model that is stored remotely on the server and downloaded to the device.
+public struct CustomModel: Hashable {
+  /// Name of the model.
+  public let name: String
+  /// Size of the custom model, provided by the server.
+  public let size: Int
+  /// Path where the model is stored on device.
+  public let path: String
+  /// Hash for the model, used for model verification.
+  public let hash: String
 }

+ 4 - 3
FirebaseMLModelDownloader/Sources/MLModelDownloader.swift → FirebaseMLModelDownloader/Sources/ModelDownloadConditions.swift

@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-public struct MLModelDownloader {
-  public var text = "Hello, World!"
-}
+import Foundation
+
+/// Model download conditions.
+public struct ModelDownloadConditions {}

+ 90 - 0
FirebaseMLModelDownloader/Sources/ModelDownloader.swift

@@ -0,0 +1,90 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import Foundation
+
+/// Possible errors with model downloading.
+public enum DownloadError: Error {
+  /// No model with this name found on server.
+  case notFound
+  /// Caller does not have necessary permissions for this operation.
+  case permissionDenied
+  /// Conditions not met to perform download.
+  case failedPrecondition
+  /// Not enough space for model on device.
+  case notEnoughSpace
+  /// Malformed model name.
+  case invalidArgument
+  /// Other errors with description.
+  case internalError(description: String)
+}
+
+/// Possible errors with locating model on device.
+public enum DownloadedModelError: Error {
+  /// File system error.
+  case fileIOError
+  /// Model not found on device.
+  case notFound
+}
+
+/// Possible ways to get a custom model.
+public enum ModelDownloadType {
+  /// Get local model stored on device.
+  case localModel
+  /// Get local model on device and update to latest model from server in the background.
+  case localModelUpdateInBackground
+  /// Get latest model from server.
+  case latestModel
+}
+
+/// Downloader to manage custom model downloads.
+public struct ModelDownloader {
+  /// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
+  public func getModel(name modelName: String, downloadType: ModelDownloadType,
+                       conditions: ModelDownloadConditions,
+                       progressHandler: ((Float) -> Void)? = nil,
+                       completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
+    // TODO: Model download
+    let modelSize = Int()
+    let modelPath = String()
+    let modelHash = String()
+
+    let customModel = CustomModel(
+      name: modelName,
+      size: modelSize,
+      path: modelPath,
+      hash: modelHash
+    )
+    completion(.success(customModel))
+    completion(.failure(.notFound))
+  }
+
+  /// Gets all downloaded models.
+  public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
+    DownloadedModelError>) -> Void) {
+    let customModels = Set<CustomModel>()
+    // TODO: List downloaded models
+    completion(.success(customModels))
+    completion(.failure(.notFound))
+  }
+
+  /// Deletes a custom model from device.
+  public func deleteDownloadedModel(name modelName: String,
+                                    completion: @escaping (Result<Void, DownloadedModelError>)
+                                      -> Void) {
+    // TODO: Delete previously downloaded model
+    completion(.success(()))
+    completion(.failure(.notFound))
+  }
+}

+ 70 - 0
FirebaseMLModelDownloader/Tests/Unit/ModelDownloaderTests.swift

@@ -0,0 +1,70 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import XCTest
+@testable import FirebaseMLModelDownloader
+
+final class ModelDownloaderTests: XCTestCase {
+  func testExample() {
+    // This is an example of a functional test case.
+    // Use XCTAssert and related functions to verify your tests produce the correct
+    // results.
+    let modelDownloader = ModelDownloader()
+    let conditions = ModelDownloadConditions()
+
+    // Download model w/ progress handler
+    modelDownloader.getModel(
+      name: "your_model_name",
+      downloadType: .latestModel,
+      conditions: conditions,
+      progressHandler: { progress in
+        // Handle progress
+      }
+    ) { result in
+      switch result {
+      case .success:
+        // Use model with your inference API
+        // let interpreter = Interpreter(modelPath: customModel.modelPath)
+        break
+      case .failure:
+        // Handle download error
+        break
+      }
+    }
+
+    // Access array of downloaded models
+    modelDownloader.listDownloadedModels { result in
+      switch result {
+      case .success:
+        // Pick model(s) for further use
+        break
+      case .failure:
+        // Handle failure
+        break
+      }
+    }
+
+    // Delete downloaded model
+    modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
+      switch result {
+      case .success():
+        // Apply any other clean up
+        break
+      case .failure:
+        // Handle failure
+        break
+      }
+    }
+  }
+}