| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588 |
- // Copyright 2021 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
- // involve interactions with the keychain that require a provisioning profile.
- // This is because MLModelDownloader depends on Installations.
- // See go/firebase-macos-keychain-popups for more details.
- #if !targetEnvironment(macCatalyst) && !os(macOS)
- import XCTest
- @testable import FirebaseCore
- @testable import FirebaseInstallations
- @testable import FirebaseMLModelDownloader
- extension UserDefaults {
- /// Returns a new cleared instance of user defaults.
- static func createTestInstance(testName: String) -> UserDefaults {
- let suiteName = "com.google.firebase.ml.test.\(testName)"
- let defaults = UserDefaults(suiteName: suiteName)!
- defaults.removePersistentDomain(forName: suiteName)
- return defaults
- }
- /// Returns the existing user defaults instance.
- static func getTestInstance(testName: String) -> UserDefaults {
- let suiteName = "com.google.firebase.ml.test.\(testName)"
- return UserDefaults(suiteName: suiteName)!
- }
- }
- final class ModelDownloaderIntegrationTests: XCTestCase {
- override class func setUp() {
- super.setUp()
- // TODO: Use FirebaseApp internal for test app.
- let bundle = Bundle(for: self)
- if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
- let options = FirebaseOptions(contentsOfFile: plistPath) {
- FirebaseApp.configure(options: options)
- } else {
- XCTFail("Could not locate GoogleService-Info.plist.")
- }
- FirebaseConfiguration.shared.setLoggerLevel(.debug)
- }
- override func setUp() {
- do {
- try ModelFileManager.emptyModelsDirectory()
- } catch {
- XCTFail("Could not empty models directory.")
- }
- }
- /// Test to download model info - makes an actual network call.
- func testDownloadModelInfo() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testName = String(#function.dropLast(2))
- let testModelName = "pose-detection"
- let modelInfoRetriever = ModelInfoRetriever(
- modelName: testModelName,
- projectID: testApp.options.projectID!,
- apiKey: testApp.options.apiKey!,
- appName: testApp.name, installations: Installations.installations(app: testApp)
- )
- let modelInfoDownloadExpectation =
- expectation(description: "Wait for model info to download.")
- modelInfoRetriever.downloadModelInfo(completion: { result in
- switch result {
- case let .success(modelInfoResult):
- switch modelInfoResult {
- case let .modelInfo(modelInfo):
- XCTAssertNotNil(modelInfo.urlExpiryTime)
- XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
- XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
- XCTAssertGreaterThan(modelInfo.size, 0)
- let localModelInfo = LocalModelInfo(from: modelInfo)
- localModelInfo.writeToDefaults(
- .createTestInstance(testName: testName),
- appName: testApp.name
- )
- case .notModified:
- XCTFail("Failed to retrieve model info.")
- }
- case let .failure(error):
- XCTAssertNotNil(error)
- XCTFail("Failed to retrieve model info - \(error)")
- }
- modelInfoDownloadExpectation.fulfill()
- })
- wait(for: [modelInfoDownloadExpectation], timeout: 5)
- if let localInfo = LocalModelInfo(
- fromDefaults: .getTestInstance(testName: testName),
- name: testModelName,
- appName: testApp.name
- ) {
- XCTAssertNotNil(localInfo)
- testRetrieveModelInfo(localInfo: localInfo)
- } else {
- XCTFail("Could not save model info locally.")
- }
- }
- func testRetrieveModelInfo(localInfo: LocalModelInfo) {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testModelName = "pose-detection"
- let modelInfoRetriever = ModelInfoRetriever(
- modelName: testModelName,
- projectID: testApp.options.projectID!,
- apiKey: testApp.options.apiKey!,
- appName: testApp.name, installations: Installations.installations(app: testApp),
- localModelInfo: localInfo
- )
- let modelInfoRetrieveExpectation =
- expectation(description: "Wait for model info to be retrieved.")
- modelInfoRetriever.downloadModelInfo(completion: { result in
- switch result {
- case let .success(modelInfoResult):
- switch modelInfoResult {
- case .modelInfo:
- XCTFail("Local model info is already the latest and should not be set again.")
- case .notModified: break
- }
- case let .failure(error):
- XCTAssertNotNil(error)
- XCTFail("Failed to retrieve model info - \(error)")
- }
- modelInfoRetrieveExpectation.fulfill()
- })
- wait(for: [modelInfoRetrieveExpectation], timeout: 5)
- }
- /// Test to download model file - makes an actual network call.
- func testModelDownload() throws {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testName = String(#function.dropLast(2))
- let testModelName = "\(testName)-test-model"
- let urlString =
- "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
- let url = URL(string: urlString)!
- let remoteModelInfo = RemoteModelInfo(
- name: testModelName,
- downloadURL: url,
- modelHash: "mock-valid-hash",
- size: 10,
- urlExpiryTime: Date()
- )
- let conditions = ModelDownloadConditions()
- let downloadExpectation = expectation(description: "Wait for model to download.")
- let downloader = ModelFileDownloader(conditions: conditions)
- let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- let taskCompletion: ModelDownloadTask.Completion = { result in
- switch result {
- case let .success(model):
- let modelURL = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
- // Remove downloaded model file.
- do {
- try ModelFileManager.removeFile(at: modelURL)
- } catch {
- XCTFail("Model removal failed - \(error)")
- }
- case let .failure(error):
- XCTFail("Error: \(error)")
- }
- downloadExpectation.fulfill()
- }
- let modelDownloadManager = ModelDownloadTask(
- remoteModelInfo: remoteModelInfo,
- appName: testApp.name,
- defaults: .createTestInstance(testName: testName),
- downloader: downloader,
- progressHandler: taskProgressHandler,
- completion: taskCompletion
- )
- modelDownloadManager.resume()
- wait(for: [downloadExpectation], timeout: 5)
- XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
- }
- func testGetModel() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testName = String(#function.dropLast(2))
- let testModelName = "image-classification"
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- /// Test download type - latest model.
- var downloadType: ModelDownloadType = .latestModel
- let latestModelExpectation = expectation(description: "Get latest model.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- let modelURL = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- wait(for: [latestModelExpectation], timeout: 5)
- /// Test download type - local model update in background.
- downloadType = .localModelUpdateInBackground
- let backgroundModelExpectation =
- expectation(description: "Get local model and update in background.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTFail("Model is already available on device.")
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- let modelURL = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- backgroundModelExpectation.fulfill()
- }
- wait(for: [backgroundModelExpectation], timeout: 5)
- /// Test download type - local model.
- downloadType = .localModel
- let localModelExpectation = expectation(description: "Get local model.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTFail("Model is already available on device.")
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- let modelURL = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
- // Remove downloaded model file.
- do {
- try ModelFileManager.removeFile(at: modelURL)
- } catch {
- XCTFail("Model removal failed - \(error)")
- }
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- localModelExpectation.fulfill()
- }
- wait(for: [localModelExpectation], timeout: 5)
- }
- func testGetModelWhenNameIsEmpty() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testName = String(#function.dropLast(2))
- let emptyModelName = ""
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let completionExpectation = expectation(description: "getModel")
- let progressExpectation = expectation(description: "progressHandler")
- progressExpectation.isInverted = true
- modelDownloader.getModel(
- name: emptyModelName,
- downloadType: .latestModel,
- conditions: conditions,
- progressHandler: { progress in
- progressExpectation.fulfill()
- }
- ) { result in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- switch result {
- case .failure(.emptyModelName):
- // The expected error.
- break
- default:
- XCTFail("Unexpected result: \(result)")
- }
- completionExpectation.fulfill()
- }
- wait(for: [completionExpectation, progressExpectation], timeout: 5)
- }
- /// Delete previously downloaded model.
- func testDeleteModel() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testName = String(#function.dropLast(2))
- let testModelName = "pose-detection"
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let downloadType: ModelDownloadType = .latestModel
- let latestModelExpectation = expectation(description: "Get latest model for deletion.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- let filePath = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- wait(for: [latestModelExpectation], timeout: 5)
- let deleteExpectation = expectation(description: "Wait for model deletion.")
- modelDownloader.deleteDownloadedModel(name: testModelName) { result in
- deleteExpectation.fulfill()
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- switch result {
- case .success: break
- case let .failure(error):
- XCTFail("Failed to delete model - \(error)")
- }
- }
- wait(for: [deleteExpectation], timeout: 5)
- }
- /// Test listing models in model directory.
- func testListModels() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testName = String(#function.dropLast(2))
- let testModelName = "pose-detection"
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let downloadType: ModelDownloadType = .latestModel
- let latestModelExpectation = expectation(description: "Get latest model.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- let filePath = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- wait(for: [latestModelExpectation], timeout: 5)
- let listExpectation = expectation(description: "Wait for list models.")
- modelDownloader.listDownloadedModels { result in
- listExpectation.fulfill()
- switch result {
- case let .success(models):
- XCTAssertGreaterThan(models.count, 0)
- case let .failure(error):
- XCTFail("Failed to list models - \(error)")
- }
- }
- wait(for: [listExpectation], timeout: 5)
- }
- /// Test logging telemetry event.
- func testLogTelemetryEvent() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- // Flip this to `true` to test logging.
- testApp.isDataCollectionDefaultEnabled = false
- let testModelName = "digit-classification"
- let testName = String(#function.dropLast(2))
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let latestModelExpectation = expectation(description: "Test get model telemetry.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: .latestModel,
- conditions: conditions
- ) { result in
- switch result {
- case .success: break
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- wait(for: [latestModelExpectation], timeout: 5)
- let deleteModelExpectation = expectation(description: "Test delete model telemetry.")
- modelDownloader.deleteDownloadedModel(name: testModelName) { result in
- switch result {
- case .success(()): break
- case let .failure(error):
- XCTFail("Failed to delete model - \(error)")
- }
- deleteModelExpectation.fulfill()
- }
- wait(for: [deleteModelExpectation], timeout: 5)
- let notFoundModelName = "fakeModelName"
- let notFoundModelExpectation =
- expectation(description: "Test get model telemetry with unknown model.")
- modelDownloader.getModel(
- name: notFoundModelName,
- downloadType: .latestModel,
- conditions: conditions
- ) { result in
- switch result {
- case .success: XCTFail("Unexpected model success.")
- case let .failure(error):
- XCTAssertEqual(error, .notFound)
- }
- notFoundModelExpectation.fulfill()
- }
- wait(for: [notFoundModelExpectation], timeout: 5)
- }
- func testGetModelWithConditions() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- testApp.isDataCollectionDefaultEnabled = false
- let testModelName = "pose-detection"
- let testName = String(#function.dropLast(2))
- let conditions = ModelDownloadConditions(allowsCellularAccess: false)
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let latestModelExpectation = expectation(description: "Get latest model with conditions.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: .latestModel,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- let filePath = URL(fileURLWithPath: model.path)
- XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- wait(for: [latestModelExpectation], timeout: 5)
- }
- }
- #endif // !targetEnvironment(macCatalyst) && !os(macOS)
|