|
|
@@ -28,6 +28,7 @@ private enum MockOptions {
|
|
|
// TODO: Create separate files for each class tested.
|
|
|
final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
let fakeModelName = "fakeModelName"
|
|
|
+ let fakemodelPath = "fakemodelPath"
|
|
|
let fakeModelHash = "fakeModelHash"
|
|
|
let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
|
|
|
let fakeFileURL = URL(string: "www.fake-model-file.com")!
|
|
|
@@ -35,6 +36,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
let fakeModelSize = 20
|
|
|
let fakeProjectID = "fakeProjectID"
|
|
|
let fakeAPIKey = "fakeAPIKey"
|
|
|
+ let fakeAppName = "fakeAppName"
|
|
|
var fakeModelJSON: String {
|
|
|
"""
|
|
|
{
|
|
|
@@ -61,6 +63,14 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
FirebaseConfiguration.shared.setLoggerLevel(.debug)
|
|
|
}
|
|
|
|
|
|
+ override func setUp() {
|
|
|
+ do {
|
|
|
+ try ModelFileManager.emptyModelsDirectory()
|
|
|
+ } catch {
|
|
|
+ XCTFail("Could not empty models directory.")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
override class func tearDown() {
|
|
|
try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
|
|
|
FirebaseApp.app()?.delete { _ in }
|
|
|
@@ -95,7 +105,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
/// Test invalid model file deletion.
|
|
|
func testDeleteInvalidModel() {
|
|
|
let modelDownloader = ModelDownloader.modelDownloader()
|
|
|
- modelDownloader.deleteDownloadedModel(name: "fake-model-name") { result in
|
|
|
+ modelDownloader.deleteDownloadedModel(name: fakeModelName) { result in
|
|
|
switch result {
|
|
|
case .success: XCTFail("Unexpected successful model deletion.")
|
|
|
case let .failure(error):
|
|
|
@@ -104,79 +114,116 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- /// Test listing models in model directory.
|
|
|
- // TODO: Add unit test.
|
|
|
+ func testGetModelNameFromFilePath() {
|
|
|
+ var fakeURL = URL(fileURLWithPath: "modelDirectory/@@appName@@\(fakeModelName).tflite")
|
|
|
+ XCTAssertEqual(ModelFileManager.getModelNameFromFilePath(fakeURL), fakeModelName)
|
|
|
+ fakeURL = URL(fileURLWithPath: "modelDirectory/--appName--\(fakeModelName).tflite")
|
|
|
+ XCTAssertNil(ModelFileManager.getModelNameFromFilePath(fakeURL))
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Test listing models.
|
|
|
func testListModels() {
|
|
|
+ guard let testApp = FirebaseApp.app() else {
|
|
|
+ XCTFail("Default app was not configured.")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
+ let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
|
|
|
+ .createUnitTestInstance(testName: testName),
|
|
|
+ app: testApp
|
|
|
+ )
|
|
|
+
|
|
|
+ let tempModelFileURL = tempModelFile()
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
+
|
|
|
+ let localModelInfo = fakeLocalModelInfo()
|
|
|
+ localModelInfo.writeToDefaults(.getUnitTestInstance(testName: testName), appName: testApp.name)
|
|
|
+
|
|
|
+ modelDownloader.listDownloadedModels { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
+ switch result {
|
|
|
+ case let .success(models):
|
|
|
+ /// Expected success since model info was written to user defaults.
|
|
|
+ XCTAssertEqual(models.first?.path, tempModelFileURL.path)
|
|
|
+ case let .failure(error):
|
|
|
+ XCTFail("Error - \(error)")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+ try? ModelFileManager.removeFile(at: tempModelFileURL)
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Test invalid path in model directory.
|
|
|
+ func testListModelsInvalidPath() {
|
|
|
let modelDownloader = ModelDownloader.modelDownloader()
|
|
|
+ let invalidTempModelFileURL = tempModelFile(invalid: true)
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
|
|
|
modelDownloader.listDownloadedModels { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
- case .success: break
|
|
|
- case .failure: break
|
|
|
+ case .success:
|
|
|
+ XCTFail("Unexpected success of list models.")
|
|
|
+ case let .failure(error):
|
|
|
+ switch error {
|
|
|
+ case let .internalError(description):
|
|
|
+ XCTAssertTrue(description
|
|
|
+ .contains("List models failed due to unexpected model file name"))
|
|
|
+ default: XCTFail("Expected failure due to bad model name.")
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+ try? ModelFileManager.removeFile(at: invalidTempModelFileURL)
|
|
|
}
|
|
|
|
|
|
- func testGetModel() {
|
|
|
- // This is an example of a functional test case.
|
|
|
- // Use XCTAssert and related functions to verify your tests produce the correct
|
|
|
- // results.
|
|
|
+ /// Test listing models without local info in user defaults.
|
|
|
+ func testListModelsNoLocalInfo() {
|
|
|
guard let testApp = FirebaseApp.app() else {
|
|
|
XCTFail("Default app was not configured.")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- let modelDownloader = ModelDownloader.modelDownloader()
|
|
|
- let modelDownloaderWithApp = ModelDownloader.modelDownloader(app: testApp)
|
|
|
-
|
|
|
- /// These should point to the same instance.
|
|
|
- XCTAssert(modelDownloader === modelDownloaderWithApp)
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
+ let tempModelFileURL = tempModelFile()
|
|
|
+ let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
|
|
|
+ .createUnitTestInstance(testName: testName),
|
|
|
+ app: testApp
|
|
|
+ )
|
|
|
|
|
|
- let conditions = ModelDownloadConditions()
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
|
|
|
- // Download model w/ progress handler
|
|
|
- modelDownloader.getModel(
|
|
|
- name: "your_model_name",
|
|
|
- downloadType: .latestModel,
|
|
|
- conditions: conditions,
|
|
|
- progressHandler: { progress in
|
|
|
- // Handle progress
|
|
|
- }
|
|
|
- ) { result in
|
|
|
+ modelDownloader.listDownloadedModels { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
case .success:
|
|
|
- // Use model with your inference API
|
|
|
- // let interpreter = Interpreter(modelPath: customModel.modelPath)
|
|
|
- break
|
|
|
- case .failure:
|
|
|
- // Handle download error
|
|
|
- break
|
|
|
+ XCTFail("Unexpected success of list models.")
|
|
|
+ case let .failure(error):
|
|
|
+ switch error {
|
|
|
+ /// Expected failure since no model info was written to user defaults.
|
|
|
+ case let .internalError(description):
|
|
|
+ XCTAssertTrue(description.contains("List models failed due to no local model info"))
|
|
|
+ default: XCTFail("Expected failure due to bad model name.")
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+ try? ModelFileManager.removeFile(at: tempModelFileURL)
|
|
|
+ }
|
|
|
|
|
|
- // Access array of downloaded models
|
|
|
- modelDownloaderWithApp.listDownloadedModels { result in
|
|
|
- switch result {
|
|
|
- case .success:
|
|
|
- // Pick model(s) for further use
|
|
|
- break
|
|
|
- case .failure:
|
|
|
- // Handle failure
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
+ /// Test how URL conversion behaves if there are spaces in the path.
|
|
|
+ func testURLConversion() {
|
|
|
+ /// Spaces in the string only convert to URL when using URL(fileURLWithPath: ).
|
|
|
+ let fakeURLWithSpace = URL(string: "file:///fakeDir1/fake%20Dir2/fakeFile")!
|
|
|
|
|
|
- // Delete downloaded model
|
|
|
- modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
|
|
|
- switch result {
|
|
|
- case .success():
|
|
|
- // Apply any other clean up
|
|
|
- break
|
|
|
- case .failure:
|
|
|
- // Handle failure
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
+ XCTAssertEqual(fakeURLWithSpace, URL(string: fakeURLWithSpace.absoluteString))
|
|
|
+ XCTAssertEqual(fakeURLWithSpace, URL(fileURLWithPath: fakeURLWithSpace.path))
|
|
|
+ XCTAssertNil(URL(string: fakeURLWithSpace.path))
|
|
|
+
|
|
|
+ /// Strings without spaces should work fine either way.
|
|
|
+ let fakeURLWithoutSpace = URL(string: "fakeDir1/fakeDir2/fakeFile")!
|
|
|
+ XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.absoluteString)!)
|
|
|
+ XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.path)!)
|
|
|
}
|
|
|
|
|
|
/// Test on-device logging.
|
|
|
@@ -189,10 +236,10 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
/// Compare proto serialization methods.
|
|
|
func testTelemetryEncoding() {
|
|
|
let fakeModel = CustomModel(
|
|
|
- name: "fakeModelName",
|
|
|
+ name: fakeModelName,
|
|
|
size: 10,
|
|
|
- path: "fakeModelPath",
|
|
|
- hash: "fakeModelHash"
|
|
|
+ path: fakemodelPath,
|
|
|
+ hash: fakeModelHash
|
|
|
)
|
|
|
var modelOptions = ModelOptions()
|
|
|
modelOptions.setModelOptions(model: fakeModel)
|
|
|
@@ -338,6 +385,56 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
wait(for: [completionExpectation], timeout: 0.5)
|
|
|
}
|
|
|
|
|
|
+ /// Get model info if model name does not exist.
|
|
|
+ func testGetModelInfoWith404() {
|
|
|
+ let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 404)
|
|
|
+ let localModelInfo = fakeLocalModelInfo(mismatch: false)
|
|
|
+ let modelInfoRetriever = fakeModelRetriever(
|
|
|
+ fakeSession: session,
|
|
|
+ fakeLocalModelInfo: localModelInfo
|
|
|
+ )
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
+
|
|
|
+ modelInfoRetriever.downloadModelInfo { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
+ switch result {
|
|
|
+ case let .success(fakemodelInfoResult):
|
|
|
+ switch fakemodelInfoResult {
|
|
|
+ case .modelInfo: XCTFail("Unexpected new model info from server.")
|
|
|
+ case .notModified: XCTFail("Expected failure since model name is not availabl.")
|
|
|
+ }
|
|
|
+ case let .failure(error):
|
|
|
+ XCTAssertEqual(error, .notFound)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Get model info if resource exhausted due to too many requests.
|
|
|
+ func testGetModelInfoWith429() {
|
|
|
+ let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 429)
|
|
|
+ let localModelInfo = fakeLocalModelInfo(mismatch: false)
|
|
|
+ let modelInfoRetriever = fakeModelRetriever(
|
|
|
+ fakeSession: session,
|
|
|
+ fakeLocalModelInfo: localModelInfo
|
|
|
+ )
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
+
|
|
|
+ modelInfoRetriever.downloadModelInfo { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
+ switch result {
|
|
|
+ case let .success(fakemodelInfoResult):
|
|
|
+ switch fakemodelInfoResult {
|
|
|
+ case .modelInfo: XCTFail("Unexpected new model info from server.")
|
|
|
+ case .notModified: XCTFail("Expected failure since resource exhausted.")
|
|
|
+ }
|
|
|
+ case let .failure(error):
|
|
|
+ XCTAssertEqual(error, .resourceExhausted)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+ }
|
|
|
+
|
|
|
/// Get model if server returns a model file.
|
|
|
func testModelDownloadWith200() {
|
|
|
let tempFileURL = tempFile()
|
|
|
@@ -365,19 +462,18 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
case let .success(model):
|
|
|
- XCTAssertEqual(model.name, self.fakeModelName)
|
|
|
- XCTAssertEqual(model.size, self.fakeModelSize)
|
|
|
- XCTAssertEqual(model.hash, self.fakeModelHash)
|
|
|
- let modelPath = URL(string: model.path)!
|
|
|
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
|
|
|
- try? self.deleteFile(at: modelPath)
|
|
|
+ self.checkModel(model: model)
|
|
|
+ self.deleteFile(at: model.path)
|
|
|
case let .failure(error):
|
|
|
XCTFail("Error - \(error)")
|
|
|
}
|
|
|
}
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
|
|
|
appName: "fakeAppName",
|
|
|
- defaults: .createTestInstance(testName: #function),
|
|
|
+ defaults: .createUnitTestInstance(
|
|
|
+ testName: testName
|
|
|
+ ),
|
|
|
downloader: downloader,
|
|
|
progressHandler: taskProgressHandler,
|
|
|
completion: taskCompletion)
|
|
|
@@ -421,9 +517,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
XCTAssertEqual(error, .expiredDownloadURL)
|
|
|
}
|
|
|
}
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
|
|
|
appName: "fakeAppName",
|
|
|
- defaults: .createTestInstance(testName: #function),
|
|
|
+ defaults: .createUnitTestInstance(
|
|
|
+ testName: testName
|
|
|
+ ),
|
|
|
downloader: downloader,
|
|
|
completion: taskCompletion)
|
|
|
modelDownloadTask.resume()
|
|
|
@@ -461,9 +560,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
XCTAssertEqual(error, .invalidArgument)
|
|
|
}
|
|
|
}
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
|
|
|
appName: "fakeAppName",
|
|
|
- defaults: .createTestInstance(testName: #function),
|
|
|
+ defaults: .createUnitTestInstance(
|
|
|
+ testName: testName
|
|
|
+ ),
|
|
|
downloader: downloader,
|
|
|
completion: taskCompletion)
|
|
|
modelDownloadTask.resume()
|
|
|
@@ -477,6 +579,41 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
try? ModelFileManager.removeFile(at: tempFileURL)
|
|
|
}
|
|
|
|
|
|
+ /// Get model info if Firebase ML API is not enabled or permission denied.
|
|
|
+ func testGetModelInfoWith403() {
|
|
|
+ let errorMessage =
|
|
|
+ """
|
|
|
+ {
|
|
|
+ "error":
|
|
|
+ {
|
|
|
+ "message":"API Disabled."
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """
|
|
|
+ let session = fakeModelInfoSessionWithURL(
|
|
|
+ fakeDownloadURL,
|
|
|
+ statusCode: 403,
|
|
|
+ customJSON: errorMessage
|
|
|
+ )
|
|
|
+ let localModelInfo = fakeLocalModelInfo(mismatch: false)
|
|
|
+ let modelInfoRetriever = fakeModelRetriever(
|
|
|
+ fakeSession: session,
|
|
|
+ fakeLocalModelInfo: localModelInfo
|
|
|
+ )
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
+
|
|
|
+ modelInfoRetriever.downloadModelInfo { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
+ switch result {
|
|
|
+ case .success:
|
|
|
+ XCTFail("Unexpected new model info from server.")
|
|
|
+ case let .failure(error):
|
|
|
+ XCTAssertEqual(error, .permissionDenied)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+ }
|
|
|
+
|
|
|
/// Get model if multiple duplicate requests are made.
|
|
|
func testModelDownloadWithMergeRequests() {
|
|
|
let tempFileURL = tempFile()
|
|
|
@@ -501,19 +638,18 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
case let .success(model):
|
|
|
- XCTAssertEqual(model.name, self.fakeModelName)
|
|
|
- XCTAssertEqual(model.size, self.fakeModelSize)
|
|
|
- XCTAssertEqual(model.hash, self.fakeModelHash)
|
|
|
- let modelPath = URL(string: model.path)!
|
|
|
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
|
|
|
+ self.checkModel(model: model)
|
|
|
case let .failure(error):
|
|
|
XCTFail("Error - \(error)")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
|
|
|
- appName: "fakeAppName",
|
|
|
- defaults: .createTestInstance(testName: #function),
|
|
|
+ appName: fakeAppName,
|
|
|
+ defaults: .createUnitTestInstance(
|
|
|
+ testName: testName
|
|
|
+ ),
|
|
|
downloader: downloader,
|
|
|
completion: taskCompletion)
|
|
|
|
|
|
@@ -522,13 +658,9 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
case let .success(model):
|
|
|
- XCTAssertEqual(model.name, self.fakeModelName)
|
|
|
- XCTAssertEqual(model.size, self.fakeModelSize)
|
|
|
- XCTAssertEqual(model.hash, self.fakeModelHash)
|
|
|
- let modelPath = URL(string: model.path)!
|
|
|
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
|
|
|
+ self.checkModel(model: model)
|
|
|
if i == numberOfRequests {
|
|
|
- try? self.deleteFile(at: modelPath)
|
|
|
+ self.deleteFile(at: model.path)
|
|
|
}
|
|
|
case let .failure(error):
|
|
|
XCTFail("Error - \(error)")
|
|
|
@@ -578,13 +710,9 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
case let .success(model):
|
|
|
- XCTAssertEqual(model.name, self.fakeModelName)
|
|
|
- XCTAssertEqual(model.size, self.fakeModelSize)
|
|
|
- XCTAssertEqual(model.hash, self.fakeModelHash)
|
|
|
- let modelPath = URL(string: model.path)!
|
|
|
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
|
|
|
+ self.checkModel(model: model)
|
|
|
if i == numberOfRequests {
|
|
|
- try? self.deleteFile(at: modelPath)
|
|
|
+ self.deleteFile(at: model.path)
|
|
|
}
|
|
|
case let .failure(error):
|
|
|
XCTFail("Error - \(error)")
|
|
|
@@ -626,9 +754,12 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
XCTAssertEqual(error, .notFound)
|
|
|
}
|
|
|
}
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
|
|
|
appName: "fakeAppName",
|
|
|
- defaults: .createTestInstance(testName: #function),
|
|
|
+ defaults: .createUnitTestInstance(
|
|
|
+ testName: testName
|
|
|
+ ),
|
|
|
downloader: downloader,
|
|
|
completion: taskCompletion)
|
|
|
modelDownloadTask.resume()
|
|
|
@@ -642,6 +773,91 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
try? ModelFileManager.removeFile(at: tempFileURL)
|
|
|
}
|
|
|
|
|
|
+ /// Get model if server returns an error due to model not found.
|
|
|
+ func testModelDownloadNetworkError() {
|
|
|
+ let tempFileURL = tempFile()
|
|
|
+ let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
|
|
|
+ let downloader = MockModelFileDownloader()
|
|
|
+ let fileDownloaderExpectation = expectation(description: "File downloader")
|
|
|
+ downloader.downloadFileHandler = { url in
|
|
|
+ fileDownloaderExpectation.fulfill()
|
|
|
+ XCTAssertEqual(url, self.fakeDownloadURL)
|
|
|
+ }
|
|
|
+
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
+ let taskCompletion: ModelDownloadTask.Completion = { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
+ switch result {
|
|
|
+ case .success: XCTFail("Unexpected successful model download.")
|
|
|
+ case let .failure(error):
|
|
|
+ XCTAssertEqual(error, .failedPrecondition)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ let testName = String(#function.dropLast(2))
|
|
|
+ let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
|
|
|
+ appName: "fakeAppName",
|
|
|
+ defaults: .createUnitTestInstance(
|
|
|
+ testName: testName
|
|
|
+ ),
|
|
|
+ downloader: downloader,
|
|
|
+ completion: taskCompletion)
|
|
|
+ modelDownloadTask.resume()
|
|
|
+ wait(for: [fileDownloaderExpectation], timeout: 0.1)
|
|
|
+
|
|
|
+ let offlineError = DownloadError.internalError(description: "Network appears to be offline.")
|
|
|
+ downloader
|
|
|
+ .completion?(.failure(FileDownloaderError.networkError(offlineError)))
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+
|
|
|
+ try? ModelFileManager.removeFile(at: tempFileURL)
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Get model with a valid server response.
|
|
|
+ func testGetModelSuccessful() {
|
|
|
+ let modelDownloader = ModelDownloader.modelDownloader()
|
|
|
+ let conditions = ModelDownloadConditions()
|
|
|
+ let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
|
|
|
+ let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
|
|
|
+ let downloader = MockModelFileDownloader()
|
|
|
+ let tempFileURL = tempFile()
|
|
|
+
|
|
|
+ let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
|
|
|
+ statusCode: 200,
|
|
|
+ httpVersion: nil,
|
|
|
+ headerFields: nil)!
|
|
|
+
|
|
|
+ let fileDownloaderExpectation = expectation(description: "File downloader")
|
|
|
+ let completionExpectation = expectation(description: "Completion handler")
|
|
|
+
|
|
|
+ // Handles initial response.
|
|
|
+ downloader.downloadFileHandler = { url in
|
|
|
+ fileDownloaderExpectation.fulfill()
|
|
|
+ XCTAssertEqual(url, self.fakeDownloadURL)
|
|
|
+ }
|
|
|
+
|
|
|
+ modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
|
|
|
+ modelInfoRetriever: modelInfoRetriever,
|
|
|
+ downloader: downloader,
|
|
|
+ conditions: conditions) { result in
|
|
|
+ completionExpectation.fulfill()
|
|
|
+ switch result {
|
|
|
+ case let .success(model):
|
|
|
+ self.checkModel(model: model)
|
|
|
+ self.deleteFile(at: model.path)
|
|
|
+ case let .failure(error):
|
|
|
+ XCTFail("Error - \(error)")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ wait(for: [fileDownloaderExpectation], timeout: 0.5)
|
|
|
+
|
|
|
+ downloader
|
|
|
+ .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
|
|
|
+ fileURL: tempFileURL)))
|
|
|
+ wait(for: [completionExpectation], timeout: 0.5)
|
|
|
+
|
|
|
+ try? ModelFileManager.removeFile(at: tempFileURL)
|
|
|
+ }
|
|
|
+
|
|
|
/// Get model if download url expired but retry was successful.
|
|
|
func testGetModelSuccessfulRetry() {
|
|
|
let modelDownloader = ModelDownloader.modelDownloader()
|
|
|
@@ -678,12 +894,8 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
switch result {
|
|
|
// Retry successful - model returned.
|
|
|
case let .success(model):
|
|
|
- XCTAssertEqual(model.name, self.fakeModelName)
|
|
|
- XCTAssertEqual(model.size, self.fakeModelSize)
|
|
|
- XCTAssertEqual(model.hash, self.fakeModelHash)
|
|
|
- let modelPath = URL(string: model.path)!
|
|
|
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
|
|
|
- try? self.deleteFile(at: modelPath)
|
|
|
+ self.checkModel(model: model)
|
|
|
+ self.deleteFile(at: model.path)
|
|
|
case let .failure(error):
|
|
|
XCTFail("Error - \(error)")
|
|
|
}
|
|
|
@@ -771,7 +983,7 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
httpVersion: nil,
|
|
|
headerFields: nil)!
|
|
|
|
|
|
- let fileDownloaderExpectation = expectation(description: "File downloader before retry")
|
|
|
+ let fileDownloaderExpectation = expectation(description: "File downloader")
|
|
|
let completionExpectation = expectation(description: "Completion handler")
|
|
|
|
|
|
// Handles initial response.
|
|
|
@@ -786,7 +998,6 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
conditions: conditions) { result in
|
|
|
completionExpectation.fulfill()
|
|
|
switch result {
|
|
|
- // Retry failed due to another expired url - error returned.
|
|
|
case .success: XCTFail("Unexpected successful model download.")
|
|
|
case let .failure(error):
|
|
|
switch error {
|
|
|
@@ -974,12 +1185,8 @@ final class ModelDownloaderUnitTests: XCTestCase {
|
|
|
switch result {
|
|
|
// One retry failed but the next retry succeeded - model returned.
|
|
|
case let .success(model):
|
|
|
- XCTAssertEqual(model.name, self.fakeModelName)
|
|
|
- XCTAssertEqual(model.size, self.fakeModelSize)
|
|
|
- XCTAssertEqual(model.hash, self.fakeModelHash)
|
|
|
- let modelPath = URL(string: model.path)!
|
|
|
- XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
|
|
|
- try? self.deleteFile(at: modelPath)
|
|
|
+ self.checkModel(model: model)
|
|
|
+ self.deleteFile(at: model.path)
|
|
|
case let .failure(error):
|
|
|
XCTFail("Error - \(error)")
|
|
|
}
|
|
|
@@ -1044,21 +1251,28 @@ class MockModelFileDownloader: FileDownloader {
|
|
|
|
|
|
extension UserDefaults {
|
|
|
/// Returns a new cleared instance of user defaults.
|
|
|
- static func createTestInstance(testName: String) -> UserDefaults {
|
|
|
+ static func createUnitTestInstance(testName: String) -> UserDefaults {
|
|
|
let suiteName = "com.google.firebase.ml.test.\(testName)"
|
|
|
- // TODO: reconsider force unwrapping
|
|
|
let defaults = UserDefaults(suiteName: suiteName)!
|
|
|
defaults.removePersistentDomain(forName: suiteName)
|
|
|
return defaults
|
|
|
}
|
|
|
+
|
|
|
+ /// Returns the existing user defaults instance.
|
|
|
+ static func getUnitTestInstance(testName: String) -> UserDefaults {
|
|
|
+ let suiteName = "com.google.firebase.ml.test.\(testName)"
|
|
|
+ return UserDefaults(suiteName: suiteName)!
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// MARK: - Helpers
|
|
|
|
|
|
extension ModelDownloaderUnitTests {
|
|
|
- func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int) -> MockModelInfoRetrieverSession {
|
|
|
+ func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int,
|
|
|
+ customJSON: String? = nil) -> MockModelInfoRetrieverSession {
|
|
|
let fakeSession = MockModelInfoRetrieverSession()
|
|
|
- fakeSession.data = fakeModelJSON.data(using: .utf8)
|
|
|
+ let fakeJSON = customJSON ?? fakeModelJSON
|
|
|
+ fakeSession.data = fakeJSON.data(using: .utf8)
|
|
|
fakeSession.response = HTTPURLResponse(
|
|
|
url: url,
|
|
|
statusCode: statusCode,
|
|
|
@@ -1071,19 +1285,25 @@ extension ModelDownloaderUnitTests {
|
|
|
|
|
|
func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
|
|
|
fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
|
|
|
- // TODO: Replace with a fake one so we can check that is was used correctly by the download task.
|
|
|
+ // TODO: Replace with fake to check if it was used correctly by the download task.
|
|
|
let modelInfoRetriever = ModelInfoRetriever(
|
|
|
- modelName: "fakeModelName",
|
|
|
- projectID: "fakeProjectID",
|
|
|
- apiKey: "fakeAPIKey",
|
|
|
- authTokenProvider: successAuthTokenProvider,
|
|
|
- appName: "fakeAppName",
|
|
|
- localModelInfo: fakeLocalModelInfo,
|
|
|
- session: fakeSession
|
|
|
+ modelName: fakeModelName,
|
|
|
+ projectID: fakeProjectID,
|
|
|
+ apiKey: fakeAPIKey,
|
|
|
+ appName: fakeAppName, authTokenProvider: successAuthTokenProvider,
|
|
|
+ session: fakeSession, localModelInfo: fakeLocalModelInfo
|
|
|
)
|
|
|
return modelInfoRetriever
|
|
|
}
|
|
|
|
|
|
+ func checkModel(model: CustomModel) {
|
|
|
+ XCTAssertEqual(model.name, fakeModelName)
|
|
|
+ XCTAssertEqual(model.size, fakeModelSize)
|
|
|
+ XCTAssertEqual(model.hash, fakeModelHash)
|
|
|
+ let modelURL = URL(fileURLWithPath: model.path)
|
|
|
+ XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
|
|
|
+ }
|
|
|
+
|
|
|
func tempFile(name: String = "fake-model-file.tmp") -> URL {
|
|
|
let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
|
|
|
isDirectory: true)
|
|
|
@@ -1093,32 +1313,38 @@ extension ModelDownloaderUnitTests {
|
|
|
return tempFileURL
|
|
|
}
|
|
|
|
|
|
- func deleteFile(at url: URL) throws {
|
|
|
- try ModelFileManager.removeFile(at: url)
|
|
|
+ func tempModelFile(invalid: Bool = false) -> URL {
|
|
|
+ let modelDirectoryURL = ModelFileManager.modelsDirectory!
|
|
|
+ let modelFileName = invalid ? "fbml_model_fake-model-file.tmp" :
|
|
|
+ "fbml_model@@__FIRAPP_DEFAULT@@\(fakeModelName)"
|
|
|
+ let tempFileURL = modelDirectoryURL.appendingPathComponent(modelFileName)
|
|
|
+ .appendingPathExtension("tflite")
|
|
|
+ let tempData: Data = "fakeModelData".data(using: .utf8)!
|
|
|
+ try? tempData.write(to: tempFileURL)
|
|
|
+ return tempFileURL
|
|
|
}
|
|
|
|
|
|
- func fakeLocalModelInfo(mismatch: Bool) -> LocalModelInfo {
|
|
|
- var hash = "fakeModelHash"
|
|
|
- if mismatch {
|
|
|
- hash = "wrongModelHash"
|
|
|
+ func deleteFile(at path: String) {
|
|
|
+ let url = URL(fileURLWithPath: path)
|
|
|
+ do {
|
|
|
+ try ModelFileManager.removeFile(at: url)
|
|
|
+ } catch {
|
|
|
+ XCTFail("Failed to delete file at \(path).")
|
|
|
}
|
|
|
+ }
|
|
|
+
|
|
|
+ func fakeLocalModelInfo(mismatch: Bool = false) -> LocalModelInfo {
|
|
|
+ let hash = mismatch ? "wrongModelHash" : fakeModelHash
|
|
|
return LocalModelInfo(
|
|
|
- name: "fakeModelName",
|
|
|
+ name: fakeModelName,
|
|
|
modelHash: hash,
|
|
|
- size: 20,
|
|
|
- path: "fakeModelPath"
|
|
|
+ size: 20
|
|
|
)
|
|
|
}
|
|
|
|
|
|
func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
|
|
|
- var expiryTime = Date()
|
|
|
- var size = fakeModelSize
|
|
|
- if !expired {
|
|
|
- expiryTime = expiryTime.addingTimeInterval(1000)
|
|
|
- }
|
|
|
- if oversized {
|
|
|
- size = Int(Int64.max) / 4
|
|
|
- }
|
|
|
+ let expiryTime = expired ? Date() : Date().addingTimeInterval(1000)
|
|
|
+ let size = oversized ? Int(Int64.max) / 4 : fakeModelSize
|
|
|
return RemoteModelInfo(name: fakeModelName,
|
|
|
downloadURL: fakeDownloadURL,
|
|
|
modelHash: fakeModelHash,
|