| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- // Copyright 2021 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- import XCTest
- @testable import FirebaseCore
- @testable import FirebaseInstallations
- @testable import FirebaseMLModelDownloader
- extension UserDefaults {
- /// Returns a new cleared instance of user defaults.
- static func createTestInstance(testName: String) -> UserDefaults {
- let suiteName = "com.google.firebase.ml.test.\(testName)"
- // TODO: reconsider force unwrapping
- let defaults = UserDefaults(suiteName: suiteName)!
- defaults.removePersistentDomain(forName: suiteName)
- return defaults
- }
- /// Returns the existing user defaults instance.
- static func getTestInstance(testName: String) -> UserDefaults {
- let suiteName = "com.google.firebase.ml.test.\(testName)"
- // TODO: reconsider force unwrapping
- return UserDefaults(suiteName: suiteName)!
- }
- }
- // TODO: Use FirebaseApp internal init for testApp
- final class ModelDownloaderIntegrationTests: XCTestCase {
- override class func setUp() {
- super.setUp()
- let bundle = Bundle(for: self)
- if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
- let options = FirebaseOptions(contentsOfFile: plistPath) {
- FirebaseApp.configure(options: options)
- } else {
- XCTFail("Could not locate GoogleService-Info.plist.")
- }
- FirebaseConfiguration.shared.setLoggerLevel(.debug)
- }
- /// Test to download model info - makes an actual network call.
- func testDownloadModelInfo() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testModelName = "pose-detection"
- let modelInfoRetriever = ModelInfoRetriever(
- modelName: testModelName,
- projectID: testApp.options.projectID!,
- apiKey: testApp.options.apiKey!,
- installations: Installations.installations(app: testApp),
- appName: testApp.name
- )
- let downloadExpectation = expectation(description: "Wait for model info to download.")
- modelInfoRetriever.downloadModelInfo(completion: { result in
- switch result {
- case let .success(modelInfoResult):
- switch modelInfoResult {
- case let .modelInfo(modelInfo):
- XCTAssertNotNil(modelInfo.urlExpiryTime)
- XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
- XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
- XCTAssertGreaterThan(modelInfo.size, 0)
- let localModelInfo = LocalModelInfo(from: modelInfo, path: "mock-valid-path")
- localModelInfo.writeToDefaults(
- .createTestInstance(testName: #function),
- appName: testApp.name
- )
- case .notModified:
- XCTFail("Failed to retrieve model info.")
- }
- case let .failure(error):
- XCTAssertNotNil(error)
- XCTFail("Failed to retrieve model info - \(error)")
- }
- downloadExpectation.fulfill()
- })
- waitForExpectations(timeout: 5, handler: nil)
- if let localInfo = LocalModelInfo(
- fromDefaults: .getTestInstance(testName: #function),
- name: testModelName,
- appName: testApp.name
- ) {
- XCTAssertNotNil(localInfo)
- testRetrieveModelInfo(localInfo: localInfo)
- } else {
- XCTFail("Could not save model info locally.")
- }
- }
- func testRetrieveModelInfo(localInfo: LocalModelInfo) {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testModelName = "pose-detection"
- let modelInfoRetriever = ModelInfoRetriever(
- modelName: testModelName,
- projectID: testApp.options.projectID!,
- apiKey: testApp.options.apiKey!,
- installations: Installations.installations(app: testApp),
- appName: testApp.name,
- localModelInfo: localInfo
- )
- let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
- modelInfoRetriever.downloadModelInfo(completion: { result in
- switch result {
- case let .success(modelInfoResult):
- switch modelInfoResult {
- case .modelInfo:
- XCTFail("Local model info is already the latest and should not be set again.")
- case .notModified: break
- }
- case let .failure(error):
- XCTAssertNotNil(error)
- XCTFail("Failed to retrieve model info - \(error)")
- }
- retrieveExpectation.fulfill()
- })
- waitForExpectations(timeout: 5, handler: nil)
- }
- /// Test to download model file - makes an actual network call.
- func testModelDownload() throws {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testName = #function.dropLast(2)
- let testModelName = "\(testName)-test-model"
- let urlString =
- "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
- let url = URL(string: urlString)!
- let remoteModelInfo = RemoteModelInfo(
- name: testModelName,
- downloadURL: url,
- modelHash: "mock-valid-hash",
- size: 10,
- urlExpiryTime: Date()
- )
- let conditions = ModelDownloadConditions()
- let expectation = self.expectation(description: "Wait for model to download.")
- let downloader = ModelFileDownloader(conditions: conditions)
- let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- let taskCompletion: ModelDownloadTask.Completion = { result in
- switch result {
- case let .success(model):
- guard let modelPath = URL(string: model.path) else {
- XCTFail("Invalid or empty model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
- /// Remove downloaded model file.
- do {
- try ModelFileManager.removeFile(at: modelPath)
- } catch {
- XCTFail("Model removal failed - \(error)")
- }
- case let .failure(error):
- XCTFail("Error: \(error)")
- }
- expectation.fulfill()
- }
- let modelDownloadManager = ModelDownloadTask(
- remoteModelInfo: remoteModelInfo,
- appName: testApp.name,
- defaults: .createTestInstance(testName: #function),
- downloader: downloader,
- progressHandler: taskProgressHandler,
- completion: taskCompletion
- )
- modelDownloadManager.resume()
- waitForExpectations(timeout: 5, handler: nil)
- XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
- }
- func testGetModel() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testName = String(#function.dropLast(2))
- let testModelName = "image-classification"
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- /// Test download type - latest model.
- var downloadType: ModelDownloadType = .latestModel
- let latestModelExpectation = expectation(description: "Get latest model.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- guard let modelPath = URL(string: model.path) else {
- XCTFail("Invalid model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- /// Test download type - local model.
- downloadType = .localModel
- let localModelExpectation = expectation(description: "Get local model.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTFail("Model is already available on device.")
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- guard let modelPath = URL(string: model.path) else {
- XCTFail("Invalid model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- localModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- /// Test download type - local model update in background.
- downloadType = .localModelUpdateInBackground
- let backgroundModelExpectation =
- expectation(description: "Get local model and update in background.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTFail("Model is already available on device.")
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- guard let modelPath = URL(string: model.path) else {
- XCTFail("Invalid model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
- /// Remove downloaded model file.
- do {
- try ModelFileManager.removeFile(at: modelPath)
- } catch {
- XCTFail("Model removal failed - \(error)")
- }
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- backgroundModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- }
- /// Delete previously downloaded model.
- func testDeleteModel() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testName = String(#function.dropLast(2))
- let testModelName = "pose-detection"
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let downloadType: ModelDownloadType = .latestModel
- let latestModelExpectation = expectation(description: "Get latest model for deletion.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- guard let filePath = URL(string: model.path) else {
- XCTFail("Invalid model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- modelDownloader.deleteDownloadedModel(name: testModelName) { result in
- switch result {
- case .success: break
- case let .failure(error):
- XCTFail("Failed to delete model - \(error)")
- }
- }
- }
- /// Test listing models in model directory.
- func testListModels() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testName = String(#function.dropLast(2))
- let testModelName = "pose-detection"
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let downloadType: ModelDownloadType = .latestModel
- let latestModelExpectation = expectation(description: "Get latest model.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- guard let filePath = URL(string: model.path) else {
- XCTFail("Invalid model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- modelDownloader.listDownloadedModels { result in
- switch result {
- case let .success(models):
- XCTAssertGreaterThan(models.count, 0)
- case let .failure(error):
- XCTFail("Failed to list models - \(error)")
- }
- }
- }
- /// Test logging telemetry event.
- func testLogTelemetryEvent() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testModelName = "image-classification"
- let testName = #function
- let conditions = ModelDownloadConditions()
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let latestModelExpectation = expectation(description: "Test telemetry.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: .latestModel,
- conditions: conditions
- ) { result in
- switch result {
- case let .success(model):
- guard let telemetryLogger = TelemetryLogger(app: testApp) else {
- XCTFail("Could not initialize logger.")
- return
- }
- // TODO: Remove actual logging and stub out with mocks.
- telemetryLogger.logModelDownloadEvent(
- eventName: .modelDownload,
- status: .succeeded,
- model: model,
- downloadErrorCode: .noError
- )
- let modelPath = URL(string: model.path)!
- try? ModelFileManager.removeFile(at: modelPath)
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- }
- func testGetModelWithConditions() {
- guard let testApp = FirebaseApp.app() else {
- XCTFail("Default app was not configured.")
- return
- }
- let testModelName = "pose-detection"
- let testName = #function
- // TODO: Figure out a better way to test this.
- var conditions = ModelDownloadConditions()
- conditions.allowsCellularAccess = false
- let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
- .createTestInstance(testName: testName),
- app: testApp
- )
- let latestModelExpectation = expectation(description: "Get latest model with conditions.")
- modelDownloader.getModel(
- name: testModelName,
- downloadType: .latestModel,
- conditions: conditions,
- progressHandler: { progress in
- XCTAssertLessThanOrEqual(progress, 1)
- XCTAssertGreaterThanOrEqual(progress, 0)
- }
- ) { result in
- switch result {
- case let .success(model):
- XCTAssertNotNil(model.path)
- guard let filePath = URL(string: model.path) else {
- XCTFail("Invalid model path.")
- return
- }
- XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
- case let .failure(error):
- XCTFail("Failed to download model - \(error)")
- }
- latestModelExpectation.fulfill()
- }
- waitForExpectations(timeout: 5, handler: nil)
- }
- }
|