ModelDownloaderUnitTests.swift 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
  15. // involve interactions with the keychain that require a provisioning profile.
  16. // This is because MLModelDownloader depends on Installations.
  17. // See go/firebase-macos-keychain-popups for more details.
  18. #if !targetEnvironment(macCatalyst) && !os(macOS)
  19. import XCTest
  20. @testable import FirebaseCore
  21. @testable import FirebaseInstallations
  22. @testable import FirebaseMLModelDownloader
  23. /// Mock options to configure default Firebase app.
  24. private enum MockOptions {
  25. static let appID = "1:123:ios:123abc"
  26. static let gcmSenderID = "mock-sender-id"
  27. static let projectID = "mock-project-id"
  28. static let apiKey = "ABcdEf-APIKeyWithValidFormat_0123456789"
  29. }
  30. // TODO: Create separate files for each class tested.
  31. final class ModelDownloaderUnitTests: XCTestCase {
  32. let fakeModelName = "fakeModelName"
  33. let fakemodelPath = "fakemodelPath"
  34. let fakeModelHash = "fakeModelHash"
  35. let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
  36. let fakeFileURL = URL(string: "www.fake-model-file.com")!
  37. let fakeExpiryTime = "2021-01-20T04:20:10.220Z"
  38. let fakeModelSize = 20
  39. let fakeProjectID = "fakeProjectID"
  40. let fakeAPIKey = "fakeAPIKey"
  41. let fakeAppName = "fakeAppName"
  42. var fakeModelJSON: String {
  43. """
  44. {
  45. "downloadUri":"\(fakeDownloadURL)",
  46. "expireTime":"\(fakeExpiryTime)",
  47. "sizeBytes":"\(fakeModelSize)"
  48. }
  49. """
  50. }
  51. let successAuthTokenProvider =
  52. { (completion: @escaping (Result<String, DownloadError>) -> Void) in
  53. completion(.success("fakeFISToken"))
  54. }
  55. override class func setUp() {
  56. let options = FirebaseOptions(
  57. googleAppID: MockOptions.appID,
  58. gcmSenderID: MockOptions.gcmSenderID
  59. )
  60. options.apiKey = MockOptions.apiKey
  61. options.projectID = MockOptions.projectID
  62. FirebaseApp.configure(options: options)
  63. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  64. }
  65. override func setUp() {
  66. do {
  67. try ModelFileManager.emptyModelsDirectory()
  68. } catch {
  69. XCTFail("Could not empty models directory.")
  70. }
  71. }
  72. override class func tearDown() {
  73. try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
  74. FirebaseApp.app()?.delete { _ in }
  75. }
  76. /// Test if the same downloader instance is returned for an app.
  77. func testModelDownloaderInit() {
  78. guard let testApp = FirebaseApp.app() else {
  79. XCTFail("Default app was not configured.")
  80. return
  81. }
  82. testApp.isDataCollectionDefaultEnabled = false
  83. let modelDownloader1 = ModelDownloader.modelDownloader(app: testApp)
  84. let modelDownloader2 = ModelDownloader.modelDownloader()
  85. XCTAssert(modelDownloader1 === modelDownloader2)
  86. }
  87. /// Test if app deletion removes associated downloader instance.
  88. func testAppDeletion() {
  89. guard let testApp = FirebaseApp.app() else {
  90. XCTFail("Default app was not configured.")
  91. return
  92. }
  93. testApp.isDataCollectionDefaultEnabled = false
  94. let modelDownloader1 = ModelDownloader.modelDownloader()
  95. testApp.delete { success in
  96. XCTAssertTrue(success)
  97. }
  98. ModelDownloaderUnitTests.setUp()
  99. let modelDownloader2 = ModelDownloader.modelDownloader()
  100. XCTAssert(modelDownloader1 !== modelDownloader2)
  101. }
  102. /// Test invalid model file deletion.
  103. func testDeleteInvalidModel() {
  104. guard let testApp = FirebaseApp.app() else {
  105. XCTFail("Default app was not configured.")
  106. return
  107. }
  108. testApp.isDataCollectionDefaultEnabled = false
  109. let modelDownloader = ModelDownloader.modelDownloader()
  110. modelDownloader.deleteDownloadedModel(name: fakeModelName) { result in
  111. switch result {
  112. case .success: XCTFail("Unexpected successful model deletion.")
  113. case let .failure(error):
  114. XCTAssertNotNil(error)
  115. }
  116. }
  117. }
  118. func testGetModelNameFromFilePath() {
  119. var fakeURL = URL(fileURLWithPath: "modelDirectory/@@appName@@\(fakeModelName).tflite")
  120. XCTAssertEqual(ModelFileManager.getModelNameFromFilePath(fakeURL), fakeModelName)
  121. fakeURL = URL(fileURLWithPath: "modelDirectory/--appName--\(fakeModelName).tflite")
  122. XCTAssertNil(ModelFileManager.getModelNameFromFilePath(fakeURL))
  123. }
  124. /// Test listing models.
  125. func testListModels() {
  126. guard let testApp = FirebaseApp.app() else {
  127. XCTFail("Default app was not configured.")
  128. return
  129. }
  130. testApp.isDataCollectionDefaultEnabled = false
  131. let testName = String(#function.dropLast(2))
  132. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  133. .createUnitTestInstance(testName: testName),
  134. app: testApp
  135. )
  136. let tempModelFileURL = tempModelFile()
  137. let completionExpectation = expectation(description: "Completion handler")
  138. let localModelInfo = fakeLocalModelInfo()
  139. localModelInfo.writeToDefaults(
  140. .getUnitTestInstance(testName: testName),
  141. appName: testApp.name
  142. )
  143. modelDownloader.listDownloadedModels { result in
  144. completionExpectation.fulfill()
  145. switch result {
  146. case let .success(models):
  147. /// Expected success since model info was written to user defaults.
  148. XCTAssertEqual(models.first?.path, tempModelFileURL.path)
  149. case let .failure(error):
  150. XCTFail("Error - \(error)")
  151. }
  152. }
  153. wait(for: [completionExpectation], timeout: 0.5)
  154. try? ModelFileManager.removeFile(at: tempModelFileURL)
  155. }
  156. /// Test invalid path in model directory.
  157. func testListModelsInvalidPath() {
  158. guard let testApp = FirebaseApp.app() else {
  159. XCTFail("Default app was not configured.")
  160. return
  161. }
  162. testApp.isDataCollectionDefaultEnabled = false
  163. let modelDownloader = ModelDownloader.modelDownloader()
  164. let invalidTempModelFileURL = tempModelFile(invalid: true)
  165. let completionExpectation = expectation(description: "Completion handler")
  166. modelDownloader.listDownloadedModels { result in
  167. completionExpectation.fulfill()
  168. switch result {
  169. case .success:
  170. XCTFail("Unexpected success of list models.")
  171. case let .failure(error):
  172. switch error {
  173. case let .internalError(description):
  174. XCTAssertTrue(description
  175. .contains("List models failed due to unexpected model file name"))
  176. default: XCTFail("Expected failure due to bad model name.")
  177. }
  178. }
  179. }
  180. wait(for: [completionExpectation], timeout: 0.5)
  181. try? ModelFileManager.removeFile(at: invalidTempModelFileURL)
  182. }
  183. /// Test listing models without local info in user defaults.
  184. func testListModelsNoLocalInfo() {
  185. guard let testApp = FirebaseApp.app() else {
  186. XCTFail("Default app was not configured.")
  187. return
  188. }
  189. testApp.isDataCollectionDefaultEnabled = false
  190. let testName = String(#function.dropLast(2))
  191. let tempModelFileURL = tempModelFile()
  192. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  193. .createUnitTestInstance(testName: testName),
  194. app: testApp
  195. )
  196. let completionExpectation = expectation(description: "Completion handler")
  197. modelDownloader.listDownloadedModels { result in
  198. completionExpectation.fulfill()
  199. switch result {
  200. case .success:
  201. XCTFail("Unexpected success of list models.")
  202. case let .failure(error):
  203. switch error {
  204. /// Expected failure since no model info was written to user defaults.
  205. case let .internalError(description):
  206. XCTAssertTrue(description.contains("List models failed due to no local model info"))
  207. default: XCTFail("Expected failure due to bad model name.")
  208. }
  209. }
  210. }
  211. wait(for: [completionExpectation], timeout: 0.5)
  212. try? ModelFileManager.removeFile(at: tempModelFileURL)
  213. }
  214. /// Test how URL conversion behaves if there are spaces in the path.
  215. func testURLConversion() {
  216. /// Spaces in the string only convert to URL when using URL(fileURLWithPath: ).
  217. let fakeURLWithSpace = URL(string: "file:///fakeDir1/fake%20Dir2/fakeFile")!
  218. XCTAssertEqual(fakeURLWithSpace, URL(string: fakeURLWithSpace.absoluteString))
  219. XCTAssertEqual(fakeURLWithSpace, URL(fileURLWithPath: fakeURLWithSpace.path))
  220. XCTAssertNil(URL(string: fakeURLWithSpace.path))
  221. /// Strings without spaces should work fine either way.
  222. let fakeURLWithoutSpace = URL(string: "fakeDir1/fakeDir2/fakeFile")!
  223. XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.absoluteString)!)
  224. XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.path)!)
  225. }
  226. /// Test on-device logging.
  227. func testDeviceLogging() {
  228. DeviceLogger.logEvent(level: .error,
  229. message: "This is a test error logged to device.",
  230. messageCode: .testError)
  231. }
  232. /// Compare proto serialization methods.
  233. func testTelemetryEncoding() {
  234. let fakeModel = CustomModel(
  235. name: fakeModelName,
  236. size: 10,
  237. path: fakemodelPath,
  238. hash: fakeModelHash
  239. )
  240. var modelOptions = ModelOptions()
  241. modelOptions.setModelOptions(modelName: fakeModel.name, modelHash: fakeModel.hash)
  242. guard let binaryData = try? modelOptions.serializedData(),
  243. let jsonData = try? modelOptions.jsonUTF8Data(),
  244. let binaryEvent = try? ModelOptions(serializedData: binaryData),
  245. let jsonEvent = try? ModelOptions(jsonUTF8Data: jsonData) else {
  246. XCTFail("Encoding error.")
  247. return
  248. }
  249. XCTAssertNotNil(binaryData)
  250. XCTAssertNotNil(jsonData)
  251. XCTAssertLessThan(binaryData.count, jsonData.count)
  252. XCTAssertEqual(binaryEvent, jsonEvent)
  253. }
  254. /// Get model info if server returns a new model info.
  255. func testGetModelInfoWith200() {
  256. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  257. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  258. let completionExpectation = expectation(description: "Completion handler")
  259. modelInfoRetriever.downloadModelInfo { result in
  260. completionExpectation.fulfill()
  261. switch result {
  262. case let .success(fakemodelInfoResult):
  263. switch fakemodelInfoResult {
  264. case let .modelInfo(remoteModelInfo):
  265. XCTAssertEqual(remoteModelInfo.name, self.fakeModelName)
  266. XCTAssertEqual(remoteModelInfo.downloadURL, self.fakeDownloadURL)
  267. XCTAssertEqual(remoteModelInfo.size, self.fakeModelSize)
  268. XCTAssertEqual(remoteModelInfo.modelHash, self.fakeModelHash)
  269. case .notModified: XCTFail("Expected new model info from server.")
  270. }
  271. case let .failure(error): XCTFail("Failed with error - \(error)")
  272. }
  273. }
  274. wait(for: [completionExpectation], timeout: 0.5)
  275. }
  276. /// Get model info if model info is not modified.
  277. func testGetModelInfoWith304() {
  278. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  279. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  280. let modelInfoRetriever = fakeModelRetriever(
  281. fakeSession: session,
  282. fakeLocalModelInfo: localModelInfo
  283. )
  284. let completionExpectation = expectation(description: "Completion handler")
  285. modelInfoRetriever.downloadModelInfo { result in
  286. completionExpectation.fulfill()
  287. switch result {
  288. case let .success(fakemodelInfoResult):
  289. switch fakemodelInfoResult {
  290. case .modelInfo: XCTFail("Unexpected new model info from server.")
  291. case .notModified: break
  292. }
  293. case let .failure(error): XCTFail("Failed with error - \(error)")
  294. }
  295. }
  296. wait(for: [completionExpectation], timeout: 0.5)
  297. }
  298. /// Get model info if model info is not modified but local model info is not set.
  299. func testGetModelInfoWith304MissingLocalInfo() {
  300. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  301. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  302. let completionExpectation = expectation(description: "Completion handler")
  303. modelInfoRetriever.downloadModelInfo { result in
  304. completionExpectation.fulfill()
  305. switch result {
  306. case let .success(fakemodelInfoResult):
  307. switch fakemodelInfoResult {
  308. case .modelInfo: XCTFail("Unexpected new model info from server.")
  309. case .notModified: XCTFail("Expected failure since local model info was not set.")
  310. }
  311. case let .failure(error):
  312. switch error {
  313. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  314. "Model info was deleted unexpectedly.")
  315. default: XCTFail("Expected failure since local model info was not set.")
  316. }
  317. }
  318. }
  319. wait(for: [completionExpectation], timeout: 0.5)
  320. }
  321. /// Get model info if model info is not modified but local model info is not set.
  322. func testGetModelInfoWith304HashMismatch() {
  323. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  324. let localModelInfo = fakeLocalModelInfo(mismatch: true)
  325. let modelInfoRetriever = fakeModelRetriever(
  326. fakeSession: session,
  327. fakeLocalModelInfo: localModelInfo
  328. )
  329. let completionExpectation = expectation(description: "Completion handler")
  330. modelInfoRetriever.downloadModelInfo { result in
  331. completionExpectation.fulfill()
  332. switch result {
  333. case let .success(fakemodelInfoResult):
  334. switch fakemodelInfoResult {
  335. case .modelInfo: XCTFail("Unexpected new model info from server.")
  336. case .notModified: XCTFail("Expected failure since model hash does not match.")
  337. }
  338. case let .failure(error):
  339. switch error {
  340. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  341. "Unexpected model hash value.")
  342. default: XCTFail("Expected failure since local model info was not set.")
  343. }
  344. }
  345. }
  346. wait(for: [completionExpectation], timeout: 0.5)
  347. }
  348. /// Get model info if model name is invalid.
  349. func testGetModelInfoWith400() {
  350. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 400)
  351. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  352. let modelInfoRetriever = fakeModelRetriever(
  353. fakeSession: session,
  354. fakeLocalModelInfo: localModelInfo
  355. )
  356. let completionExpectation = expectation(description: "Completion handler")
  357. modelInfoRetriever.downloadModelInfo { result in
  358. completionExpectation.fulfill()
  359. switch result {
  360. case let .success(fakemodelInfoResult):
  361. switch fakemodelInfoResult {
  362. case .modelInfo: XCTFail("Unexpected new model info from server.")
  363. case .notModified: XCTFail("Expected failure since model name is invalid.")
  364. }
  365. case let .failure(error):
  366. XCTAssertEqual(error, .invalidArgument)
  367. }
  368. }
  369. wait(for: [completionExpectation], timeout: 0.5)
  370. }
  371. /// Get model info if model name does not exist.
  372. func testGetModelInfoWith404() {
  373. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 404)
  374. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  375. let modelInfoRetriever = fakeModelRetriever(
  376. fakeSession: session,
  377. fakeLocalModelInfo: localModelInfo
  378. )
  379. let completionExpectation = expectation(description: "Completion handler")
  380. modelInfoRetriever.downloadModelInfo { result in
  381. completionExpectation.fulfill()
  382. switch result {
  383. case let .success(fakemodelInfoResult):
  384. switch fakemodelInfoResult {
  385. case .modelInfo: XCTFail("Unexpected new model info from server.")
  386. case .notModified: XCTFail("Expected failure since model name is not availabl.")
  387. }
  388. case let .failure(error):
  389. XCTAssertEqual(error, .notFound)
  390. }
  391. }
  392. wait(for: [completionExpectation], timeout: 0.5)
  393. }
  394. /// Get model info if resource exhausted due to too many requests.
  395. func testGetModelInfoWith429() {
  396. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 429)
  397. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  398. let modelInfoRetriever = fakeModelRetriever(
  399. fakeSession: session,
  400. fakeLocalModelInfo: localModelInfo
  401. )
  402. let completionExpectation = expectation(description: "Completion handler")
  403. modelInfoRetriever.downloadModelInfo { result in
  404. completionExpectation.fulfill()
  405. switch result {
  406. case let .success(fakemodelInfoResult):
  407. switch fakemodelInfoResult {
  408. case .modelInfo: XCTFail("Unexpected new model info from server.")
  409. case .notModified: XCTFail("Expected failure since resource exhausted.")
  410. }
  411. case let .failure(error):
  412. XCTAssertEqual(error, .resourceExhausted)
  413. }
  414. }
  415. wait(for: [completionExpectation], timeout: 0.5)
  416. }
  417. /// Get model if server returns a model file.
  418. func testModelDownloadWith200() {
  419. let tempFileURL = tempFile()
  420. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  421. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  422. statusCode: 200,
  423. httpVersion: nil,
  424. headerFields: nil)!
  425. let downloader = MockModelFileDownloader()
  426. let fileDownloaderExpectation = expectation(description: "File downloader")
  427. downloader.downloadFileHandler = { url in
  428. fileDownloaderExpectation.fulfill()
  429. XCTAssertEqual(url, self.fakeDownloadURL)
  430. }
  431. let progressExpectation = expectation(description: "Progress handler")
  432. progressExpectation.expectedFulfillmentCount = 2
  433. let completionExpectation = expectation(description: "Completion handler")
  434. let taskProgressHandler: ModelDownloadTask.ProgressHandler = {
  435. progress in
  436. progressExpectation.fulfill()
  437. XCTAssertEqual(progress, 0.4)
  438. }
  439. let taskCompletion: ModelDownloadTask.Completion = { result in
  440. completionExpectation.fulfill()
  441. switch result {
  442. case let .success(model):
  443. self.checkModel(model: model)
  444. self.deleteFile(at: model.path)
  445. case let .failure(error):
  446. XCTFail("Error - \(error)")
  447. }
  448. }
  449. let testName = String(#function.dropLast(2))
  450. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  451. appName: "fakeAppName",
  452. defaults: .createUnitTestInstance(
  453. testName: testName
  454. ),
  455. downloader: downloader,
  456. progressHandler: taskProgressHandler,
  457. completion: taskCompletion)
  458. modelDownloadTask.resume()
  459. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  460. downloader.progressHandler?(100, 250)
  461. downloader.progressHandler?(100, 250)
  462. wait(for: [progressExpectation], timeout: 0.5)
  463. downloader
  464. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  465. fileURL: tempFileURL)))
  466. wait(for: [completionExpectation], timeout: 0.5)
  467. try? ModelFileManager.removeFile(at: tempFileURL)
  468. }
  469. /// Get model if server returns an error due to expired url.
  470. func testModelDownloadWith400Expired() {
  471. let tempFileURL = tempFile()
  472. let remoteModelInfo = fakeModelInfo(expired: true, oversized: false)
  473. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  474. statusCode: 400,
  475. httpVersion: nil,
  476. headerFields: nil)!
  477. let downloader = MockModelFileDownloader()
  478. let fileDownloaderExpectation = expectation(description: "File downloader")
  479. downloader.downloadFileHandler = { url in
  480. fileDownloaderExpectation.fulfill()
  481. XCTAssertEqual(url, self.fakeDownloadURL)
  482. }
  483. let completionExpectation = expectation(description: "Completion handler")
  484. let taskCompletion: ModelDownloadTask.Completion = { result in
  485. completionExpectation.fulfill()
  486. switch result {
  487. case .success: XCTFail("Unexpected successful model download.")
  488. case let .failure(error):
  489. XCTAssertEqual(error, .expiredDownloadURL)
  490. }
  491. }
  492. let testName = String(#function.dropLast(2))
  493. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  494. appName: "fakeAppName",
  495. defaults: .createUnitTestInstance(
  496. testName: testName
  497. ),
  498. downloader: downloader,
  499. completion: taskCompletion)
  500. modelDownloadTask.resume()
  501. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  502. downloader
  503. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  504. fileURL: tempFileURL)))
  505. wait(for: [completionExpectation], timeout: 0.5)
  506. try? ModelFileManager.removeFile(at: tempFileURL)
  507. }
  508. /// Get model if server returns an error due to invalid model name.
  509. func testModelDownloadWith400Invalid() {
  510. let tempFileURL = tempFile()
  511. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  512. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  513. statusCode: 400,
  514. httpVersion: nil,
  515. headerFields: nil)!
  516. let downloader = MockModelFileDownloader()
  517. let fileDownloaderExpectation = expectation(description: "File downloader")
  518. downloader.downloadFileHandler = { url in
  519. fileDownloaderExpectation.fulfill()
  520. XCTAssertEqual(url, self.fakeDownloadURL)
  521. }
  522. let completionExpectation = expectation(description: "Completion handler")
  523. let taskCompletion: ModelDownloadTask.Completion = { result in
  524. completionExpectation.fulfill()
  525. switch result {
  526. case .success: XCTFail("Unexpected successful model download.")
  527. case let .failure(error):
  528. XCTAssertEqual(error, .invalidArgument)
  529. }
  530. }
  531. let testName = String(#function.dropLast(2))
  532. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  533. appName: "fakeAppName",
  534. defaults: .createUnitTestInstance(
  535. testName: testName
  536. ),
  537. downloader: downloader,
  538. completion: taskCompletion)
  539. modelDownloadTask.resume()
  540. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  541. downloader
  542. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  543. fileURL: tempFileURL)))
  544. wait(for: [completionExpectation], timeout: 0.5)
  545. try? ModelFileManager.removeFile(at: tempFileURL)
  546. }
  547. /// Get model info if Firebase ML API is not enabled or permission denied.
  548. func testGetModelInfoWith403() {
  549. let errorMessage =
  550. """
  551. {
  552. "error":
  553. {
  554. "message":"API Disabled."
  555. }
  556. }
  557. """
  558. let session = fakeModelInfoSessionWithURL(
  559. fakeDownloadURL,
  560. statusCode: 403,
  561. customJSON: errorMessage
  562. )
  563. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  564. let modelInfoRetriever = fakeModelRetriever(
  565. fakeSession: session,
  566. fakeLocalModelInfo: localModelInfo
  567. )
  568. let completionExpectation = expectation(description: "Completion handler")
  569. modelInfoRetriever.downloadModelInfo { result in
  570. completionExpectation.fulfill()
  571. switch result {
  572. case .success:
  573. XCTFail("Unexpected new model info from server.")
  574. case let .failure(error):
  575. XCTAssertEqual(error, .permissionDenied)
  576. }
  577. }
  578. wait(for: [completionExpectation], timeout: 0.5)
  579. }
  580. /// Get model if multiple duplicate requests are made.
  581. func testModelDownloadWithMergeRequests() {
  582. let tempFileURL = tempFile()
  583. let numberOfRequests = 5
  584. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  585. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  586. statusCode: 200,
  587. httpVersion: nil,
  588. headerFields: nil)!
  589. let downloader = MockModelFileDownloader()
  590. let fileDownloaderExpectation = expectation(description: "File downloader")
  591. downloader.downloadFileHandler = { url in
  592. fileDownloaderExpectation.fulfill()
  593. XCTAssertEqual(url, self.fakeDownloadURL)
  594. }
  595. let completionExpectation = expectation(description: "Completion handler")
  596. completionExpectation.expectedFulfillmentCount = numberOfRequests + 1
  597. let taskCompletion: ModelDownloadTask.Completion = { result in
  598. completionExpectation.fulfill()
  599. switch result {
  600. case let .success(model):
  601. self.checkModel(model: model)
  602. case let .failure(error):
  603. XCTFail("Error - \(error)")
  604. }
  605. }
  606. let testName = String(#function.dropLast(2))
  607. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  608. appName: fakeAppName,
  609. defaults: .createUnitTestInstance(
  610. testName: testName
  611. ),
  612. downloader: downloader,
  613. completion: taskCompletion)
  614. for i in 1 ... numberOfRequests {
  615. let newTaskCompletion: ModelDownloadTask.Completion = { result in
  616. completionExpectation.fulfill()
  617. switch result {
  618. case let .success(model):
  619. self.checkModel(model: model)
  620. if i == numberOfRequests {
  621. self.deleteFile(at: model.path)
  622. }
  623. case let .failure(error):
  624. XCTFail("Error - \(error)")
  625. }
  626. }
  627. modelDownloadTask.resume()
  628. modelDownloadTask.merge(newCompletion: newTaskCompletion)
  629. }
  630. downloader
  631. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  632. fileURL: tempFileURL)))
  633. wait(for: [fileDownloaderExpectation], timeout: 1)
  634. wait(for: [completionExpectation], timeout: 2)
  635. try? ModelFileManager.removeFile(at: tempFileURL)
  636. }
  637. /// Get model if multiple duplicate requests are made (test at the API surface).
  638. func testGetModelWithMergeRequests() {
  639. guard let testApp = FirebaseApp.app() else {
  640. XCTFail("Default app was not configured.")
  641. return
  642. }
  643. testApp.isDataCollectionDefaultEnabled = false
  644. let modelDownloader = ModelDownloader.modelDownloader()
  645. let conditions = ModelDownloadConditions()
  646. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  647. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  648. let tempFileURL = tempFile()
  649. let numberOfRequests = 6
  650. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  651. statusCode: 200,
  652. httpVersion: nil,
  653. headerFields: nil)!
  654. let downloader = MockModelFileDownloader()
  655. let fileDownloaderExpectation = expectation(description: "File downloader")
  656. downloader.downloadFileHandler = { url in
  657. fileDownloaderExpectation.fulfill()
  658. XCTAssertEqual(url, self.fakeDownloadURL)
  659. }
  660. let completionExpectation = expectation(description: "Completion handler")
  661. completionExpectation.expectedFulfillmentCount = numberOfRequests
  662. for i in 1 ... numberOfRequests {
  663. modelDownloader.downloadInfoAndModel(modelName: "fakeModelName",
  664. modelInfoRetriever: modelInfoRetriever,
  665. downloader: downloader,
  666. conditions: conditions) { result in
  667. completionExpectation.fulfill()
  668. switch result {
  669. case let .success(model):
  670. self.checkModel(model: model)
  671. if i == numberOfRequests {
  672. self.deleteFile(at: model.path)
  673. }
  674. case let .failure(error):
  675. XCTFail("Error - \(error)")
  676. }
  677. }
  678. }
  679. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  680. downloader
  681. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  682. fileURL: tempFileURL)))
  683. wait(for: [completionExpectation], timeout: 1)
  684. try? ModelFileManager.removeFile(at: tempFileURL)
  685. }
  686. /// Get model if server returns an error due to model not found.
  687. func testModelDownloadWith404() {
  688. let tempFileURL = tempFile()
  689. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  690. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  691. statusCode: 404,
  692. httpVersion: nil,
  693. headerFields: nil)!
  694. let downloader = MockModelFileDownloader()
  695. let fileDownloaderExpectation = expectation(description: "File downloader")
  696. downloader.downloadFileHandler = { url in
  697. fileDownloaderExpectation.fulfill()
  698. XCTAssertEqual(url, self.fakeDownloadURL)
  699. }
  700. let completionExpectation = expectation(description: "Completion handler")
  701. let taskCompletion: ModelDownloadTask.Completion = { result in
  702. completionExpectation.fulfill()
  703. switch result {
  704. case .success: XCTFail("Unexpected successful model download.")
  705. case let .failure(error):
  706. XCTAssertEqual(error, .notFound)
  707. }
  708. }
  709. let testName = String(#function.dropLast(2))
  710. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  711. appName: "fakeAppName",
  712. defaults: .createUnitTestInstance(
  713. testName: testName
  714. ),
  715. downloader: downloader,
  716. completion: taskCompletion)
  717. modelDownloadTask.resume()
  718. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  719. downloader
  720. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  721. fileURL: tempFileURL)))
  722. wait(for: [completionExpectation], timeout: 0.5)
  723. try? ModelFileManager.removeFile(at: tempFileURL)
  724. }
  725. /// Get model if server returns an error due to model not found.
  726. func testModelDownloadNetworkError() {
  727. let tempFileURL = tempFile()
  728. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  729. let downloader = MockModelFileDownloader()
  730. let fileDownloaderExpectation = expectation(description: "File downloader")
  731. downloader.downloadFileHandler = { url in
  732. fileDownloaderExpectation.fulfill()
  733. XCTAssertEqual(url, self.fakeDownloadURL)
  734. }
  735. let completionExpectation = expectation(description: "Completion handler")
  736. let taskCompletion: ModelDownloadTask.Completion = { result in
  737. completionExpectation.fulfill()
  738. switch result {
  739. case .success: XCTFail("Unexpected successful model download.")
  740. case let .failure(error):
  741. XCTAssertEqual(error, .failedPrecondition)
  742. }
  743. }
  744. let testName = String(#function.dropLast(2))
  745. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  746. appName: "fakeAppName",
  747. defaults: .createUnitTestInstance(
  748. testName: testName
  749. ),
  750. downloader: downloader,
  751. completion: taskCompletion)
  752. modelDownloadTask.resume()
  753. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  754. let offlineError = DownloadError.internalError(description: "Network appears to be offline.")
  755. downloader
  756. .completion?(.failure(FileDownloaderError.networkError(offlineError)))
  757. wait(for: [completionExpectation], timeout: 0.5)
  758. try? ModelFileManager.removeFile(at: tempFileURL)
  759. }
  760. /// Get model with a valid server response.
  761. func testGetModelSuccessful() {
  762. guard let testApp = FirebaseApp.app() else {
  763. XCTFail("Default app was not configured.")
  764. return
  765. }
  766. testApp.isDataCollectionDefaultEnabled = false
  767. let modelDownloader = ModelDownloader.modelDownloader()
  768. let conditions = ModelDownloadConditions()
  769. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  770. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  771. let downloader = MockModelFileDownloader()
  772. let tempFileURL = tempFile()
  773. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  774. statusCode: 200,
  775. httpVersion: nil,
  776. headerFields: nil)!
  777. let fileDownloaderExpectation = expectation(description: "File downloader")
  778. let completionExpectation = expectation(description: "Completion handler")
  779. // Handles initial response.
  780. downloader.downloadFileHandler = { url in
  781. fileDownloaderExpectation.fulfill()
  782. XCTAssertEqual(url, self.fakeDownloadURL)
  783. }
  784. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  785. modelInfoRetriever: modelInfoRetriever,
  786. downloader: downloader,
  787. conditions: conditions) { result in
  788. completionExpectation.fulfill()
  789. switch result {
  790. case let .success(model):
  791. self.checkModel(model: model)
  792. self.deleteFile(at: model.path)
  793. case let .failure(error):
  794. XCTFail("Error - \(error)")
  795. }
  796. }
  797. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  798. downloader
  799. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  800. fileURL: tempFileURL)))
  801. wait(for: [completionExpectation], timeout: 0.5)
  802. try? ModelFileManager.removeFile(at: tempFileURL)
  803. }
  804. /// Get model if download url expired but retry was successful.
  805. func testGetModelSuccessfulRetry() {
  806. guard let testApp = FirebaseApp.app() else {
  807. XCTFail("Default app was not configured.")
  808. return
  809. }
  810. testApp.isDataCollectionDefaultEnabled = false
  811. let modelDownloader = ModelDownloader.modelDownloader()
  812. let conditions = ModelDownloadConditions()
  813. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  814. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  815. let downloader = MockModelFileDownloader()
  816. let tempFileURL = tempFile()
  817. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  818. statusCode: 400,
  819. httpVersion: nil,
  820. headerFields: nil)!
  821. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  822. statusCode: 200,
  823. httpVersion: nil,
  824. headerFields: nil)!
  825. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  826. let completionExpectation = expectation(description: "Completion handler.")
  827. // Handles initial response.
  828. downloader.downloadFileHandler = { url in
  829. fileDownloaderExpectation1.fulfill()
  830. XCTAssertEqual(url, self.fakeDownloadURL)
  831. }
  832. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  833. modelInfoRetriever: modelInfoRetriever,
  834. downloader: downloader,
  835. conditions: conditions) { result in
  836. completionExpectation.fulfill()
  837. switch result {
  838. // Retry successful - model returned.
  839. case let .success(model):
  840. self.checkModel(model: model)
  841. self.deleteFile(at: model.path)
  842. case let .failure(error):
  843. XCTFail("Error - \(error)")
  844. }
  845. }
  846. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  847. // Handles response after one retry.
  848. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  849. downloader.downloadFileHandler = { url in
  850. fileDownloaderExpectation2.fulfill()
  851. XCTAssertEqual(url, self.fakeDownloadURL)
  852. }
  853. downloader
  854. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  855. fileURL: tempFileURL)))
  856. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  857. downloader
  858. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  859. fileURL: tempFileURL)))
  860. wait(for: [completionExpectation], timeout: 0.5)
  861. try? ModelFileManager.removeFile(at: tempFileURL)
  862. }
  863. /// Get model if model doesn't exist.
  864. func testGetModelNotFound() {
  865. guard let testApp = FirebaseApp.app() else {
  866. XCTFail("Default app was not configured.")
  867. return
  868. }
  869. testApp.isDataCollectionDefaultEnabled = false
  870. let modelDownloader = ModelDownloader.modelDownloader()
  871. let conditions = ModelDownloadConditions()
  872. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  873. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  874. let downloader = MockModelFileDownloader()
  875. let tempFileURL = tempFile()
  876. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  877. statusCode: 404,
  878. httpVersion: nil,
  879. headerFields: nil)!
  880. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  881. let completionExpectation = expectation(description: "Completion handler")
  882. // Handles initial response.
  883. downloader.downloadFileHandler = { url in
  884. fileDownloaderExpectation.fulfill()
  885. XCTAssertEqual(url, self.fakeDownloadURL)
  886. }
  887. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  888. modelInfoRetriever: modelInfoRetriever,
  889. downloader: downloader,
  890. conditions: conditions) { result in
  891. completionExpectation.fulfill()
  892. switch result {
  893. // Retry failed due to another expired url - error returned.
  894. case .success: XCTFail("Unexpected successful model download.")
  895. case let .failure(error):
  896. switch error {
  897. case .notFound: break
  898. default: XCTFail("Expected failure due to model not found.")
  899. }
  900. }
  901. }
  902. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  903. downloader
  904. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  905. fileURL: tempFileURL)))
  906. wait(for: [completionExpectation], timeout: 0.5)
  907. try? ModelFileManager.removeFile(at: tempFileURL)
  908. }
  909. /// Get model if invalid or insufficient permissions.
  910. func testGetModelUnauthorized() {
  911. guard let testApp = FirebaseApp.app() else {
  912. XCTFail("Default app was not configured.")
  913. return
  914. }
  915. testApp.isDataCollectionDefaultEnabled = false
  916. let modelDownloader = ModelDownloader.modelDownloader()
  917. let conditions = ModelDownloadConditions()
  918. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  919. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  920. let downloader = MockModelFileDownloader()
  921. let tempFileURL = tempFile()
  922. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  923. statusCode: 403,
  924. httpVersion: nil,
  925. headerFields: nil)!
  926. let fileDownloaderExpectation = expectation(description: "File downloader")
  927. let completionExpectation = expectation(description: "Completion handler")
  928. // Handles initial response.
  929. downloader.downloadFileHandler = { url in
  930. fileDownloaderExpectation.fulfill()
  931. XCTAssertEqual(url, self.fakeDownloadURL)
  932. }
  933. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  934. modelInfoRetriever: modelInfoRetriever,
  935. downloader: downloader,
  936. conditions: conditions) { result in
  937. completionExpectation.fulfill()
  938. switch result {
  939. case .success: XCTFail("Unexpected successful model download.")
  940. case let .failure(error):
  941. switch error {
  942. case .permissionDenied: break
  943. default: XCTFail("Expected failure due to bad permissions.")
  944. }
  945. }
  946. }
  947. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  948. downloader
  949. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  950. fileURL: tempFileURL)))
  951. wait(for: [completionExpectation], timeout: 0.5)
  952. try? ModelFileManager.removeFile(at: tempFileURL)
  953. }
  954. /// Get model if download url expired and retry url also expired.
  955. func testGetModelFailedRetry() {
  956. guard let testApp = FirebaseApp.app() else {
  957. XCTFail("Default app was not configured.")
  958. return
  959. }
  960. testApp.isDataCollectionDefaultEnabled = false
  961. let modelDownloader = ModelDownloader.modelDownloader()
  962. let conditions = ModelDownloadConditions()
  963. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  964. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  965. let downloader = MockModelFileDownloader()
  966. let tempFileURL = tempFile()
  967. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  968. statusCode: 400,
  969. httpVersion: nil,
  970. headerFields: nil)!
  971. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  972. let completionExpectation = expectation(description: "Completion handler")
  973. // Handles initial response.
  974. downloader.downloadFileHandler = { url in
  975. fileDownloaderExpectation1.fulfill()
  976. XCTAssertEqual(url, self.fakeDownloadURL)
  977. }
  978. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  979. modelInfoRetriever: modelInfoRetriever,
  980. downloader: downloader,
  981. conditions: conditions) { result in
  982. completionExpectation.fulfill()
  983. switch result {
  984. // Retry failed due to another expired url - error returned.
  985. case .success: XCTFail("Unexpected successful model download.")
  986. case let .failure(error):
  987. switch error {
  988. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  989. "Unable to update expired model info.")
  990. default: XCTFail("Expected failure due to expired model info.")
  991. }
  992. }
  993. }
  994. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  995. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  996. // Handles response after one retry.
  997. downloader.downloadFileHandler = { url in
  998. fileDownloaderExpectation2.fulfill()
  999. XCTAssertEqual(url, self.fakeDownloadURL)
  1000. }
  1001. downloader
  1002. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1003. fileURL: tempFileURL)))
  1004. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1005. downloader
  1006. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1007. fileURL: tempFileURL)))
  1008. wait(for: [completionExpectation], timeout: 0.5)
  1009. try? ModelFileManager.removeFile(at: tempFileURL)
  1010. }
  1011. /// Get model if multiple retries all returned expired urls.
  1012. func testGetModelMultipleFailedRetries() {
  1013. guard let testApp = FirebaseApp.app() else {
  1014. XCTFail("Default app was not configured.")
  1015. return
  1016. }
  1017. testApp.isDataCollectionDefaultEnabled = false
  1018. let modelDownloader = ModelDownloader.modelDownloader()
  1019. modelDownloader.numberOfRetries = 2
  1020. let conditions = ModelDownloadConditions()
  1021. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  1022. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  1023. let downloader = MockModelFileDownloader()
  1024. let tempFileURL = tempFile()
  1025. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  1026. statusCode: 400,
  1027. httpVersion: nil,
  1028. headerFields: nil)!
  1029. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  1030. let completionExpectation = expectation(description: "Completion handler")
  1031. // Handles initial response.
  1032. downloader.downloadFileHandler = { url in
  1033. fileDownloaderExpectation1.fulfill()
  1034. XCTAssertEqual(url, self.fakeDownloadURL)
  1035. }
  1036. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  1037. modelInfoRetriever: modelInfoRetriever,
  1038. downloader: downloader,
  1039. conditions: conditions) { result in
  1040. completionExpectation.fulfill()
  1041. switch result {
  1042. // Two retries failed due to expired urls - error returned.
  1043. case .success: XCTFail("Unexpected successful model download.")
  1044. case let .failure(error):
  1045. switch error {
  1046. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  1047. "Unable to update expired model info.")
  1048. default: XCTFail("Expected failure due to expired model info.")
  1049. }
  1050. }
  1051. }
  1052. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  1053. let fileDownloaderExpectation2 = expectation(description: "file downloader after first retry")
  1054. // Handles response after first retry.
  1055. downloader.downloadFileHandler = { url in
  1056. fileDownloaderExpectation2.fulfill()
  1057. XCTAssertEqual(url, self.fakeDownloadURL)
  1058. }
  1059. downloader
  1060. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1061. fileURL: tempFileURL)))
  1062. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1063. let fileDownloaderExpectation3 =
  1064. expectation(description: "File downloader after second retry")
  1065. // Handles response after second retry.
  1066. downloader.downloadFileHandler = { url in
  1067. fileDownloaderExpectation3.fulfill()
  1068. XCTAssertEqual(url, self.fakeDownloadURL)
  1069. }
  1070. downloader
  1071. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1072. fileURL: tempFileURL)))
  1073. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  1074. downloader
  1075. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1076. fileURL: tempFileURL)))
  1077. wait(for: [completionExpectation], timeout: 0.5)
  1078. try? ModelFileManager.removeFile(at: tempFileURL)
  1079. }
  1080. /// Get model if a retry finally returns an unexpired download url.
  1081. func testGetModelMultipleRetriesWithSuccess() {
  1082. guard let testApp = FirebaseApp.app() else {
  1083. XCTFail("Default app was not configured.")
  1084. return
  1085. }
  1086. testApp.isDataCollectionDefaultEnabled = false
  1087. let modelDownloader = ModelDownloader.modelDownloader()
  1088. modelDownloader.numberOfRetries = 4
  1089. let conditions = ModelDownloadConditions()
  1090. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  1091. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  1092. let downloader = MockModelFileDownloader()
  1093. let tempFileURL = tempFile()
  1094. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  1095. statusCode: 400,
  1096. httpVersion: nil,
  1097. headerFields: nil)!
  1098. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  1099. statusCode: 200,
  1100. httpVersion: nil,
  1101. headerFields: nil)!
  1102. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  1103. let completionExpectation = expectation(description: "Completion handler")
  1104. // Handles initial response.
  1105. downloader.downloadFileHandler = { url in
  1106. fileDownloaderExpectation1.fulfill()
  1107. XCTAssertEqual(url, self.fakeDownloadURL)
  1108. }
  1109. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  1110. modelInfoRetriever: modelInfoRetriever,
  1111. downloader: downloader,
  1112. conditions: conditions) { result in
  1113. completionExpectation.fulfill()
  1114. switch result {
  1115. // One retry failed but the next retry succeeded - model returned.
  1116. case let .success(model):
  1117. self.checkModel(model: model)
  1118. self.deleteFile(at: model.path)
  1119. case let .failure(error):
  1120. XCTFail("Error - \(error)")
  1121. }
  1122. }
  1123. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  1124. let fileDownloaderExpectation2 = expectation(description: "File downloader after first retry")
  1125. // Handles response after first retry.
  1126. downloader.downloadFileHandler = { url in
  1127. fileDownloaderExpectation2.fulfill()
  1128. XCTAssertEqual(url, self.fakeDownloadURL)
  1129. }
  1130. downloader
  1131. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1132. fileURL: tempFileURL)))
  1133. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1134. let fileDownloaderExpectation3 =
  1135. expectation(description: "File downloader after second retry")
  1136. // Handles response after second retry.
  1137. downloader.downloadFileHandler = { url in
  1138. fileDownloaderExpectation3.fulfill()
  1139. XCTAssertEqual(url, self.fakeDownloadURL)
  1140. }
  1141. downloader
  1142. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1143. fileURL: tempFileURL)))
  1144. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  1145. downloader
  1146. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  1147. fileURL: tempFileURL)))
  1148. wait(for: [completionExpectation], timeout: 0.5)
  1149. try? ModelFileManager.removeFile(at: tempFileURL)
  1150. }
  1151. }
  1152. /// Mock URL session for testing.
  1153. class MockModelInfoRetrieverSession: ModelInfoRetrieverSession {
  1154. var data: Data?
  1155. var response: URLResponse?
  1156. var error: Error?
  1157. func getModelInfo(with request: URLRequest,
  1158. completion: @escaping (Data?, URLResponse?, Error?) -> Void) {
  1159. completion(data, response, error)
  1160. }
  1161. }
  1162. class MockModelFileDownloader: FileDownloader {
  1163. var progressHandler: ProgressHandler?
  1164. var completion: CompletionHandler?
  1165. var downloadFileHandler: ((_ url: URL) -> Void)?
  1166. func downloadFile(with url: URL, progressHandler: @escaping ProgressHandler,
  1167. completion: @escaping CompletionHandler) {
  1168. self.progressHandler = progressHandler
  1169. self.completion = completion
  1170. downloadFileHandler?(url)
  1171. }
  1172. }
  1173. extension UserDefaults {
  1174. /// Returns a new cleared instance of user defaults.
  1175. static func createUnitTestInstance(testName: String) -> UserDefaults {
  1176. let suiteName = "com.google.firebase.ml.test.\(testName)"
  1177. let defaults = UserDefaults(suiteName: suiteName)!
  1178. defaults.removePersistentDomain(forName: suiteName)
  1179. return defaults
  1180. }
  1181. /// Returns the existing user defaults instance.
  1182. static func getUnitTestInstance(testName: String) -> UserDefaults {
  1183. let suiteName = "com.google.firebase.ml.test.\(testName)"
  1184. return UserDefaults(suiteName: suiteName)!
  1185. }
  1186. }
  1187. // MARK: - Helpers
  1188. extension ModelDownloaderUnitTests {
  1189. func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int,
  1190. customJSON: String? = nil) -> MockModelInfoRetrieverSession {
  1191. let fakeSession = MockModelInfoRetrieverSession()
  1192. let fakeJSON = customJSON ?? fakeModelJSON
  1193. fakeSession.data = fakeJSON.data(using: .utf8)
  1194. fakeSession.response = HTTPURLResponse(
  1195. url: url,
  1196. statusCode: statusCode,
  1197. httpVersion: nil,
  1198. headerFields: ["Etag": "fakeModelHash"]
  1199. )
  1200. fakeSession.error = nil
  1201. return fakeSession
  1202. }
  1203. func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
  1204. fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
  1205. // TODO: Replace with fake to check if it was used correctly by the download task.
  1206. let modelInfoRetriever = ModelInfoRetriever(
  1207. modelName: fakeModelName,
  1208. projectID: fakeProjectID,
  1209. apiKey: fakeAPIKey,
  1210. appName: fakeAppName, authTokenProvider: successAuthTokenProvider,
  1211. session: fakeSession, localModelInfo: fakeLocalModelInfo
  1212. )
  1213. return modelInfoRetriever
  1214. }
  1215. func checkModel(model: CustomModel) {
  1216. XCTAssertEqual(model.name, fakeModelName)
  1217. XCTAssertEqual(model.size, fakeModelSize)
  1218. XCTAssertEqual(model.hash, fakeModelHash)
  1219. let modelURL = URL(fileURLWithPath: model.path)
  1220. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  1221. }
  1222. func tempFile(name: String = "fake-model-file.tmp") -> URL {
  1223. let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
  1224. isDirectory: true)
  1225. let tempFileURL = tempDirectoryURL.appendingPathComponent(name)
  1226. let tempData: Data = "fakeModelData".data(using: .utf8)!
  1227. try? tempData.write(to: tempFileURL)
  1228. return tempFileURL
  1229. }
  1230. func tempModelFile(invalid: Bool = false) -> URL {
  1231. let modelDirectoryURL = ModelFileManager.modelsDirectory!
  1232. let modelFileName = invalid ? "fbml_model_fake-model-file.tmp" :
  1233. "fbml_model@@__FIRAPP_DEFAULT@@\(fakeModelName)"
  1234. let tempFileURL = modelDirectoryURL.appendingPathComponent(modelFileName)
  1235. .appendingPathExtension("tflite")
  1236. let tempData: Data = "fakeModelData".data(using: .utf8)!
  1237. try? tempData.write(to: tempFileURL)
  1238. return tempFileURL
  1239. }
  1240. func deleteFile(at path: String) {
  1241. let url = URL(fileURLWithPath: path)
  1242. do {
  1243. try ModelFileManager.removeFile(at: url)
  1244. } catch {
  1245. XCTFail("Failed to delete file at \(path).")
  1246. }
  1247. }
  1248. func fakeLocalModelInfo(mismatch: Bool = false) -> LocalModelInfo {
  1249. let hash = mismatch ? "wrongModelHash" : fakeModelHash
  1250. return LocalModelInfo(
  1251. name: fakeModelName,
  1252. modelHash: hash,
  1253. size: 20
  1254. )
  1255. }
  1256. func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
  1257. let expiryTime = expired ? Date() : Date().addingTimeInterval(1000)
  1258. let size = oversized ? Int(Int64.max) / 4 : fakeModelSize
  1259. return RemoteModelInfo(name: fakeModelName,
  1260. downloadURL: fakeDownloadURL,
  1261. modelHash: fakeModelHash,
  1262. size: size,
  1263. urlExpiryTime: expiryTime)
  1264. }
  1265. }
  1266. #endif // !targetEnvironment(macCatalyst) && !os(macOS)