// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import XCTest @testable import FirebaseCore @testable import FirebaseInstallations @testable import FirebaseMLModelDownloader extension UserDefaults { /// Returns a new cleared instance of user defaults. static func createTestInstance(testName: String) -> UserDefaults { let suiteName = "com.google.firebase.ml.test.\(testName)" let defaults = UserDefaults(suiteName: suiteName)! defaults.removePersistentDomain(forName: suiteName) return defaults } /// Returns the existing user defaults instance. static func getTestInstance(testName: String) -> UserDefaults { let suiteName = "com.google.firebase.ml.test.\(testName)" return UserDefaults(suiteName: suiteName)! } } final class ModelDownloaderIntegrationTests: XCTestCase { override class func setUp() { super.setUp() // TODO: Use FirebaseApp internal for test app. let bundle = Bundle(for: self) if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"), let options = FirebaseOptions(contentsOfFile: plistPath) { FirebaseApp.configure(options: options) } else { XCTFail("Could not locate GoogleService-Info.plist.") } FirebaseConfiguration.shared.setLoggerLevel(.debug) } override func setUp() { do { try ModelFileManager.emptyModelsDirectory() } catch { XCTFail("Could not empty models directory.") } } /// Test to download model info - makes an actual network call. func testDownloadModelInfo() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testName = String(#function.dropLast(2)) let testModelName = "pose-detection" let modelInfoRetriever = ModelInfoRetriever( modelName: testModelName, projectID: testApp.options.projectID!, apiKey: testApp.options.apiKey!, appName: testApp.name, installations: Installations.installations(app: testApp) ) let modelInfoDownloadExpectation = expectation(description: "Wait for model info to download.") modelInfoRetriever.downloadModelInfo(completion: { result in switch result { case let .success(modelInfoResult): switch modelInfoResult { case let .modelInfo(modelInfo): XCTAssertNotNil(modelInfo.urlExpiryTime) XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0) XCTAssertGreaterThan(modelInfo.modelHash.count, 0) XCTAssertGreaterThan(modelInfo.size, 0) let localModelInfo = LocalModelInfo(from: modelInfo) localModelInfo.writeToDefaults( .createTestInstance(testName: testName), appName: testApp.name ) case .notModified: XCTFail("Failed to retrieve model info.") } case let .failure(error): XCTAssertNotNil(error) XCTFail("Failed to retrieve model info - \(error)") } modelInfoDownloadExpectation.fulfill() }) wait(for: [modelInfoDownloadExpectation], timeout: 5) if let localInfo = LocalModelInfo( fromDefaults: .getTestInstance(testName: testName), name: testModelName, appName: testApp.name ) { XCTAssertNotNil(localInfo) testRetrieveModelInfo(localInfo: localInfo) } else { XCTFail("Could not save model info locally.") } } func testRetrieveModelInfo(localInfo: LocalModelInfo) { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } let testModelName = "pose-detection" let modelInfoRetriever = ModelInfoRetriever( modelName: testModelName, projectID: testApp.options.projectID!, apiKey: testApp.options.apiKey!, appName: testApp.name, installations: Installations.installations(app: testApp), localModelInfo: localInfo ) let modelInfoRetrieveExpectation = expectation(description: "Wait for model info to be retrieved.") modelInfoRetriever.downloadModelInfo(completion: { result in switch result { case let .success(modelInfoResult): switch modelInfoResult { case .modelInfo: XCTFail("Local model info is already the latest and should not be set again.") case .notModified: break } case let .failure(error): XCTAssertNotNil(error) XCTFail("Failed to retrieve model info - \(error)") } modelInfoRetrieveExpectation.fulfill() }) wait(for: [modelInfoRetrieveExpectation], timeout: 5) } /// Test to download model file - makes an actual network call. func testModelDownload() throws { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testName = String(#function.dropLast(2)) let testModelName = "\(testName)-test-model" let urlString = "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite" let url = URL(string: urlString)! let remoteModelInfo = RemoteModelInfo( name: testModelName, downloadURL: url, modelHash: "mock-valid-hash", size: 10, urlExpiryTime: Date() ) let conditions = ModelDownloadConditions() let downloadExpectation = expectation(description: "Wait for model to download.") let downloader = ModelFileDownloader(conditions: conditions) let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in XCTAssertLessThanOrEqual(progress, 1) XCTAssertGreaterThanOrEqual(progress, 0) } let taskCompletion: ModelDownloadTask.Completion = { result in switch result { case let .success(model): let modelURL = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL)) // Remove downloaded model file. do { try ModelFileManager.removeFile(at: modelURL) } catch { XCTFail("Model removal failed - \(error)") } case let .failure(error): XCTFail("Error: \(error)") } downloadExpectation.fulfill() } let modelDownloadManager = ModelDownloadTask( remoteModelInfo: remoteModelInfo, appName: testApp.name, defaults: .createTestInstance(testName: testName), downloader: downloader, progressHandler: taskProgressHandler, completion: taskCompletion ) modelDownloadManager.resume() wait(for: [downloadExpectation], timeout: 5) XCTAssertEqual(modelDownloadManager.downloadStatus, .complete) } func testGetModel() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testName = String(#function.dropLast(2)) let testModelName = "image-classification" let conditions = ModelDownloadConditions() let modelDownloader = ModelDownloader.modelDownloaderWithDefaults( .createTestInstance(testName: testName), app: testApp ) /// Test download type - latest model. var downloadType: ModelDownloadType = .latestModel let latestModelExpectation = expectation(description: "Get latest model.") modelDownloader.getModel( name: testModelName, downloadType: downloadType, conditions: conditions, progressHandler: { progress in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") XCTAssertLessThanOrEqual(progress, 1) XCTAssertGreaterThanOrEqual(progress, 0) } ) { result in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") switch result { case let .success(model): XCTAssertNotNil(model.path) let modelURL = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL)) case let .failure(error): XCTFail("Failed to download model - \(error)") } latestModelExpectation.fulfill() } wait(for: [latestModelExpectation], timeout: 5) /// Test download type - local model update in background. downloadType = .localModelUpdateInBackground let backgroundModelExpectation = expectation(description: "Get local model and update in background.") modelDownloader.getModel( name: testModelName, downloadType: downloadType, conditions: conditions, progressHandler: { progress in XCTFail("Model is already available on device.") } ) { result in switch result { case let .success(model): XCTAssertNotNil(model.path) let modelURL = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL)) case let .failure(error): XCTFail("Failed to download model - \(error)") } backgroundModelExpectation.fulfill() } wait(for: [backgroundModelExpectation], timeout: 5) /// Test download type - local model. downloadType = .localModel let localModelExpectation = expectation(description: "Get local model.") modelDownloader.getModel( name: testModelName, downloadType: downloadType, conditions: conditions, progressHandler: { progress in XCTFail("Model is already available on device.") } ) { result in switch result { case let .success(model): XCTAssertNotNil(model.path) let modelURL = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL)) // Remove downloaded model file. do { try ModelFileManager.removeFile(at: modelURL) } catch { XCTFail("Model removal failed - \(error)") } case let .failure(error): XCTFail("Failed to download model - \(error)") } localModelExpectation.fulfill() } wait(for: [localModelExpectation], timeout: 5) } func testGetModelWhenNameIsEmpty() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testName = String(#function.dropLast(2)) let emptyModelName = "" let conditions = ModelDownloadConditions() let modelDownloader = ModelDownloader.modelDownloaderWithDefaults( .createTestInstance(testName: testName), app: testApp ) let completionExpectation = expectation(description: "getModel") let progressExpectation = expectation(description: "progressHandler") progressExpectation.isInverted = true modelDownloader.getModel( name: emptyModelName, downloadType: .latestModel, conditions: conditions, progressHandler: { progress in progressExpectation.fulfill() } ) { result in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") switch result { case .failure(.emptyModelName): // The expected error. break default: XCTFail("Unexpected result: \(result)") } completionExpectation.fulfill() } wait(for: [completionExpectation, progressExpectation], timeout: 5) } /// Delete previously downloaded model. func testDeleteModel() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testName = String(#function.dropLast(2)) let testModelName = "pose-detection" let conditions = ModelDownloadConditions() let modelDownloader = ModelDownloader.modelDownloaderWithDefaults( .createTestInstance(testName: testName), app: testApp ) let downloadType: ModelDownloadType = .latestModel let latestModelExpectation = expectation(description: "Get latest model for deletion.") modelDownloader.getModel( name: testModelName, downloadType: downloadType, conditions: conditions, progressHandler: { progress in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") XCTAssertLessThanOrEqual(progress, 1) XCTAssertGreaterThanOrEqual(progress, 0) } ) { result in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") switch result { case let .success(model): XCTAssertNotNil(model.path) let filePath = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath)) case let .failure(error): XCTFail("Failed to download model - \(error)") } latestModelExpectation.fulfill() } wait(for: [latestModelExpectation], timeout: 5) let deleteExpectation = expectation(description: "Wait for model deletion.") modelDownloader.deleteDownloadedModel(name: testModelName) { result in deleteExpectation.fulfill() XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") switch result { case .success: break case let .failure(error): XCTFail("Failed to delete model - \(error)") } } wait(for: [deleteExpectation], timeout: 5) } /// Test listing models in model directory. func testListModels() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testName = String(#function.dropLast(2)) let testModelName = "pose-detection" let conditions = ModelDownloadConditions() let modelDownloader = ModelDownloader.modelDownloaderWithDefaults( .createTestInstance(testName: testName), app: testApp ) let downloadType: ModelDownloadType = .latestModel let latestModelExpectation = expectation(description: "Get latest model.") modelDownloader.getModel( name: testModelName, downloadType: downloadType, conditions: conditions, progressHandler: { progress in XCTAssertLessThanOrEqual(progress, 1) XCTAssertGreaterThanOrEqual(progress, 0) } ) { result in switch result { case let .success(model): XCTAssertNotNil(model.path) let filePath = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath)) case let .failure(error): XCTFail("Failed to download model - \(error)") } latestModelExpectation.fulfill() } wait(for: [latestModelExpectation], timeout: 5) let listExpectation = expectation(description: "Wait for list models.") modelDownloader.listDownloadedModels { result in listExpectation.fulfill() switch result { case let .success(models): XCTAssertGreaterThan(models.count, 0) case let .failure(error): XCTFail("Failed to list models - \(error)") } } wait(for: [listExpectation], timeout: 5) } /// Test logging telemetry event. func testLogTelemetryEvent() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } // Flip this to `true` to test logging. testApp.isDataCollectionDefaultEnabled = false let testModelName = "digit-classification" let testName = String(#function.dropLast(2)) let conditions = ModelDownloadConditions() let modelDownloader = ModelDownloader.modelDownloaderWithDefaults( .createTestInstance(testName: testName), app: testApp ) let latestModelExpectation = expectation(description: "Test get model telemetry.") modelDownloader.getModel( name: testModelName, downloadType: .latestModel, conditions: conditions ) { result in switch result { case .success: break case let .failure(error): XCTFail("Failed to download model - \(error)") } latestModelExpectation.fulfill() } wait(for: [latestModelExpectation], timeout: 5) let deleteModelExpectation = expectation(description: "Test delete model telemetry.") modelDownloader.deleteDownloadedModel(name: testModelName) { result in switch result { case .success(()): break case let .failure(error): XCTFail("Failed to delete model - \(error)") } deleteModelExpectation.fulfill() } wait(for: [deleteModelExpectation], timeout: 5) let notFoundModelName = "fakeModelName" let notFoundModelExpectation = expectation(description: "Test get model telemetry with unknown model.") modelDownloader.getModel( name: notFoundModelName, downloadType: .latestModel, conditions: conditions ) { result in switch result { case .success: XCTFail("Unexpected model success.") case let .failure(error): XCTAssertEqual(error, .notFound) } notFoundModelExpectation.fulfill() } wait(for: [notFoundModelExpectation], timeout: 5) } func testGetModelWithConditions() { guard let testApp = FirebaseApp.app() else { XCTFail("Default app was not configured.") return } testApp.isDataCollectionDefaultEnabled = false let testModelName = "pose-detection" let testName = String(#function.dropLast(2)) let conditions = ModelDownloadConditions(allowsCellularAccess: false) let modelDownloader = ModelDownloader.modelDownloaderWithDefaults( .createTestInstance(testName: testName), app: testApp ) let latestModelExpectation = expectation(description: "Get latest model with conditions.") modelDownloader.getModel( name: testModelName, downloadType: .latestModel, conditions: conditions, progressHandler: { progress in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") XCTAssertLessThanOrEqual(progress, 1) XCTAssertGreaterThanOrEqual(progress, 0) } ) { result in XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.") switch result { case let .success(model): XCTAssertNotNil(model.path) let filePath = URL(fileURLWithPath: model.path) XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath)) case let .failure(error): XCTFail("Failed to download model - \(error)") } latestModelExpectation.fulfill() } wait(for: [latestModelExpectation], timeout: 5) } }