ModelDownloaderUnitTests.swift 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
  15. // involve interactions with the keychain that require a provisioning profile.
  16. // This is because MLModelDownloader depends on Installations.
  17. // See go/firebase-macos-keychain-popups for more details.
  18. #if !targetEnvironment(macCatalyst) && !os(macOS)
  19. @testable import FirebaseCore
  20. @testable import FirebaseInstallations
  21. @testable import FirebaseMLModelDownloader
  22. import XCTest
  23. #if SWIFT_PACKAGE
  24. @_implementationOnly import GoogleUtilities_Logger
  25. @_implementationOnly import GoogleUtilities_UserDefaults
  26. #else
  27. @_implementationOnly import GoogleUtilities
  28. #endif // SWIFT_PACKAGE
  29. /// Mock options to configure default Firebase app.
  30. private enum MockOptions {
  31. static let appID = "1:123:ios:123abc"
  32. static let gcmSenderID = "mock-sender-id"
  33. static let projectID = "mock-project-id"
  34. static let apiKey = "ABcdEf-APIKeyWithValidFormat_0123456789"
  35. }
  36. // TODO: Create separate files for each class tested.
  37. final class ModelDownloaderUnitTests: XCTestCase {
  38. let fakeModelName = "fakeModelName"
  39. let fakemodelPath = "fakemodelPath"
  40. let fakeModelHash = "fakeModelHash"
  41. let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
  42. let fakeFileURL = URL(string: "www.fake-model-file.com")!
  43. let fakeExpiryTime = "2021-01-20T04:20:10.220Z"
  44. let fakeModelSize = 20
  45. let fakeProjectID = "fakeProjectID"
  46. let fakeAPIKey = "fakeAPIKey"
  47. let fakeAppName = "fakeAppName"
  48. var fakeModelJSON: String {
  49. """
  50. {
  51. "downloadUri":"\(fakeDownloadURL)",
  52. "expireTime":"\(fakeExpiryTime)",
  53. "sizeBytes":"\(fakeModelSize)"
  54. }
  55. """
  56. }
  57. let successAuthTokenProvider =
  58. { (completion: @escaping (Result<String, DownloadError>) -> Void) in
  59. completion(.success("fakeFISToken"))
  60. }
  61. override class func setUp() {
  62. let options = FirebaseOptions(
  63. googleAppID: MockOptions.appID,
  64. gcmSenderID: MockOptions.gcmSenderID
  65. )
  66. options.apiKey = MockOptions.apiKey
  67. options.projectID = MockOptions.projectID
  68. FirebaseApp.configure(options: options)
  69. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  70. }
  71. override func setUp() {
  72. do {
  73. try ModelFileManager.emptyModelsDirectory()
  74. } catch {
  75. XCTFail("Could not empty models directory.")
  76. }
  77. }
  78. override class func tearDown() {
  79. try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
  80. FirebaseApp.app()?.delete { _ in }
  81. }
  82. /// Test if the same downloader instance is returned for an app.
  83. func testModelDownloaderInit() {
  84. guard let testApp = FirebaseApp.app() else {
  85. XCTFail("Default app was not configured.")
  86. return
  87. }
  88. testApp.isDataCollectionDefaultEnabled = false
  89. let modelDownloader1 = ModelDownloader.modelDownloader(app: testApp)
  90. let modelDownloader2 = ModelDownloader.modelDownloader()
  91. XCTAssert(modelDownloader1 === modelDownloader2)
  92. }
  93. /// Test if app deletion removes associated downloader instance.
  94. func testAppDeletion() {
  95. guard let testApp = FirebaseApp.app() else {
  96. XCTFail("Default app was not configured.")
  97. return
  98. }
  99. testApp.isDataCollectionDefaultEnabled = false
  100. let modelDownloader1 = ModelDownloader.modelDownloader()
  101. testApp.delete { success in
  102. XCTAssertTrue(success)
  103. }
  104. ModelDownloaderUnitTests.setUp()
  105. let modelDownloader2 = ModelDownloader.modelDownloader()
  106. XCTAssert(modelDownloader1 !== modelDownloader2)
  107. }
  108. /// Test invalid model file deletion.
  109. func testDeleteInvalidModel() {
  110. guard let testApp = FirebaseApp.app() else {
  111. XCTFail("Default app was not configured.")
  112. return
  113. }
  114. testApp.isDataCollectionDefaultEnabled = false
  115. let modelDownloader = ModelDownloader.modelDownloader()
  116. modelDownloader.deleteDownloadedModel(name: fakeModelName) { result in
  117. switch result {
  118. case .success: XCTFail("Unexpected successful model deletion.")
  119. case let .failure(error):
  120. XCTAssertNotNil(error)
  121. }
  122. }
  123. }
  124. func testGetModelNameFromFilePath() {
  125. var fakeURL = URL(fileURLWithPath: "modelDirectory/@@appName@@\(fakeModelName).tflite")
  126. XCTAssertEqual(ModelFileManager.getModelNameFromFilePath(fakeURL), fakeModelName)
  127. fakeURL = URL(fileURLWithPath: "modelDirectory/--appName--\(fakeModelName).tflite")
  128. XCTAssertNil(ModelFileManager.getModelNameFromFilePath(fakeURL))
  129. }
  130. /// Test listing models.
  131. func testListModels() {
  132. guard let testApp = FirebaseApp.app() else {
  133. XCTFail("Default app was not configured.")
  134. return
  135. }
  136. testApp.isDataCollectionDefaultEnabled = false
  137. let testName = String(#function.dropLast(2))
  138. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  139. .createUnitTestInstance(testName: testName),
  140. app: testApp
  141. )
  142. let tempModelFileURL = tempModelFile()
  143. let completionExpectation = expectation(description: "Completion handler")
  144. let localModelInfo = fakeLocalModelInfo()
  145. localModelInfo.writeToDefaults(
  146. .getUnitTestInstance(testName: testName),
  147. appName: testApp.name
  148. )
  149. modelDownloader.listDownloadedModels { result in
  150. completionExpectation.fulfill()
  151. switch result {
  152. case let .success(models):
  153. /// Expected success since model info was written to user defaults.
  154. XCTAssertEqual(models.first?.path, tempModelFileURL.path)
  155. case let .failure(error):
  156. XCTFail("Error - \(error)")
  157. }
  158. }
  159. wait(for: [completionExpectation], timeout: 0.5)
  160. try? ModelFileManager.removeFile(at: tempModelFileURL)
  161. }
  162. /// Test invalid path in model directory.
  163. func testListModelsInvalidPath() {
  164. guard let testApp = FirebaseApp.app() else {
  165. XCTFail("Default app was not configured.")
  166. return
  167. }
  168. testApp.isDataCollectionDefaultEnabled = false
  169. let modelDownloader = ModelDownloader.modelDownloader()
  170. let invalidTempModelFileURL = tempModelFile(invalid: true)
  171. let completionExpectation = expectation(description: "Completion handler")
  172. modelDownloader.listDownloadedModels { result in
  173. completionExpectation.fulfill()
  174. switch result {
  175. case .success:
  176. XCTFail("Unexpected success of list models.")
  177. case let .failure(error):
  178. switch error {
  179. case let .internalError(description):
  180. XCTAssertTrue(description
  181. .contains("List models failed due to unexpected model file name"))
  182. default: XCTFail("Expected failure due to bad model name.")
  183. }
  184. }
  185. }
  186. wait(for: [completionExpectation], timeout: 0.5)
  187. try? ModelFileManager.removeFile(at: invalidTempModelFileURL)
  188. }
  189. /// Test listing models without local info in user defaults.
  190. func testListModelsNoLocalInfo() {
  191. guard let testApp = FirebaseApp.app() else {
  192. XCTFail("Default app was not configured.")
  193. return
  194. }
  195. testApp.isDataCollectionDefaultEnabled = false
  196. let testName = String(#function.dropLast(2))
  197. let tempModelFileURL = tempModelFile()
  198. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  199. .createUnitTestInstance(testName: testName),
  200. app: testApp
  201. )
  202. let completionExpectation = expectation(description: "Completion handler")
  203. modelDownloader.listDownloadedModels { result in
  204. completionExpectation.fulfill()
  205. switch result {
  206. case .success:
  207. XCTFail("Unexpected success of list models.")
  208. case let .failure(error):
  209. switch error {
  210. /// Expected failure since no model info was written to user defaults.
  211. case let .internalError(description):
  212. XCTAssertTrue(description.contains("List models failed due to no local model info"))
  213. default: XCTFail("Expected failure due to bad model name.")
  214. }
  215. }
  216. }
  217. wait(for: [completionExpectation], timeout: 0.5)
  218. try? ModelFileManager.removeFile(at: tempModelFileURL)
  219. }
  220. /// Test how URL conversion behaves if there are spaces in the path.
  221. func testURLConversion() {
  222. let fakeURLWithSpace = URL(string: "file:///fakeDir1/fake%20Dir2/fakeFile")!
  223. XCTAssertEqual(fakeURLWithSpace, URL(string: fakeURLWithSpace.absoluteString))
  224. XCTAssertEqual(fakeURLWithSpace, URL(fileURLWithPath: fakeURLWithSpace.path))
  225. /// Strings without spaces should work fine either way.
  226. let fakeURLWithoutSpace = URL(string: "fakeDir1/fakeDir2/fakeFile")!
  227. XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.absoluteString)!)
  228. XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.path)!)
  229. }
  230. /// Test on-device logging.
  231. func testDeviceLogging() {
  232. DeviceLogger.logEvent(level: .error,
  233. message: "This is a test error logged to device.",
  234. messageCode: .testError)
  235. }
  236. /// Compare proto serialization methods.
  237. func testTelemetryEncoding() {
  238. let fakeModel = CustomModel(
  239. name: fakeModelName,
  240. size: 10,
  241. path: fakemodelPath,
  242. hash: fakeModelHash
  243. )
  244. var modelOptions = ModelOptions()
  245. modelOptions.setModelOptions(modelName: fakeModel.name, modelHash: fakeModel.hash)
  246. guard let binaryData = try? modelOptions.serializedData(),
  247. let jsonData = try? modelOptions.jsonUTF8Data(),
  248. let binaryEvent = try? ModelOptions(serializedData: binaryData),
  249. let jsonEvent = try? ModelOptions(jsonUTF8Data: jsonData) else {
  250. XCTFail("Encoding error.")
  251. return
  252. }
  253. XCTAssertNotNil(binaryData)
  254. XCTAssertNotNil(jsonData)
  255. XCTAssertLessThan(binaryData.count, jsonData.count)
  256. XCTAssertEqual(binaryEvent, jsonEvent)
  257. }
  258. /// Get model info if server returns a new model info.
  259. func testGetModelInfoWith200() {
  260. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  261. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  262. let completionExpectation = expectation(description: "Completion handler")
  263. modelInfoRetriever.downloadModelInfo { result in
  264. completionExpectation.fulfill()
  265. switch result {
  266. case let .success(fakemodelInfoResult):
  267. switch fakemodelInfoResult {
  268. case let .modelInfo(remoteModelInfo):
  269. XCTAssertEqual(remoteModelInfo.name, self.fakeModelName)
  270. XCTAssertEqual(remoteModelInfo.downloadURL, self.fakeDownloadURL)
  271. XCTAssertEqual(remoteModelInfo.size, self.fakeModelSize)
  272. XCTAssertEqual(remoteModelInfo.modelHash, self.fakeModelHash)
  273. case .notModified: XCTFail("Expected new model info from server.")
  274. }
  275. case let .failure(error): XCTFail("Failed with error - \(error)")
  276. }
  277. }
  278. wait(for: [completionExpectation], timeout: 0.5)
  279. }
  280. /// Get model info if model info is not modified.
  281. func testGetModelInfoWith304() {
  282. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  283. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  284. let modelInfoRetriever = fakeModelRetriever(
  285. fakeSession: session,
  286. fakeLocalModelInfo: localModelInfo
  287. )
  288. let completionExpectation = expectation(description: "Completion handler")
  289. modelInfoRetriever.downloadModelInfo { result in
  290. completionExpectation.fulfill()
  291. switch result {
  292. case let .success(fakemodelInfoResult):
  293. switch fakemodelInfoResult {
  294. case .modelInfo: XCTFail("Unexpected new model info from server.")
  295. case .notModified: break
  296. }
  297. case let .failure(error): XCTFail("Failed with error - \(error)")
  298. }
  299. }
  300. wait(for: [completionExpectation], timeout: 0.5)
  301. }
  302. /// Get model info if model info is not modified but local model info is not set.
  303. func testGetModelInfoWith304MissingLocalInfo() {
  304. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  305. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  306. let completionExpectation = expectation(description: "Completion handler")
  307. modelInfoRetriever.downloadModelInfo { result in
  308. completionExpectation.fulfill()
  309. switch result {
  310. case let .success(fakemodelInfoResult):
  311. switch fakemodelInfoResult {
  312. case .modelInfo: XCTFail("Unexpected new model info from server.")
  313. case .notModified: XCTFail("Expected failure since local model info was not set.")
  314. }
  315. case let .failure(error):
  316. switch error {
  317. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  318. "Model info was deleted unexpectedly.")
  319. default: XCTFail("Expected failure since local model info was not set.")
  320. }
  321. }
  322. }
  323. wait(for: [completionExpectation], timeout: 0.5)
  324. }
  325. /// Get model info if model info is not modified but local model info is not set.
  326. func testGetModelInfoWith304HashMismatch() {
  327. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  328. let localModelInfo = fakeLocalModelInfo(mismatch: true)
  329. let modelInfoRetriever = fakeModelRetriever(
  330. fakeSession: session,
  331. fakeLocalModelInfo: localModelInfo
  332. )
  333. let completionExpectation = expectation(description: "Completion handler")
  334. modelInfoRetriever.downloadModelInfo { result in
  335. completionExpectation.fulfill()
  336. switch result {
  337. case let .success(fakemodelInfoResult):
  338. switch fakemodelInfoResult {
  339. case .modelInfo: XCTFail("Unexpected new model info from server.")
  340. case .notModified: XCTFail("Expected failure since model hash does not match.")
  341. }
  342. case let .failure(error):
  343. switch error {
  344. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  345. "Unexpected model hash value.")
  346. default: XCTFail("Expected failure since local model info was not set.")
  347. }
  348. }
  349. }
  350. wait(for: [completionExpectation], timeout: 0.5)
  351. }
  352. /// Get model info if model name is invalid.
  353. func testGetModelInfoWith400() {
  354. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 400)
  355. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  356. let modelInfoRetriever = fakeModelRetriever(
  357. fakeSession: session,
  358. fakeLocalModelInfo: localModelInfo
  359. )
  360. let completionExpectation = expectation(description: "Completion handler")
  361. modelInfoRetriever.downloadModelInfo { result in
  362. completionExpectation.fulfill()
  363. switch result {
  364. case let .success(fakemodelInfoResult):
  365. switch fakemodelInfoResult {
  366. case .modelInfo: XCTFail("Unexpected new model info from server.")
  367. case .notModified: XCTFail("Expected failure since model name is invalid.")
  368. }
  369. case let .failure(error):
  370. XCTAssertEqual(error, .invalidArgument)
  371. }
  372. }
  373. wait(for: [completionExpectation], timeout: 0.5)
  374. }
  375. /// Get model info if model name does not exist.
  376. func testGetModelInfoWith404() {
  377. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 404)
  378. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  379. let modelInfoRetriever = fakeModelRetriever(
  380. fakeSession: session,
  381. fakeLocalModelInfo: localModelInfo
  382. )
  383. let completionExpectation = expectation(description: "Completion handler")
  384. modelInfoRetriever.downloadModelInfo { result in
  385. completionExpectation.fulfill()
  386. switch result {
  387. case let .success(fakemodelInfoResult):
  388. switch fakemodelInfoResult {
  389. case .modelInfo: XCTFail("Unexpected new model info from server.")
  390. case .notModified: XCTFail("Expected failure since model name is not availabl.")
  391. }
  392. case let .failure(error):
  393. XCTAssertEqual(error, .notFound)
  394. }
  395. }
  396. wait(for: [completionExpectation], timeout: 0.5)
  397. }
  398. /// Get model info if resource exhausted due to too many requests.
  399. func testGetModelInfoWith429() {
  400. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 429)
  401. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  402. let modelInfoRetriever = fakeModelRetriever(
  403. fakeSession: session,
  404. fakeLocalModelInfo: localModelInfo
  405. )
  406. let completionExpectation = expectation(description: "Completion handler")
  407. modelInfoRetriever.downloadModelInfo { result in
  408. completionExpectation.fulfill()
  409. switch result {
  410. case let .success(fakemodelInfoResult):
  411. switch fakemodelInfoResult {
  412. case .modelInfo: XCTFail("Unexpected new model info from server.")
  413. case .notModified: XCTFail("Expected failure since resource exhausted.")
  414. }
  415. case let .failure(error):
  416. XCTAssertEqual(error, .resourceExhausted)
  417. }
  418. }
  419. wait(for: [completionExpectation], timeout: 0.5)
  420. }
  421. /// Get model if server returns a model file.
  422. func testModelDownloadWith200() {
  423. let tempFileURL = tempFile()
  424. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  425. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  426. statusCode: 200,
  427. httpVersion: nil,
  428. headerFields: nil)!
  429. let downloader = MockModelFileDownloader()
  430. let fileDownloaderExpectation = expectation(description: "File downloader")
  431. downloader.downloadFileHandler = { url in
  432. fileDownloaderExpectation.fulfill()
  433. XCTAssertEqual(url, self.fakeDownloadURL)
  434. }
  435. let progressExpectation = expectation(description: "Progress handler")
  436. progressExpectation.expectedFulfillmentCount = 2
  437. let completionExpectation = expectation(description: "Completion handler")
  438. let taskProgressHandler: ModelDownloadTask.ProgressHandler = {
  439. progress in
  440. progressExpectation.fulfill()
  441. XCTAssertEqual(progress, 0.4)
  442. }
  443. let taskCompletion: ModelDownloadTask.Completion = { result in
  444. completionExpectation.fulfill()
  445. switch result {
  446. case let .success(model):
  447. self.checkModel(model: model)
  448. self.deleteFile(at: model.path)
  449. case let .failure(error):
  450. XCTFail("Error - \(error)")
  451. }
  452. }
  453. let testName = String(#function.dropLast(2))
  454. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  455. appName: "fakeAppName",
  456. defaults: .createUnitTestInstance(
  457. testName: testName
  458. ),
  459. downloader: downloader,
  460. progressHandler: taskProgressHandler,
  461. completion: taskCompletion)
  462. modelDownloadTask.resume()
  463. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  464. downloader.progressHandler?(100, 250)
  465. downloader.progressHandler?(100, 250)
  466. wait(for: [progressExpectation], timeout: 0.5)
  467. downloader
  468. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  469. fileURL: tempFileURL)))
  470. wait(for: [completionExpectation], timeout: 0.5)
  471. try? ModelFileManager.removeFile(at: tempFileURL)
  472. }
  473. /// Get model if server returns an error due to expired url.
  474. func testModelDownloadWith400Expired() {
  475. let tempFileURL = tempFile()
  476. let remoteModelInfo = fakeModelInfo(expired: true, oversized: false)
  477. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  478. statusCode: 400,
  479. httpVersion: nil,
  480. headerFields: nil)!
  481. let downloader = MockModelFileDownloader()
  482. let fileDownloaderExpectation = expectation(description: "File downloader")
  483. downloader.downloadFileHandler = { url in
  484. fileDownloaderExpectation.fulfill()
  485. XCTAssertEqual(url, self.fakeDownloadURL)
  486. }
  487. let completionExpectation = expectation(description: "Completion handler")
  488. let taskCompletion: ModelDownloadTask.Completion = { result in
  489. completionExpectation.fulfill()
  490. switch result {
  491. case .success: XCTFail("Unexpected successful model download.")
  492. case let .failure(error):
  493. XCTAssertEqual(error, .expiredDownloadURL)
  494. }
  495. }
  496. let testName = String(#function.dropLast(2))
  497. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  498. appName: "fakeAppName",
  499. defaults: .createUnitTestInstance(
  500. testName: testName
  501. ),
  502. downloader: downloader,
  503. completion: taskCompletion)
  504. modelDownloadTask.resume()
  505. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  506. downloader
  507. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  508. fileURL: tempFileURL)))
  509. wait(for: [completionExpectation], timeout: 0.5)
  510. try? ModelFileManager.removeFile(at: tempFileURL)
  511. }
  512. /// Get model if server returns an error due to invalid model name.
  513. func testModelDownloadWith400Invalid() {
  514. let tempFileURL = tempFile()
  515. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  516. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  517. statusCode: 400,
  518. httpVersion: nil,
  519. headerFields: nil)!
  520. let downloader = MockModelFileDownloader()
  521. let fileDownloaderExpectation = expectation(description: "File downloader")
  522. downloader.downloadFileHandler = { url in
  523. fileDownloaderExpectation.fulfill()
  524. XCTAssertEqual(url, self.fakeDownloadURL)
  525. }
  526. let completionExpectation = expectation(description: "Completion handler")
  527. let taskCompletion: ModelDownloadTask.Completion = { result in
  528. completionExpectation.fulfill()
  529. switch result {
  530. case .success: XCTFail("Unexpected successful model download.")
  531. case let .failure(error):
  532. XCTAssertEqual(error, .invalidArgument)
  533. }
  534. }
  535. let testName = String(#function.dropLast(2))
  536. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  537. appName: "fakeAppName",
  538. defaults: .createUnitTestInstance(
  539. testName: testName
  540. ),
  541. downloader: downloader,
  542. completion: taskCompletion)
  543. modelDownloadTask.resume()
  544. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  545. downloader
  546. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  547. fileURL: tempFileURL)))
  548. wait(for: [completionExpectation], timeout: 0.5)
  549. try? ModelFileManager.removeFile(at: tempFileURL)
  550. }
  551. /// Get model info if Firebase ML API is not enabled or permission denied.
  552. func testGetModelInfoWith403() {
  553. let errorMessage =
  554. """
  555. {
  556. "error":
  557. {
  558. "message":"API Disabled."
  559. }
  560. }
  561. """
  562. let session = fakeModelInfoSessionWithURL(
  563. fakeDownloadURL,
  564. statusCode: 403,
  565. customJSON: errorMessage
  566. )
  567. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  568. let modelInfoRetriever = fakeModelRetriever(
  569. fakeSession: session,
  570. fakeLocalModelInfo: localModelInfo
  571. )
  572. let completionExpectation = expectation(description: "Completion handler")
  573. modelInfoRetriever.downloadModelInfo { result in
  574. completionExpectation.fulfill()
  575. switch result {
  576. case .success:
  577. XCTFail("Unexpected new model info from server.")
  578. case let .failure(error):
  579. XCTAssertEqual(error, .permissionDenied)
  580. }
  581. }
  582. wait(for: [completionExpectation], timeout: 0.5)
  583. }
  584. /// Get model if multiple duplicate requests are made.
  585. func testModelDownloadWithMergeRequests() {
  586. let tempFileURL = tempFile()
  587. let numberOfRequests = 5
  588. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  589. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  590. statusCode: 200,
  591. httpVersion: nil,
  592. headerFields: nil)!
  593. let downloader = MockModelFileDownloader()
  594. let fileDownloaderExpectation = expectation(description: "File downloader")
  595. downloader.downloadFileHandler = { url in
  596. fileDownloaderExpectation.fulfill()
  597. XCTAssertEqual(url, self.fakeDownloadURL)
  598. }
  599. let completionExpectation = expectation(description: "Completion handler")
  600. completionExpectation.expectedFulfillmentCount = numberOfRequests + 1
  601. let taskCompletion: ModelDownloadTask.Completion = { result in
  602. completionExpectation.fulfill()
  603. switch result {
  604. case let .success(model):
  605. self.checkModel(model: model)
  606. case let .failure(error):
  607. XCTFail("Error - \(error)")
  608. }
  609. }
  610. let testName = String(#function.dropLast(2))
  611. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  612. appName: fakeAppName,
  613. defaults: .createUnitTestInstance(
  614. testName: testName
  615. ),
  616. downloader: downloader,
  617. completion: taskCompletion)
  618. for i in 1 ... numberOfRequests {
  619. let newTaskCompletion: ModelDownloadTask.Completion = { result in
  620. completionExpectation.fulfill()
  621. switch result {
  622. case let .success(model):
  623. self.checkModel(model: model)
  624. if i == numberOfRequests {
  625. self.deleteFile(at: model.path)
  626. }
  627. case let .failure(error):
  628. XCTFail("Error - \(error)")
  629. }
  630. }
  631. modelDownloadTask.resume()
  632. modelDownloadTask.merge(newCompletion: newTaskCompletion)
  633. }
  634. downloader
  635. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  636. fileURL: tempFileURL)))
  637. wait(for: [fileDownloaderExpectation], timeout: 1)
  638. wait(for: [completionExpectation], timeout: 2)
  639. try? ModelFileManager.removeFile(at: tempFileURL)
  640. }
  641. /// Get model if multiple duplicate requests are made (test at the API surface).
  642. func testGetModelWithMergeRequests() {
  643. guard let testApp = FirebaseApp.app() else {
  644. XCTFail("Default app was not configured.")
  645. return
  646. }
  647. testApp.isDataCollectionDefaultEnabled = false
  648. let modelDownloader = ModelDownloader.modelDownloader()
  649. let conditions = ModelDownloadConditions()
  650. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  651. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  652. let tempFileURL = tempFile()
  653. let numberOfRequests = 6
  654. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  655. statusCode: 200,
  656. httpVersion: nil,
  657. headerFields: nil)!
  658. let downloader = MockModelFileDownloader()
  659. let fileDownloaderExpectation = expectation(description: "File downloader")
  660. downloader.downloadFileHandler = { url in
  661. fileDownloaderExpectation.fulfill()
  662. XCTAssertEqual(url, self.fakeDownloadURL)
  663. }
  664. let completionExpectation = expectation(description: "Completion handler")
  665. completionExpectation.expectedFulfillmentCount = numberOfRequests
  666. for i in 1 ... numberOfRequests {
  667. modelDownloader.downloadInfoAndModel(modelName: "fakeModelName",
  668. modelInfoRetriever: modelInfoRetriever,
  669. downloader: downloader,
  670. conditions: conditions) { result in
  671. completionExpectation.fulfill()
  672. switch result {
  673. case let .success(model):
  674. self.checkModel(model: model)
  675. if i == numberOfRequests {
  676. self.deleteFile(at: model.path)
  677. }
  678. case let .failure(error):
  679. XCTFail("Error - \(error)")
  680. }
  681. }
  682. }
  683. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  684. downloader
  685. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  686. fileURL: tempFileURL)))
  687. wait(for: [completionExpectation], timeout: 1)
  688. try? ModelFileManager.removeFile(at: tempFileURL)
  689. }
  690. /// Get model if server returns an error due to model not found.
  691. func testModelDownloadWith404() {
  692. let tempFileURL = tempFile()
  693. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  694. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  695. statusCode: 404,
  696. httpVersion: nil,
  697. headerFields: nil)!
  698. let downloader = MockModelFileDownloader()
  699. let fileDownloaderExpectation = expectation(description: "File downloader")
  700. downloader.downloadFileHandler = { url in
  701. fileDownloaderExpectation.fulfill()
  702. XCTAssertEqual(url, self.fakeDownloadURL)
  703. }
  704. let completionExpectation = expectation(description: "Completion handler")
  705. let taskCompletion: ModelDownloadTask.Completion = { result in
  706. completionExpectation.fulfill()
  707. switch result {
  708. case .success: XCTFail("Unexpected successful model download.")
  709. case let .failure(error):
  710. XCTAssertEqual(error, .notFound)
  711. }
  712. }
  713. let testName = String(#function.dropLast(2))
  714. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  715. appName: "fakeAppName",
  716. defaults: .createUnitTestInstance(
  717. testName: testName
  718. ),
  719. downloader: downloader,
  720. completion: taskCompletion)
  721. modelDownloadTask.resume()
  722. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  723. downloader
  724. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  725. fileURL: tempFileURL)))
  726. wait(for: [completionExpectation], timeout: 0.5)
  727. try? ModelFileManager.removeFile(at: tempFileURL)
  728. }
  729. /// Get model if server returns an error due to model not found.
  730. func testModelDownloadNetworkError() {
  731. let tempFileURL = tempFile()
  732. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  733. let downloader = MockModelFileDownloader()
  734. let fileDownloaderExpectation = expectation(description: "File downloader")
  735. downloader.downloadFileHandler = { url in
  736. fileDownloaderExpectation.fulfill()
  737. XCTAssertEqual(url, self.fakeDownloadURL)
  738. }
  739. let completionExpectation = expectation(description: "Completion handler")
  740. let taskCompletion: ModelDownloadTask.Completion = { result in
  741. completionExpectation.fulfill()
  742. switch result {
  743. case .success: XCTFail("Unexpected successful model download.")
  744. case let .failure(error):
  745. XCTAssertEqual(error, .failedPrecondition)
  746. }
  747. }
  748. let testName = String(#function.dropLast(2))
  749. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  750. appName: "fakeAppName",
  751. defaults: .createUnitTestInstance(
  752. testName: testName
  753. ),
  754. downloader: downloader,
  755. completion: taskCompletion)
  756. modelDownloadTask.resume()
  757. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  758. let offlineError = DownloadError.internalError(description: "Network appears to be offline.")
  759. downloader
  760. .completion?(.failure(FileDownloaderError.networkError(offlineError)))
  761. wait(for: [completionExpectation], timeout: 0.5)
  762. try? ModelFileManager.removeFile(at: tempFileURL)
  763. }
  764. /// Get model with a valid server response.
  765. func testGetModelSuccessful() {
  766. guard let testApp = FirebaseApp.app() else {
  767. XCTFail("Default app was not configured.")
  768. return
  769. }
  770. testApp.isDataCollectionDefaultEnabled = false
  771. let modelDownloader = ModelDownloader.modelDownloader()
  772. let conditions = ModelDownloadConditions()
  773. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  774. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  775. let downloader = MockModelFileDownloader()
  776. let tempFileURL = tempFile()
  777. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  778. statusCode: 200,
  779. httpVersion: nil,
  780. headerFields: nil)!
  781. let fileDownloaderExpectation = expectation(description: "File downloader")
  782. let completionExpectation = expectation(description: "Completion handler")
  783. // Handles initial response.
  784. downloader.downloadFileHandler = { url in
  785. fileDownloaderExpectation.fulfill()
  786. XCTAssertEqual(url, self.fakeDownloadURL)
  787. }
  788. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  789. modelInfoRetriever: modelInfoRetriever,
  790. downloader: downloader,
  791. conditions: conditions) { result in
  792. completionExpectation.fulfill()
  793. switch result {
  794. case let .success(model):
  795. self.checkModel(model: model)
  796. self.deleteFile(at: model.path)
  797. case let .failure(error):
  798. XCTFail("Error - \(error)")
  799. }
  800. }
  801. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  802. downloader
  803. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  804. fileURL: tempFileURL)))
  805. wait(for: [completionExpectation], timeout: 0.5)
  806. try? ModelFileManager.removeFile(at: tempFileURL)
  807. }
  808. /// Get model if download url expired but retry was successful.
  809. func testGetModelSuccessfulRetry() {
  810. guard let testApp = FirebaseApp.app() else {
  811. XCTFail("Default app was not configured.")
  812. return
  813. }
  814. testApp.isDataCollectionDefaultEnabled = false
  815. let modelDownloader = ModelDownloader.modelDownloader()
  816. let conditions = ModelDownloadConditions()
  817. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  818. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  819. let downloader = MockModelFileDownloader()
  820. let tempFileURL = tempFile()
  821. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  822. statusCode: 400,
  823. httpVersion: nil,
  824. headerFields: nil)!
  825. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  826. statusCode: 200,
  827. httpVersion: nil,
  828. headerFields: nil)!
  829. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  830. let completionExpectation = expectation(description: "Completion handler.")
  831. // Handles initial response.
  832. downloader.downloadFileHandler = { url in
  833. fileDownloaderExpectation1.fulfill()
  834. XCTAssertEqual(url, self.fakeDownloadURL)
  835. }
  836. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  837. modelInfoRetriever: modelInfoRetriever,
  838. downloader: downloader,
  839. conditions: conditions) { result in
  840. completionExpectation.fulfill()
  841. switch result {
  842. // Retry successful - model returned.
  843. case let .success(model):
  844. self.checkModel(model: model)
  845. self.deleteFile(at: model.path)
  846. case let .failure(error):
  847. XCTFail("Error - \(error)")
  848. }
  849. }
  850. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  851. // Handles response after one retry.
  852. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  853. downloader.downloadFileHandler = { url in
  854. fileDownloaderExpectation2.fulfill()
  855. XCTAssertEqual(url, self.fakeDownloadURL)
  856. }
  857. downloader
  858. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  859. fileURL: tempFileURL)))
  860. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  861. downloader
  862. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  863. fileURL: tempFileURL)))
  864. wait(for: [completionExpectation], timeout: 0.5)
  865. try? ModelFileManager.removeFile(at: tempFileURL)
  866. }
  867. /// Get model if model doesn't exist.
  868. func testGetModelNotFound() {
  869. guard let testApp = FirebaseApp.app() else {
  870. XCTFail("Default app was not configured.")
  871. return
  872. }
  873. testApp.isDataCollectionDefaultEnabled = false
  874. let modelDownloader = ModelDownloader.modelDownloader()
  875. let conditions = ModelDownloadConditions()
  876. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  877. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  878. let downloader = MockModelFileDownloader()
  879. let tempFileURL = tempFile()
  880. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  881. statusCode: 404,
  882. httpVersion: nil,
  883. headerFields: nil)!
  884. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  885. let completionExpectation = expectation(description: "Completion handler")
  886. // Handles initial response.
  887. downloader.downloadFileHandler = { url in
  888. fileDownloaderExpectation.fulfill()
  889. XCTAssertEqual(url, self.fakeDownloadURL)
  890. }
  891. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  892. modelInfoRetriever: modelInfoRetriever,
  893. downloader: downloader,
  894. conditions: conditions) { result in
  895. completionExpectation.fulfill()
  896. switch result {
  897. // Retry failed due to another expired url - error returned.
  898. case .success: XCTFail("Unexpected successful model download.")
  899. case let .failure(error):
  900. switch error {
  901. case .notFound: break
  902. default: XCTFail("Expected failure due to model not found.")
  903. }
  904. }
  905. }
  906. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  907. downloader
  908. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  909. fileURL: tempFileURL)))
  910. wait(for: [completionExpectation], timeout: 0.5)
  911. try? ModelFileManager.removeFile(at: tempFileURL)
  912. }
  913. /// Get model if invalid or insufficient permissions.
  914. func testGetModelUnauthorized() {
  915. guard let testApp = FirebaseApp.app() else {
  916. XCTFail("Default app was not configured.")
  917. return
  918. }
  919. testApp.isDataCollectionDefaultEnabled = false
  920. let modelDownloader = ModelDownloader.modelDownloader()
  921. let conditions = ModelDownloadConditions()
  922. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  923. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  924. let downloader = MockModelFileDownloader()
  925. let tempFileURL = tempFile()
  926. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  927. statusCode: 403,
  928. httpVersion: nil,
  929. headerFields: nil)!
  930. let fileDownloaderExpectation = expectation(description: "File downloader")
  931. let completionExpectation = expectation(description: "Completion handler")
  932. // Handles initial response.
  933. downloader.downloadFileHandler = { url in
  934. fileDownloaderExpectation.fulfill()
  935. XCTAssertEqual(url, self.fakeDownloadURL)
  936. }
  937. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  938. modelInfoRetriever: modelInfoRetriever,
  939. downloader: downloader,
  940. conditions: conditions) { result in
  941. completionExpectation.fulfill()
  942. switch result {
  943. case .success: XCTFail("Unexpected successful model download.")
  944. case let .failure(error):
  945. switch error {
  946. case .permissionDenied: break
  947. default: XCTFail("Expected failure due to bad permissions.")
  948. }
  949. }
  950. }
  951. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  952. downloader
  953. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  954. fileURL: tempFileURL)))
  955. wait(for: [completionExpectation], timeout: 0.5)
  956. try? ModelFileManager.removeFile(at: tempFileURL)
  957. }
  958. /// Get model if download url expired and retry url also expired.
  959. func testGetModelFailedRetry() {
  960. guard let testApp = FirebaseApp.app() else {
  961. XCTFail("Default app was not configured.")
  962. return
  963. }
  964. testApp.isDataCollectionDefaultEnabled = false
  965. let modelDownloader = ModelDownloader.modelDownloader()
  966. let conditions = ModelDownloadConditions()
  967. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  968. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  969. let downloader = MockModelFileDownloader()
  970. let tempFileURL = tempFile()
  971. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  972. statusCode: 400,
  973. httpVersion: nil,
  974. headerFields: nil)!
  975. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  976. let completionExpectation = expectation(description: "Completion handler")
  977. // Handles initial response.
  978. downloader.downloadFileHandler = { url in
  979. fileDownloaderExpectation1.fulfill()
  980. XCTAssertEqual(url, self.fakeDownloadURL)
  981. }
  982. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  983. modelInfoRetriever: modelInfoRetriever,
  984. downloader: downloader,
  985. conditions: conditions) { result in
  986. completionExpectation.fulfill()
  987. switch result {
  988. // Retry failed due to another expired url - error returned.
  989. case .success: XCTFail("Unexpected successful model download.")
  990. case let .failure(error):
  991. switch error {
  992. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  993. "Unable to update expired model info.")
  994. default: XCTFail("Expected failure due to expired model info.")
  995. }
  996. }
  997. }
  998. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  999. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  1000. // Handles response after one retry.
  1001. downloader.downloadFileHandler = { url in
  1002. fileDownloaderExpectation2.fulfill()
  1003. XCTAssertEqual(url, self.fakeDownloadURL)
  1004. }
  1005. downloader
  1006. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1007. fileURL: tempFileURL)))
  1008. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1009. downloader
  1010. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1011. fileURL: tempFileURL)))
  1012. wait(for: [completionExpectation], timeout: 0.5)
  1013. try? ModelFileManager.removeFile(at: tempFileURL)
  1014. }
  1015. /// Get model if multiple retries all returned expired urls.
  1016. func testGetModelMultipleFailedRetries() {
  1017. guard let testApp = FirebaseApp.app() else {
  1018. XCTFail("Default app was not configured.")
  1019. return
  1020. }
  1021. testApp.isDataCollectionDefaultEnabled = false
  1022. let modelDownloader = ModelDownloader.modelDownloader()
  1023. modelDownloader.numberOfRetries = 2
  1024. let conditions = ModelDownloadConditions()
  1025. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  1026. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  1027. let downloader = MockModelFileDownloader()
  1028. let tempFileURL = tempFile()
  1029. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  1030. statusCode: 400,
  1031. httpVersion: nil,
  1032. headerFields: nil)!
  1033. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  1034. let completionExpectation = expectation(description: "Completion handler")
  1035. // Handles initial response.
  1036. downloader.downloadFileHandler = { url in
  1037. fileDownloaderExpectation1.fulfill()
  1038. XCTAssertEqual(url, self.fakeDownloadURL)
  1039. }
  1040. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  1041. modelInfoRetriever: modelInfoRetriever,
  1042. downloader: downloader,
  1043. conditions: conditions) { result in
  1044. completionExpectation.fulfill()
  1045. switch result {
  1046. // Two retries failed due to expired urls - error returned.
  1047. case .success: XCTFail("Unexpected successful model download.")
  1048. case let .failure(error):
  1049. switch error {
  1050. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  1051. "Unable to update expired model info.")
  1052. default: XCTFail("Expected failure due to expired model info.")
  1053. }
  1054. }
  1055. }
  1056. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  1057. let fileDownloaderExpectation2 = expectation(description: "file downloader after first retry")
  1058. // Handles response after first retry.
  1059. downloader.downloadFileHandler = { url in
  1060. fileDownloaderExpectation2.fulfill()
  1061. XCTAssertEqual(url, self.fakeDownloadURL)
  1062. }
  1063. downloader
  1064. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1065. fileURL: tempFileURL)))
  1066. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1067. let fileDownloaderExpectation3 =
  1068. expectation(description: "File downloader after second retry")
  1069. // Handles response after second retry.
  1070. downloader.downloadFileHandler = { url in
  1071. fileDownloaderExpectation3.fulfill()
  1072. XCTAssertEqual(url, self.fakeDownloadURL)
  1073. }
  1074. downloader
  1075. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1076. fileURL: tempFileURL)))
  1077. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  1078. downloader
  1079. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1080. fileURL: tempFileURL)))
  1081. wait(for: [completionExpectation], timeout: 0.5)
  1082. try? ModelFileManager.removeFile(at: tempFileURL)
  1083. }
  1084. /// Get model if a retry finally returns an unexpired download url.
  1085. func testGetModelMultipleRetriesWithSuccess() {
  1086. guard let testApp = FirebaseApp.app() else {
  1087. XCTFail("Default app was not configured.")
  1088. return
  1089. }
  1090. testApp.isDataCollectionDefaultEnabled = false
  1091. let modelDownloader = ModelDownloader.modelDownloader()
  1092. modelDownloader.numberOfRetries = 4
  1093. let conditions = ModelDownloadConditions()
  1094. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  1095. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  1096. let downloader = MockModelFileDownloader()
  1097. let tempFileURL = tempFile()
  1098. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  1099. statusCode: 400,
  1100. httpVersion: nil,
  1101. headerFields: nil)!
  1102. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  1103. statusCode: 200,
  1104. httpVersion: nil,
  1105. headerFields: nil)!
  1106. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  1107. let completionExpectation = expectation(description: "Completion handler")
  1108. // Handles initial response.
  1109. downloader.downloadFileHandler = { url in
  1110. fileDownloaderExpectation1.fulfill()
  1111. XCTAssertEqual(url, self.fakeDownloadURL)
  1112. }
  1113. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  1114. modelInfoRetriever: modelInfoRetriever,
  1115. downloader: downloader,
  1116. conditions: conditions) { result in
  1117. completionExpectation.fulfill()
  1118. switch result {
  1119. // One retry failed but the next retry succeeded - model returned.
  1120. case let .success(model):
  1121. self.checkModel(model: model)
  1122. self.deleteFile(at: model.path)
  1123. case let .failure(error):
  1124. XCTFail("Error - \(error)")
  1125. }
  1126. }
  1127. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  1128. let fileDownloaderExpectation2 = expectation(description: "File downloader after first retry")
  1129. // Handles response after first retry.
  1130. downloader.downloadFileHandler = { url in
  1131. fileDownloaderExpectation2.fulfill()
  1132. XCTAssertEqual(url, self.fakeDownloadURL)
  1133. }
  1134. downloader
  1135. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1136. fileURL: tempFileURL)))
  1137. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1138. let fileDownloaderExpectation3 =
  1139. expectation(description: "File downloader after second retry")
  1140. // Handles response after second retry.
  1141. downloader.downloadFileHandler = { url in
  1142. fileDownloaderExpectation3.fulfill()
  1143. XCTAssertEqual(url, self.fakeDownloadURL)
  1144. }
  1145. downloader
  1146. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1147. fileURL: tempFileURL)))
  1148. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  1149. downloader
  1150. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  1151. fileURL: tempFileURL)))
  1152. wait(for: [completionExpectation], timeout: 0.5)
  1153. try? ModelFileManager.removeFile(at: tempFileURL)
  1154. }
  1155. }
  1156. /// Mock URL session for testing.
  1157. class MockModelInfoRetrieverSession: ModelInfoRetrieverSession {
  1158. var data: Data?
  1159. var response: URLResponse?
  1160. var error: Error?
  1161. func getModelInfo(with request: URLRequest,
  1162. completion: @escaping (Data?, URLResponse?, Error?) -> Void) {
  1163. completion(data, response, error)
  1164. }
  1165. }
  1166. class MockModelFileDownloader: FileDownloader {
  1167. var progressHandler: ProgressHandler?
  1168. var completion: CompletionHandler?
  1169. var downloadFileHandler: ((_ url: URL) -> Void)?
  1170. func downloadFile(with url: URL, progressHandler: @escaping ProgressHandler,
  1171. completion: @escaping CompletionHandler) {
  1172. self.progressHandler = progressHandler
  1173. self.completion = completion
  1174. downloadFileHandler?(url)
  1175. }
  1176. }
  1177. extension GULUserDefaults {
  1178. /// Returns a new instance of user defaults.
  1179. static func createUnitTestInstance(testName: String) -> GULUserDefaults {
  1180. let suiteName = "com.google.firebase.ml.test.\(testName)"
  1181. let defaults = GULUserDefaults(suiteName: suiteName)
  1182. return defaults
  1183. }
  1184. /// Returns the existing user defaults instance.
  1185. static func getUnitTestInstance(testName: String) -> GULUserDefaults {
  1186. let suiteName = "com.google.firebase.ml.test.\(testName)"
  1187. return GULUserDefaults(suiteName: suiteName)
  1188. }
  1189. }
  1190. // MARK: - Helpers
  1191. extension ModelDownloaderUnitTests {
  1192. func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int,
  1193. customJSON: String? = nil) -> MockModelInfoRetrieverSession {
  1194. let fakeSession = MockModelInfoRetrieverSession()
  1195. let fakeJSON = customJSON ?? fakeModelJSON
  1196. fakeSession.data = fakeJSON.data(using: .utf8)
  1197. fakeSession.response = HTTPURLResponse(
  1198. url: url,
  1199. statusCode: statusCode,
  1200. httpVersion: nil,
  1201. headerFields: ["Etag": "fakeModelHash"]
  1202. )
  1203. fakeSession.error = nil
  1204. return fakeSession
  1205. }
  1206. func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
  1207. fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
  1208. // TODO: Replace with fake to check if it was used correctly by the download task.
  1209. let modelInfoRetriever = ModelInfoRetriever(
  1210. modelName: fakeModelName,
  1211. projectID: fakeProjectID,
  1212. apiKey: fakeAPIKey,
  1213. appName: fakeAppName, authTokenProvider: successAuthTokenProvider,
  1214. session: fakeSession, localModelInfo: fakeLocalModelInfo
  1215. )
  1216. return modelInfoRetriever
  1217. }
  1218. func checkModel(model: CustomModel) {
  1219. XCTAssertEqual(model.name, fakeModelName)
  1220. XCTAssertEqual(model.size, fakeModelSize)
  1221. XCTAssertEqual(model.hash, fakeModelHash)
  1222. let modelURL = URL(fileURLWithPath: model.path)
  1223. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  1224. }
  1225. func tempFile(name: String = "fake-model-file.tmp") -> URL {
  1226. let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
  1227. isDirectory: true)
  1228. let tempFileURL = tempDirectoryURL.appendingPathComponent(name)
  1229. let tempData: Data = "fakeModelData".data(using: .utf8)!
  1230. try? tempData.write(to: tempFileURL)
  1231. return tempFileURL
  1232. }
  1233. func tempModelFile(invalid: Bool = false) -> URL {
  1234. let modelDirectoryURL = ModelFileManager.modelsDirectory!
  1235. let modelFileName = invalid ? "fbml_model_fake-model-file.tmp" :
  1236. "fbml_model@@__FIRAPP_DEFAULT@@\(fakeModelName)"
  1237. let tempFileURL = modelDirectoryURL.appendingPathComponent(modelFileName)
  1238. .appendingPathExtension("tflite")
  1239. let tempData: Data = "fakeModelData".data(using: .utf8)!
  1240. try? tempData.write(to: tempFileURL)
  1241. return tempFileURL
  1242. }
  1243. func deleteFile(at path: String) {
  1244. let url = URL(fileURLWithPath: path)
  1245. do {
  1246. try ModelFileManager.removeFile(at: url)
  1247. } catch {
  1248. XCTFail("Failed to delete file at \(path).")
  1249. }
  1250. }
  1251. func fakeLocalModelInfo(mismatch: Bool = false) -> LocalModelInfo {
  1252. let hash = mismatch ? "wrongModelHash" : fakeModelHash
  1253. return LocalModelInfo(
  1254. name: fakeModelName,
  1255. modelHash: hash,
  1256. size: 20
  1257. )
  1258. }
  1259. func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
  1260. let expiryTime = expired ? Date() : Date().addingTimeInterval(1000)
  1261. let size = oversized ? Int(Int64.max) / 4 : fakeModelSize
  1262. return RemoteModelInfo(name: fakeModelName,
  1263. downloadURL: fakeDownloadURL,
  1264. modelHash: fakeModelHash,
  1265. size: size,
  1266. urlExpiryTime: expiryTime)
  1267. }
  1268. }
  1269. #endif // !targetEnvironment(macCatalyst) && !os(macOS)