ModelDownloaderUnitTests.swift 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
  15. // involve interactions with the keychain that require a provisioning profile.
  16. // This is because MLModelDownloader depends on Installations.
  17. // See go/firebase-macos-keychain-popups for more details.
  18. #if !targetEnvironment(macCatalyst) && !os(macOS)
  19. import XCTest
  20. @testable import FirebaseCore
  21. @testable import FirebaseInstallations
  22. @testable import FirebaseMLModelDownloader
  23. /// Mock options to configure default Firebase app.
  24. private enum MockOptions {
  25. static let appID = "1:123:ios:123abc"
  26. static let gcmSenderID = "mock-sender-id"
  27. static let projectID = "mock-project-id"
  28. static let apiKey = "ABcdEf-APIKeyWithValidFormat_0123456789"
  29. }
  30. // TODO: Create separate files for each class tested.
  31. final class ModelDownloaderUnitTests: XCTestCase {
  32. let fakeModelName = "fakeModelName"
  33. let fakemodelPath = "fakemodelPath"
  34. let fakeModelHash = "fakeModelHash"
  35. let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
  36. let fakeFileURL = URL(string: "www.fake-model-file.com")!
  37. let fakeExpiryTime = "2021-01-20T04:20:10.220Z"
  38. let fakeModelSize = 20
  39. let fakeProjectID = "fakeProjectID"
  40. let fakeAPIKey = "fakeAPIKey"
  41. let fakeAppName = "fakeAppName"
  42. var fakeModelJSON: String {
  43. """
  44. {
  45. "downloadUri":"\(fakeDownloadURL)",
  46. "expireTime":"\(fakeExpiryTime)",
  47. "sizeBytes":"\(fakeModelSize)"
  48. }
  49. """
  50. }
  51. let successAuthTokenProvider =
  52. { (completion: @escaping (Result<String, DownloadError>) -> Void) in
  53. completion(.success("fakeFISToken"))
  54. }
  55. override class func setUp() {
  56. let options = FirebaseOptions(
  57. googleAppID: MockOptions.appID,
  58. gcmSenderID: MockOptions.gcmSenderID
  59. )
  60. options.apiKey = MockOptions.apiKey
  61. options.projectID = MockOptions.projectID
  62. FirebaseApp.configure(options: options)
  63. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  64. }
  65. override func setUp() {
  66. do {
  67. try ModelFileManager.emptyModelsDirectory()
  68. } catch {
  69. XCTFail("Could not empty models directory.")
  70. }
  71. }
  72. override class func tearDown() {
  73. try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
  74. FirebaseApp.app()?.delete { _ in }
  75. }
  76. /// Test if the same downloader instance is returned for an app.
  77. func testModelDownloaderInit() {
  78. guard let testApp = FirebaseApp.app() else {
  79. XCTFail("Default app was not configured.")
  80. return
  81. }
  82. testApp.isDataCollectionDefaultEnabled = false
  83. let modelDownloader1 = ModelDownloader.modelDownloader(app: testApp)
  84. let modelDownloader2 = ModelDownloader.modelDownloader()
  85. XCTAssert(modelDownloader1 === modelDownloader2)
  86. }
  87. /// Test if app deletion removes associated downloader instance.
  88. func testAppDeletion() {
  89. guard let testApp = FirebaseApp.app() else {
  90. XCTFail("Default app was not configured.")
  91. return
  92. }
  93. testApp.isDataCollectionDefaultEnabled = false
  94. let modelDownloader1 = ModelDownloader.modelDownloader()
  95. testApp.delete { success in
  96. XCTAssertTrue(success)
  97. }
  98. ModelDownloaderUnitTests.setUp()
  99. let modelDownloader2 = ModelDownloader.modelDownloader()
  100. XCTAssert(modelDownloader1 !== modelDownloader2)
  101. }
  102. /// Test invalid model file deletion.
  103. func testDeleteInvalidModel() {
  104. guard let testApp = FirebaseApp.app() else {
  105. XCTFail("Default app was not configured.")
  106. return
  107. }
  108. testApp.isDataCollectionDefaultEnabled = false
  109. let modelDownloader = ModelDownloader.modelDownloader()
  110. modelDownloader.deleteDownloadedModel(name: fakeModelName) { result in
  111. switch result {
  112. case .success: XCTFail("Unexpected successful model deletion.")
  113. case let .failure(error):
  114. XCTAssertNotNil(error)
  115. }
  116. }
  117. }
  118. func testGetModelNameFromFilePath() {
  119. var fakeURL = URL(fileURLWithPath: "modelDirectory/@@appName@@\(fakeModelName).tflite")
  120. XCTAssertEqual(ModelFileManager.getModelNameFromFilePath(fakeURL), fakeModelName)
  121. fakeURL = URL(fileURLWithPath: "modelDirectory/--appName--\(fakeModelName).tflite")
  122. XCTAssertNil(ModelFileManager.getModelNameFromFilePath(fakeURL))
  123. }
  124. /// Test listing models.
  125. func testListModels() {
  126. guard let testApp = FirebaseApp.app() else {
  127. XCTFail("Default app was not configured.")
  128. return
  129. }
  130. testApp.isDataCollectionDefaultEnabled = false
  131. let testName = String(#function.dropLast(2))
  132. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  133. .createUnitTestInstance(testName: testName),
  134. app: testApp
  135. )
  136. let tempModelFileURL = tempModelFile()
  137. let completionExpectation = expectation(description: "Completion handler")
  138. let localModelInfo = fakeLocalModelInfo()
  139. localModelInfo.writeToDefaults(
  140. .getUnitTestInstance(testName: testName),
  141. appName: testApp.name
  142. )
  143. modelDownloader.listDownloadedModels { result in
  144. completionExpectation.fulfill()
  145. switch result {
  146. case let .success(models):
  147. /// Expected success since model info was written to user defaults.
  148. XCTAssertEqual(models.first?.path, tempModelFileURL.path)
  149. case let .failure(error):
  150. XCTFail("Error - \(error)")
  151. }
  152. }
  153. wait(for: [completionExpectation], timeout: 0.5)
  154. try? ModelFileManager.removeFile(at: tempModelFileURL)
  155. }
  156. /// Test invalid path in model directory.
  157. func testListModelsInvalidPath() {
  158. guard let testApp = FirebaseApp.app() else {
  159. XCTFail("Default app was not configured.")
  160. return
  161. }
  162. testApp.isDataCollectionDefaultEnabled = false
  163. let modelDownloader = ModelDownloader.modelDownloader()
  164. let invalidTempModelFileURL = tempModelFile(invalid: true)
  165. let completionExpectation = expectation(description: "Completion handler")
  166. modelDownloader.listDownloadedModels { result in
  167. completionExpectation.fulfill()
  168. switch result {
  169. case .success:
  170. XCTFail("Unexpected success of list models.")
  171. case let .failure(error):
  172. switch error {
  173. case let .internalError(description):
  174. XCTAssertTrue(description
  175. .contains("List models failed due to unexpected model file name"))
  176. default: XCTFail("Expected failure due to bad model name.")
  177. }
  178. }
  179. }
  180. wait(for: [completionExpectation], timeout: 0.5)
  181. try? ModelFileManager.removeFile(at: invalidTempModelFileURL)
  182. }
  183. /// Test listing models without local info in user defaults.
  184. func testListModelsNoLocalInfo() {
  185. guard let testApp = FirebaseApp.app() else {
  186. XCTFail("Default app was not configured.")
  187. return
  188. }
  189. testApp.isDataCollectionDefaultEnabled = false
  190. let testName = String(#function.dropLast(2))
  191. let tempModelFileURL = tempModelFile()
  192. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  193. .createUnitTestInstance(testName: testName),
  194. app: testApp
  195. )
  196. let completionExpectation = expectation(description: "Completion handler")
  197. modelDownloader.listDownloadedModels { result in
  198. completionExpectation.fulfill()
  199. switch result {
  200. case .success:
  201. XCTFail("Unexpected success of list models.")
  202. case let .failure(error):
  203. switch error {
  204. /// Expected failure since no model info was written to user defaults.
  205. case let .internalError(description):
  206. XCTAssertTrue(description.contains("List models failed due to no local model info"))
  207. default: XCTFail("Expected failure due to bad model name.")
  208. }
  209. }
  210. }
  211. wait(for: [completionExpectation], timeout: 0.5)
  212. try? ModelFileManager.removeFile(at: tempModelFileURL)
  213. }
  214. /// Test how URL conversion behaves if there are spaces in the path.
  215. func testURLConversion() {
  216. let fakeURLWithSpace = URL(string: "file:///fakeDir1/fake%20Dir2/fakeFile")!
  217. XCTAssertEqual(fakeURLWithSpace, URL(string: fakeURLWithSpace.absoluteString))
  218. XCTAssertEqual(fakeURLWithSpace, URL(fileURLWithPath: fakeURLWithSpace.path))
  219. /// Strings without spaces should work fine either way.
  220. let fakeURLWithoutSpace = URL(string: "fakeDir1/fakeDir2/fakeFile")!
  221. XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.absoluteString)!)
  222. XCTAssertEqual(fakeURLWithoutSpace, URL(string: fakeURLWithoutSpace.path)!)
  223. }
  224. /// Test on-device logging.
  225. func testDeviceLogging() {
  226. DeviceLogger.logEvent(level: .error,
  227. message: "This is a test error logged to device.",
  228. messageCode: .testError)
  229. }
  230. /// Compare proto serialization methods.
  231. func testTelemetryEncoding() {
  232. let fakeModel = CustomModel(
  233. name: fakeModelName,
  234. size: 10,
  235. path: fakemodelPath,
  236. hash: fakeModelHash
  237. )
  238. var modelOptions = ModelOptions()
  239. modelOptions.setModelOptions(modelName: fakeModel.name, modelHash: fakeModel.hash)
  240. guard let binaryData = try? modelOptions.serializedData(),
  241. let jsonData = try? modelOptions.jsonUTF8Data(),
  242. let binaryEvent = try? ModelOptions(serializedData: binaryData),
  243. let jsonEvent = try? ModelOptions(jsonUTF8Data: jsonData) else {
  244. XCTFail("Encoding error.")
  245. return
  246. }
  247. XCTAssertNotNil(binaryData)
  248. XCTAssertNotNil(jsonData)
  249. XCTAssertLessThan(binaryData.count, jsonData.count)
  250. XCTAssertEqual(binaryEvent, jsonEvent)
  251. }
  252. /// Get model info if server returns a new model info.
  253. func testGetModelInfoWith200() {
  254. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  255. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  256. let completionExpectation = expectation(description: "Completion handler")
  257. modelInfoRetriever.downloadModelInfo { result in
  258. completionExpectation.fulfill()
  259. switch result {
  260. case let .success(fakemodelInfoResult):
  261. switch fakemodelInfoResult {
  262. case let .modelInfo(remoteModelInfo):
  263. XCTAssertEqual(remoteModelInfo.name, self.fakeModelName)
  264. XCTAssertEqual(remoteModelInfo.downloadURL, self.fakeDownloadURL)
  265. XCTAssertEqual(remoteModelInfo.size, self.fakeModelSize)
  266. XCTAssertEqual(remoteModelInfo.modelHash, self.fakeModelHash)
  267. case .notModified: XCTFail("Expected new model info from server.")
  268. }
  269. case let .failure(error): XCTFail("Failed with error - \(error)")
  270. }
  271. }
  272. wait(for: [completionExpectation], timeout: 0.5)
  273. }
  274. /// Get model info if model info is not modified.
  275. func testGetModelInfoWith304() {
  276. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  277. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  278. let modelInfoRetriever = fakeModelRetriever(
  279. fakeSession: session,
  280. fakeLocalModelInfo: localModelInfo
  281. )
  282. let completionExpectation = expectation(description: "Completion handler")
  283. modelInfoRetriever.downloadModelInfo { result in
  284. completionExpectation.fulfill()
  285. switch result {
  286. case let .success(fakemodelInfoResult):
  287. switch fakemodelInfoResult {
  288. case .modelInfo: XCTFail("Unexpected new model info from server.")
  289. case .notModified: break
  290. }
  291. case let .failure(error): XCTFail("Failed with error - \(error)")
  292. }
  293. }
  294. wait(for: [completionExpectation], timeout: 0.5)
  295. }
  296. /// Get model info if model info is not modified but local model info is not set.
  297. func testGetModelInfoWith304MissingLocalInfo() {
  298. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  299. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  300. let completionExpectation = expectation(description: "Completion handler")
  301. modelInfoRetriever.downloadModelInfo { result in
  302. completionExpectation.fulfill()
  303. switch result {
  304. case let .success(fakemodelInfoResult):
  305. switch fakemodelInfoResult {
  306. case .modelInfo: XCTFail("Unexpected new model info from server.")
  307. case .notModified: XCTFail("Expected failure since local model info was not set.")
  308. }
  309. case let .failure(error):
  310. switch error {
  311. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  312. "Model info was deleted unexpectedly.")
  313. default: XCTFail("Expected failure since local model info was not set.")
  314. }
  315. }
  316. }
  317. wait(for: [completionExpectation], timeout: 0.5)
  318. }
  319. /// Get model info if model info is not modified but local model info is not set.
  320. func testGetModelInfoWith304HashMismatch() {
  321. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  322. let localModelInfo = fakeLocalModelInfo(mismatch: true)
  323. let modelInfoRetriever = fakeModelRetriever(
  324. fakeSession: session,
  325. fakeLocalModelInfo: localModelInfo
  326. )
  327. let completionExpectation = expectation(description: "Completion handler")
  328. modelInfoRetriever.downloadModelInfo { result in
  329. completionExpectation.fulfill()
  330. switch result {
  331. case let .success(fakemodelInfoResult):
  332. switch fakemodelInfoResult {
  333. case .modelInfo: XCTFail("Unexpected new model info from server.")
  334. case .notModified: XCTFail("Expected failure since model hash does not match.")
  335. }
  336. case let .failure(error):
  337. switch error {
  338. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  339. "Unexpected model hash value.")
  340. default: XCTFail("Expected failure since local model info was not set.")
  341. }
  342. }
  343. }
  344. wait(for: [completionExpectation], timeout: 0.5)
  345. }
  346. /// Get model info if model name is invalid.
  347. func testGetModelInfoWith400() {
  348. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 400)
  349. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  350. let modelInfoRetriever = fakeModelRetriever(
  351. fakeSession: session,
  352. fakeLocalModelInfo: localModelInfo
  353. )
  354. let completionExpectation = expectation(description: "Completion handler")
  355. modelInfoRetriever.downloadModelInfo { result in
  356. completionExpectation.fulfill()
  357. switch result {
  358. case let .success(fakemodelInfoResult):
  359. switch fakemodelInfoResult {
  360. case .modelInfo: XCTFail("Unexpected new model info from server.")
  361. case .notModified: XCTFail("Expected failure since model name is invalid.")
  362. }
  363. case let .failure(error):
  364. XCTAssertEqual(error, .invalidArgument)
  365. }
  366. }
  367. wait(for: [completionExpectation], timeout: 0.5)
  368. }
  369. /// Get model info if model name does not exist.
  370. func testGetModelInfoWith404() {
  371. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 404)
  372. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  373. let modelInfoRetriever = fakeModelRetriever(
  374. fakeSession: session,
  375. fakeLocalModelInfo: localModelInfo
  376. )
  377. let completionExpectation = expectation(description: "Completion handler")
  378. modelInfoRetriever.downloadModelInfo { result in
  379. completionExpectation.fulfill()
  380. switch result {
  381. case let .success(fakemodelInfoResult):
  382. switch fakemodelInfoResult {
  383. case .modelInfo: XCTFail("Unexpected new model info from server.")
  384. case .notModified: XCTFail("Expected failure since model name is not availabl.")
  385. }
  386. case let .failure(error):
  387. XCTAssertEqual(error, .notFound)
  388. }
  389. }
  390. wait(for: [completionExpectation], timeout: 0.5)
  391. }
  392. /// Get model info if resource exhausted due to too many requests.
  393. func testGetModelInfoWith429() {
  394. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 429)
  395. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  396. let modelInfoRetriever = fakeModelRetriever(
  397. fakeSession: session,
  398. fakeLocalModelInfo: localModelInfo
  399. )
  400. let completionExpectation = expectation(description: "Completion handler")
  401. modelInfoRetriever.downloadModelInfo { result in
  402. completionExpectation.fulfill()
  403. switch result {
  404. case let .success(fakemodelInfoResult):
  405. switch fakemodelInfoResult {
  406. case .modelInfo: XCTFail("Unexpected new model info from server.")
  407. case .notModified: XCTFail("Expected failure since resource exhausted.")
  408. }
  409. case let .failure(error):
  410. XCTAssertEqual(error, .resourceExhausted)
  411. }
  412. }
  413. wait(for: [completionExpectation], timeout: 0.5)
  414. }
  415. /// Get model if server returns a model file.
  416. func testModelDownloadWith200() {
  417. let tempFileURL = tempFile()
  418. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  419. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  420. statusCode: 200,
  421. httpVersion: nil,
  422. headerFields: nil)!
  423. let downloader = MockModelFileDownloader()
  424. let fileDownloaderExpectation = expectation(description: "File downloader")
  425. downloader.downloadFileHandler = { url in
  426. fileDownloaderExpectation.fulfill()
  427. XCTAssertEqual(url, self.fakeDownloadURL)
  428. }
  429. let progressExpectation = expectation(description: "Progress handler")
  430. progressExpectation.expectedFulfillmentCount = 2
  431. let completionExpectation = expectation(description: "Completion handler")
  432. let taskProgressHandler: ModelDownloadTask.ProgressHandler = {
  433. progress in
  434. progressExpectation.fulfill()
  435. XCTAssertEqual(progress, 0.4)
  436. }
  437. let taskCompletion: ModelDownloadTask.Completion = { result in
  438. completionExpectation.fulfill()
  439. switch result {
  440. case let .success(model):
  441. self.checkModel(model: model)
  442. self.deleteFile(at: model.path)
  443. case let .failure(error):
  444. XCTFail("Error - \(error)")
  445. }
  446. }
  447. let testName = String(#function.dropLast(2))
  448. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  449. appName: "fakeAppName",
  450. defaults: .createUnitTestInstance(
  451. testName: testName
  452. ),
  453. downloader: downloader,
  454. progressHandler: taskProgressHandler,
  455. completion: taskCompletion)
  456. modelDownloadTask.resume()
  457. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  458. downloader.progressHandler?(100, 250)
  459. downloader.progressHandler?(100, 250)
  460. wait(for: [progressExpectation], timeout: 0.5)
  461. downloader
  462. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  463. fileURL: tempFileURL)))
  464. wait(for: [completionExpectation], timeout: 0.5)
  465. try? ModelFileManager.removeFile(at: tempFileURL)
  466. }
  467. /// Get model if server returns an error due to expired url.
  468. func testModelDownloadWith400Expired() {
  469. let tempFileURL = tempFile()
  470. let remoteModelInfo = fakeModelInfo(expired: true, oversized: false)
  471. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  472. statusCode: 400,
  473. httpVersion: nil,
  474. headerFields: nil)!
  475. let downloader = MockModelFileDownloader()
  476. let fileDownloaderExpectation = expectation(description: "File downloader")
  477. downloader.downloadFileHandler = { url in
  478. fileDownloaderExpectation.fulfill()
  479. XCTAssertEqual(url, self.fakeDownloadURL)
  480. }
  481. let completionExpectation = expectation(description: "Completion handler")
  482. let taskCompletion: ModelDownloadTask.Completion = { result in
  483. completionExpectation.fulfill()
  484. switch result {
  485. case .success: XCTFail("Unexpected successful model download.")
  486. case let .failure(error):
  487. XCTAssertEqual(error, .expiredDownloadURL)
  488. }
  489. }
  490. let testName = String(#function.dropLast(2))
  491. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  492. appName: "fakeAppName",
  493. defaults: .createUnitTestInstance(
  494. testName: testName
  495. ),
  496. downloader: downloader,
  497. completion: taskCompletion)
  498. modelDownloadTask.resume()
  499. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  500. downloader
  501. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  502. fileURL: tempFileURL)))
  503. wait(for: [completionExpectation], timeout: 0.5)
  504. try? ModelFileManager.removeFile(at: tempFileURL)
  505. }
  506. /// Get model if server returns an error due to invalid model name.
  507. func testModelDownloadWith400Invalid() {
  508. let tempFileURL = tempFile()
  509. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  510. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  511. statusCode: 400,
  512. httpVersion: nil,
  513. headerFields: nil)!
  514. let downloader = MockModelFileDownloader()
  515. let fileDownloaderExpectation = expectation(description: "File downloader")
  516. downloader.downloadFileHandler = { url in
  517. fileDownloaderExpectation.fulfill()
  518. XCTAssertEqual(url, self.fakeDownloadURL)
  519. }
  520. let completionExpectation = expectation(description: "Completion handler")
  521. let taskCompletion: ModelDownloadTask.Completion = { result in
  522. completionExpectation.fulfill()
  523. switch result {
  524. case .success: XCTFail("Unexpected successful model download.")
  525. case let .failure(error):
  526. XCTAssertEqual(error, .invalidArgument)
  527. }
  528. }
  529. let testName = String(#function.dropLast(2))
  530. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  531. appName: "fakeAppName",
  532. defaults: .createUnitTestInstance(
  533. testName: testName
  534. ),
  535. downloader: downloader,
  536. completion: taskCompletion)
  537. modelDownloadTask.resume()
  538. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  539. downloader
  540. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  541. fileURL: tempFileURL)))
  542. wait(for: [completionExpectation], timeout: 0.5)
  543. try? ModelFileManager.removeFile(at: tempFileURL)
  544. }
  545. /// Get model info if Firebase ML API is not enabled or permission denied.
  546. func testGetModelInfoWith403() {
  547. let errorMessage =
  548. """
  549. {
  550. "error":
  551. {
  552. "message":"API Disabled."
  553. }
  554. }
  555. """
  556. let session = fakeModelInfoSessionWithURL(
  557. fakeDownloadURL,
  558. statusCode: 403,
  559. customJSON: errorMessage
  560. )
  561. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  562. let modelInfoRetriever = fakeModelRetriever(
  563. fakeSession: session,
  564. fakeLocalModelInfo: localModelInfo
  565. )
  566. let completionExpectation = expectation(description: "Completion handler")
  567. modelInfoRetriever.downloadModelInfo { result in
  568. completionExpectation.fulfill()
  569. switch result {
  570. case .success:
  571. XCTFail("Unexpected new model info from server.")
  572. case let .failure(error):
  573. XCTAssertEqual(error, .permissionDenied)
  574. }
  575. }
  576. wait(for: [completionExpectation], timeout: 0.5)
  577. }
  578. /// Get model if multiple duplicate requests are made.
  579. func testModelDownloadWithMergeRequests() {
  580. let tempFileURL = tempFile()
  581. let numberOfRequests = 5
  582. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  583. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  584. statusCode: 200,
  585. httpVersion: nil,
  586. headerFields: nil)!
  587. let downloader = MockModelFileDownloader()
  588. let fileDownloaderExpectation = expectation(description: "File downloader")
  589. downloader.downloadFileHandler = { url in
  590. fileDownloaderExpectation.fulfill()
  591. XCTAssertEqual(url, self.fakeDownloadURL)
  592. }
  593. let completionExpectation = expectation(description: "Completion handler")
  594. completionExpectation.expectedFulfillmentCount = numberOfRequests + 1
  595. let taskCompletion: ModelDownloadTask.Completion = { result in
  596. completionExpectation.fulfill()
  597. switch result {
  598. case let .success(model):
  599. self.checkModel(model: model)
  600. case let .failure(error):
  601. XCTFail("Error - \(error)")
  602. }
  603. }
  604. let testName = String(#function.dropLast(2))
  605. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  606. appName: fakeAppName,
  607. defaults: .createUnitTestInstance(
  608. testName: testName
  609. ),
  610. downloader: downloader,
  611. completion: taskCompletion)
  612. for i in 1 ... numberOfRequests {
  613. let newTaskCompletion: ModelDownloadTask.Completion = { result in
  614. completionExpectation.fulfill()
  615. switch result {
  616. case let .success(model):
  617. self.checkModel(model: model)
  618. if i == numberOfRequests {
  619. self.deleteFile(at: model.path)
  620. }
  621. case let .failure(error):
  622. XCTFail("Error - \(error)")
  623. }
  624. }
  625. modelDownloadTask.resume()
  626. modelDownloadTask.merge(newCompletion: newTaskCompletion)
  627. }
  628. downloader
  629. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  630. fileURL: tempFileURL)))
  631. wait(for: [fileDownloaderExpectation], timeout: 1)
  632. wait(for: [completionExpectation], timeout: 2)
  633. try? ModelFileManager.removeFile(at: tempFileURL)
  634. }
  635. /// Get model if multiple duplicate requests are made (test at the API surface).
  636. func testGetModelWithMergeRequests() {
  637. guard let testApp = FirebaseApp.app() else {
  638. XCTFail("Default app was not configured.")
  639. return
  640. }
  641. testApp.isDataCollectionDefaultEnabled = false
  642. let modelDownloader = ModelDownloader.modelDownloader()
  643. let conditions = ModelDownloadConditions()
  644. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  645. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  646. let tempFileURL = tempFile()
  647. let numberOfRequests = 6
  648. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  649. statusCode: 200,
  650. httpVersion: nil,
  651. headerFields: nil)!
  652. let downloader = MockModelFileDownloader()
  653. let fileDownloaderExpectation = expectation(description: "File downloader")
  654. downloader.downloadFileHandler = { url in
  655. fileDownloaderExpectation.fulfill()
  656. XCTAssertEqual(url, self.fakeDownloadURL)
  657. }
  658. let completionExpectation = expectation(description: "Completion handler")
  659. completionExpectation.expectedFulfillmentCount = numberOfRequests
  660. for i in 1 ... numberOfRequests {
  661. modelDownloader.downloadInfoAndModel(modelName: "fakeModelName",
  662. modelInfoRetriever: modelInfoRetriever,
  663. downloader: downloader,
  664. conditions: conditions) { result in
  665. completionExpectation.fulfill()
  666. switch result {
  667. case let .success(model):
  668. self.checkModel(model: model)
  669. if i == numberOfRequests {
  670. self.deleteFile(at: model.path)
  671. }
  672. case let .failure(error):
  673. XCTFail("Error - \(error)")
  674. }
  675. }
  676. }
  677. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  678. downloader
  679. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  680. fileURL: tempFileURL)))
  681. wait(for: [completionExpectation], timeout: 1)
  682. try? ModelFileManager.removeFile(at: tempFileURL)
  683. }
  684. /// Get model if server returns an error due to model not found.
  685. func testModelDownloadWith404() {
  686. let tempFileURL = tempFile()
  687. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  688. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  689. statusCode: 404,
  690. httpVersion: nil,
  691. headerFields: nil)!
  692. let downloader = MockModelFileDownloader()
  693. let fileDownloaderExpectation = expectation(description: "File downloader")
  694. downloader.downloadFileHandler = { url in
  695. fileDownloaderExpectation.fulfill()
  696. XCTAssertEqual(url, self.fakeDownloadURL)
  697. }
  698. let completionExpectation = expectation(description: "Completion handler")
  699. let taskCompletion: ModelDownloadTask.Completion = { result in
  700. completionExpectation.fulfill()
  701. switch result {
  702. case .success: XCTFail("Unexpected successful model download.")
  703. case let .failure(error):
  704. XCTAssertEqual(error, .notFound)
  705. }
  706. }
  707. let testName = String(#function.dropLast(2))
  708. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  709. appName: "fakeAppName",
  710. defaults: .createUnitTestInstance(
  711. testName: testName
  712. ),
  713. downloader: downloader,
  714. completion: taskCompletion)
  715. modelDownloadTask.resume()
  716. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  717. downloader
  718. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  719. fileURL: tempFileURL)))
  720. wait(for: [completionExpectation], timeout: 0.5)
  721. try? ModelFileManager.removeFile(at: tempFileURL)
  722. }
  723. /// Get model if server returns an error due to model not found.
  724. func testModelDownloadNetworkError() {
  725. let tempFileURL = tempFile()
  726. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  727. let downloader = MockModelFileDownloader()
  728. let fileDownloaderExpectation = expectation(description: "File downloader")
  729. downloader.downloadFileHandler = { url in
  730. fileDownloaderExpectation.fulfill()
  731. XCTAssertEqual(url, self.fakeDownloadURL)
  732. }
  733. let completionExpectation = expectation(description: "Completion handler")
  734. let taskCompletion: ModelDownloadTask.Completion = { result in
  735. completionExpectation.fulfill()
  736. switch result {
  737. case .success: XCTFail("Unexpected successful model download.")
  738. case let .failure(error):
  739. XCTAssertEqual(error, .failedPrecondition)
  740. }
  741. }
  742. let testName = String(#function.dropLast(2))
  743. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  744. appName: "fakeAppName",
  745. defaults: .createUnitTestInstance(
  746. testName: testName
  747. ),
  748. downloader: downloader,
  749. completion: taskCompletion)
  750. modelDownloadTask.resume()
  751. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  752. let offlineError = DownloadError.internalError(description: "Network appears to be offline.")
  753. downloader
  754. .completion?(.failure(FileDownloaderError.networkError(offlineError)))
  755. wait(for: [completionExpectation], timeout: 0.5)
  756. try? ModelFileManager.removeFile(at: tempFileURL)
  757. }
  758. /// Get model with a valid server response.
  759. func testGetModelSuccessful() {
  760. guard let testApp = FirebaseApp.app() else {
  761. XCTFail("Default app was not configured.")
  762. return
  763. }
  764. testApp.isDataCollectionDefaultEnabled = false
  765. let modelDownloader = ModelDownloader.modelDownloader()
  766. let conditions = ModelDownloadConditions()
  767. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  768. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  769. let downloader = MockModelFileDownloader()
  770. let tempFileURL = tempFile()
  771. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  772. statusCode: 200,
  773. httpVersion: nil,
  774. headerFields: nil)!
  775. let fileDownloaderExpectation = expectation(description: "File downloader")
  776. let completionExpectation = expectation(description: "Completion handler")
  777. // Handles initial response.
  778. downloader.downloadFileHandler = { url in
  779. fileDownloaderExpectation.fulfill()
  780. XCTAssertEqual(url, self.fakeDownloadURL)
  781. }
  782. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  783. modelInfoRetriever: modelInfoRetriever,
  784. downloader: downloader,
  785. conditions: conditions) { result in
  786. completionExpectation.fulfill()
  787. switch result {
  788. case let .success(model):
  789. self.checkModel(model: model)
  790. self.deleteFile(at: model.path)
  791. case let .failure(error):
  792. XCTFail("Error - \(error)")
  793. }
  794. }
  795. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  796. downloader
  797. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  798. fileURL: tempFileURL)))
  799. wait(for: [completionExpectation], timeout: 0.5)
  800. try? ModelFileManager.removeFile(at: tempFileURL)
  801. }
  802. /// Get model if download url expired but retry was successful.
  803. func testGetModelSuccessfulRetry() {
  804. guard let testApp = FirebaseApp.app() else {
  805. XCTFail("Default app was not configured.")
  806. return
  807. }
  808. testApp.isDataCollectionDefaultEnabled = false
  809. let modelDownloader = ModelDownloader.modelDownloader()
  810. let conditions = ModelDownloadConditions()
  811. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  812. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  813. let downloader = MockModelFileDownloader()
  814. let tempFileURL = tempFile()
  815. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  816. statusCode: 400,
  817. httpVersion: nil,
  818. headerFields: nil)!
  819. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  820. statusCode: 200,
  821. httpVersion: nil,
  822. headerFields: nil)!
  823. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  824. let completionExpectation = expectation(description: "Completion handler.")
  825. // Handles initial response.
  826. downloader.downloadFileHandler = { url in
  827. fileDownloaderExpectation1.fulfill()
  828. XCTAssertEqual(url, self.fakeDownloadURL)
  829. }
  830. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  831. modelInfoRetriever: modelInfoRetriever,
  832. downloader: downloader,
  833. conditions: conditions) { result in
  834. completionExpectation.fulfill()
  835. switch result {
  836. // Retry successful - model returned.
  837. case let .success(model):
  838. self.checkModel(model: model)
  839. self.deleteFile(at: model.path)
  840. case let .failure(error):
  841. XCTFail("Error - \(error)")
  842. }
  843. }
  844. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  845. // Handles response after one retry.
  846. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  847. downloader.downloadFileHandler = { url in
  848. fileDownloaderExpectation2.fulfill()
  849. XCTAssertEqual(url, self.fakeDownloadURL)
  850. }
  851. downloader
  852. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  853. fileURL: tempFileURL)))
  854. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  855. downloader
  856. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  857. fileURL: tempFileURL)))
  858. wait(for: [completionExpectation], timeout: 0.5)
  859. try? ModelFileManager.removeFile(at: tempFileURL)
  860. }
  861. /// Get model if model doesn't exist.
  862. func testGetModelNotFound() {
  863. guard let testApp = FirebaseApp.app() else {
  864. XCTFail("Default app was not configured.")
  865. return
  866. }
  867. testApp.isDataCollectionDefaultEnabled = false
  868. let modelDownloader = ModelDownloader.modelDownloader()
  869. let conditions = ModelDownloadConditions()
  870. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  871. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  872. let downloader = MockModelFileDownloader()
  873. let tempFileURL = tempFile()
  874. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  875. statusCode: 404,
  876. httpVersion: nil,
  877. headerFields: nil)!
  878. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  879. let completionExpectation = expectation(description: "Completion handler")
  880. // Handles initial response.
  881. downloader.downloadFileHandler = { url in
  882. fileDownloaderExpectation.fulfill()
  883. XCTAssertEqual(url, self.fakeDownloadURL)
  884. }
  885. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  886. modelInfoRetriever: modelInfoRetriever,
  887. downloader: downloader,
  888. conditions: conditions) { result in
  889. completionExpectation.fulfill()
  890. switch result {
  891. // Retry failed due to another expired url - error returned.
  892. case .success: XCTFail("Unexpected successful model download.")
  893. case let .failure(error):
  894. switch error {
  895. case .notFound: break
  896. default: XCTFail("Expected failure due to model not found.")
  897. }
  898. }
  899. }
  900. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  901. downloader
  902. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  903. fileURL: tempFileURL)))
  904. wait(for: [completionExpectation], timeout: 0.5)
  905. try? ModelFileManager.removeFile(at: tempFileURL)
  906. }
  907. /// Get model if invalid or insufficient permissions.
  908. func testGetModelUnauthorized() {
  909. guard let testApp = FirebaseApp.app() else {
  910. XCTFail("Default app was not configured.")
  911. return
  912. }
  913. testApp.isDataCollectionDefaultEnabled = false
  914. let modelDownloader = ModelDownloader.modelDownloader()
  915. let conditions = ModelDownloadConditions()
  916. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  917. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  918. let downloader = MockModelFileDownloader()
  919. let tempFileURL = tempFile()
  920. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  921. statusCode: 403,
  922. httpVersion: nil,
  923. headerFields: nil)!
  924. let fileDownloaderExpectation = expectation(description: "File downloader")
  925. let completionExpectation = expectation(description: "Completion handler")
  926. // Handles initial response.
  927. downloader.downloadFileHandler = { url in
  928. fileDownloaderExpectation.fulfill()
  929. XCTAssertEqual(url, self.fakeDownloadURL)
  930. }
  931. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  932. modelInfoRetriever: modelInfoRetriever,
  933. downloader: downloader,
  934. conditions: conditions) { result in
  935. completionExpectation.fulfill()
  936. switch result {
  937. case .success: XCTFail("Unexpected successful model download.")
  938. case let .failure(error):
  939. switch error {
  940. case .permissionDenied: break
  941. default: XCTFail("Expected failure due to bad permissions.")
  942. }
  943. }
  944. }
  945. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  946. downloader
  947. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  948. fileURL: tempFileURL)))
  949. wait(for: [completionExpectation], timeout: 0.5)
  950. try? ModelFileManager.removeFile(at: tempFileURL)
  951. }
  952. /// Get model if download url expired and retry url also expired.
  953. func testGetModelFailedRetry() {
  954. guard let testApp = FirebaseApp.app() else {
  955. XCTFail("Default app was not configured.")
  956. return
  957. }
  958. testApp.isDataCollectionDefaultEnabled = false
  959. let modelDownloader = ModelDownloader.modelDownloader()
  960. let conditions = ModelDownloadConditions()
  961. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  962. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  963. let downloader = MockModelFileDownloader()
  964. let tempFileURL = tempFile()
  965. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  966. statusCode: 400,
  967. httpVersion: nil,
  968. headerFields: nil)!
  969. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  970. let completionExpectation = expectation(description: "Completion handler")
  971. // Handles initial response.
  972. downloader.downloadFileHandler = { url in
  973. fileDownloaderExpectation1.fulfill()
  974. XCTAssertEqual(url, self.fakeDownloadURL)
  975. }
  976. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  977. modelInfoRetriever: modelInfoRetriever,
  978. downloader: downloader,
  979. conditions: conditions) { result in
  980. completionExpectation.fulfill()
  981. switch result {
  982. // Retry failed due to another expired url - error returned.
  983. case .success: XCTFail("Unexpected successful model download.")
  984. case let .failure(error):
  985. switch error {
  986. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  987. "Unable to update expired model info.")
  988. default: XCTFail("Expected failure due to expired model info.")
  989. }
  990. }
  991. }
  992. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  993. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  994. // Handles response after one retry.
  995. downloader.downloadFileHandler = { url in
  996. fileDownloaderExpectation2.fulfill()
  997. XCTAssertEqual(url, self.fakeDownloadURL)
  998. }
  999. downloader
  1000. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1001. fileURL: tempFileURL)))
  1002. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1003. downloader
  1004. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1005. fileURL: tempFileURL)))
  1006. wait(for: [completionExpectation], timeout: 0.5)
  1007. try? ModelFileManager.removeFile(at: tempFileURL)
  1008. }
  1009. /// Get model if multiple retries all returned expired urls.
  1010. func testGetModelMultipleFailedRetries() {
  1011. guard let testApp = FirebaseApp.app() else {
  1012. XCTFail("Default app was not configured.")
  1013. return
  1014. }
  1015. testApp.isDataCollectionDefaultEnabled = false
  1016. let modelDownloader = ModelDownloader.modelDownloader()
  1017. modelDownloader.numberOfRetries = 2
  1018. let conditions = ModelDownloadConditions()
  1019. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  1020. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  1021. let downloader = MockModelFileDownloader()
  1022. let tempFileURL = tempFile()
  1023. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  1024. statusCode: 400,
  1025. httpVersion: nil,
  1026. headerFields: nil)!
  1027. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  1028. let completionExpectation = expectation(description: "Completion handler")
  1029. // Handles initial response.
  1030. downloader.downloadFileHandler = { url in
  1031. fileDownloaderExpectation1.fulfill()
  1032. XCTAssertEqual(url, self.fakeDownloadURL)
  1033. }
  1034. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  1035. modelInfoRetriever: modelInfoRetriever,
  1036. downloader: downloader,
  1037. conditions: conditions) { result in
  1038. completionExpectation.fulfill()
  1039. switch result {
  1040. // Two retries failed due to expired urls - error returned.
  1041. case .success: XCTFail("Unexpected successful model download.")
  1042. case let .failure(error):
  1043. switch error {
  1044. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  1045. "Unable to update expired model info.")
  1046. default: XCTFail("Expected failure due to expired model info.")
  1047. }
  1048. }
  1049. }
  1050. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  1051. let fileDownloaderExpectation2 = expectation(description: "file downloader after first retry")
  1052. // Handles response after first retry.
  1053. downloader.downloadFileHandler = { url in
  1054. fileDownloaderExpectation2.fulfill()
  1055. XCTAssertEqual(url, self.fakeDownloadURL)
  1056. }
  1057. downloader
  1058. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1059. fileURL: tempFileURL)))
  1060. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1061. let fileDownloaderExpectation3 =
  1062. expectation(description: "File downloader after second retry")
  1063. // Handles response after second retry.
  1064. downloader.downloadFileHandler = { url in
  1065. fileDownloaderExpectation3.fulfill()
  1066. XCTAssertEqual(url, self.fakeDownloadURL)
  1067. }
  1068. downloader
  1069. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1070. fileURL: tempFileURL)))
  1071. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  1072. downloader
  1073. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1074. fileURL: tempFileURL)))
  1075. wait(for: [completionExpectation], timeout: 0.5)
  1076. try? ModelFileManager.removeFile(at: tempFileURL)
  1077. }
  1078. /// Get model if a retry finally returns an unexpired download url.
  1079. func testGetModelMultipleRetriesWithSuccess() {
  1080. guard let testApp = FirebaseApp.app() else {
  1081. XCTFail("Default app was not configured.")
  1082. return
  1083. }
  1084. testApp.isDataCollectionDefaultEnabled = false
  1085. let modelDownloader = ModelDownloader.modelDownloader()
  1086. modelDownloader.numberOfRetries = 4
  1087. let conditions = ModelDownloadConditions()
  1088. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  1089. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  1090. let downloader = MockModelFileDownloader()
  1091. let tempFileURL = tempFile()
  1092. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  1093. statusCode: 400,
  1094. httpVersion: nil,
  1095. headerFields: nil)!
  1096. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  1097. statusCode: 200,
  1098. httpVersion: nil,
  1099. headerFields: nil)!
  1100. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  1101. let completionExpectation = expectation(description: "Completion handler")
  1102. // Handles initial response.
  1103. downloader.downloadFileHandler = { url in
  1104. fileDownloaderExpectation1.fulfill()
  1105. XCTAssertEqual(url, self.fakeDownloadURL)
  1106. }
  1107. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  1108. modelInfoRetriever: modelInfoRetriever,
  1109. downloader: downloader,
  1110. conditions: conditions) { result in
  1111. completionExpectation.fulfill()
  1112. switch result {
  1113. // One retry failed but the next retry succeeded - model returned.
  1114. case let .success(model):
  1115. self.checkModel(model: model)
  1116. self.deleteFile(at: model.path)
  1117. case let .failure(error):
  1118. XCTFail("Error - \(error)")
  1119. }
  1120. }
  1121. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  1122. let fileDownloaderExpectation2 = expectation(description: "File downloader after first retry")
  1123. // Handles response after first retry.
  1124. downloader.downloadFileHandler = { url in
  1125. fileDownloaderExpectation2.fulfill()
  1126. XCTAssertEqual(url, self.fakeDownloadURL)
  1127. }
  1128. downloader
  1129. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1130. fileURL: tempFileURL)))
  1131. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  1132. let fileDownloaderExpectation3 =
  1133. expectation(description: "File downloader after second retry")
  1134. // Handles response after second retry.
  1135. downloader.downloadFileHandler = { url in
  1136. fileDownloaderExpectation3.fulfill()
  1137. XCTAssertEqual(url, self.fakeDownloadURL)
  1138. }
  1139. downloader
  1140. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  1141. fileURL: tempFileURL)))
  1142. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  1143. downloader
  1144. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  1145. fileURL: tempFileURL)))
  1146. wait(for: [completionExpectation], timeout: 0.5)
  1147. try? ModelFileManager.removeFile(at: tempFileURL)
  1148. }
  1149. }
  1150. /// Mock URL session for testing.
  1151. class MockModelInfoRetrieverSession: ModelInfoRetrieverSession {
  1152. var data: Data?
  1153. var response: URLResponse?
  1154. var error: Error?
  1155. func getModelInfo(with request: URLRequest,
  1156. completion: @escaping (Data?, URLResponse?, Error?) -> Void) {
  1157. completion(data, response, error)
  1158. }
  1159. }
  1160. class MockModelFileDownloader: FileDownloader {
  1161. var progressHandler: ProgressHandler?
  1162. var completion: CompletionHandler?
  1163. var downloadFileHandler: ((_ url: URL) -> Void)?
  1164. func downloadFile(with url: URL, progressHandler: @escaping ProgressHandler,
  1165. completion: @escaping CompletionHandler) {
  1166. self.progressHandler = progressHandler
  1167. self.completion = completion
  1168. downloadFileHandler?(url)
  1169. }
  1170. }
  1171. extension UserDefaults {
  1172. /// Returns a new cleared instance of user defaults.
  1173. static func createUnitTestInstance(testName: String) -> UserDefaults {
  1174. let suiteName = "com.google.firebase.ml.test.\(testName)"
  1175. let defaults = UserDefaults(suiteName: suiteName)!
  1176. defaults.removePersistentDomain(forName: suiteName)
  1177. return defaults
  1178. }
  1179. /// Returns the existing user defaults instance.
  1180. static func getUnitTestInstance(testName: String) -> UserDefaults {
  1181. let suiteName = "com.google.firebase.ml.test.\(testName)"
  1182. return UserDefaults(suiteName: suiteName)!
  1183. }
  1184. }
  1185. // MARK: - Helpers
  1186. extension ModelDownloaderUnitTests {
  1187. func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int,
  1188. customJSON: String? = nil) -> MockModelInfoRetrieverSession {
  1189. let fakeSession = MockModelInfoRetrieverSession()
  1190. let fakeJSON = customJSON ?? fakeModelJSON
  1191. fakeSession.data = fakeJSON.data(using: .utf8)
  1192. fakeSession.response = HTTPURLResponse(
  1193. url: url,
  1194. statusCode: statusCode,
  1195. httpVersion: nil,
  1196. headerFields: ["Etag": "fakeModelHash"]
  1197. )
  1198. fakeSession.error = nil
  1199. return fakeSession
  1200. }
  1201. func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
  1202. fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
  1203. // TODO: Replace with fake to check if it was used correctly by the download task.
  1204. let modelInfoRetriever = ModelInfoRetriever(
  1205. modelName: fakeModelName,
  1206. projectID: fakeProjectID,
  1207. apiKey: fakeAPIKey,
  1208. appName: fakeAppName, authTokenProvider: successAuthTokenProvider,
  1209. session: fakeSession, localModelInfo: fakeLocalModelInfo
  1210. )
  1211. return modelInfoRetriever
  1212. }
  1213. func checkModel(model: CustomModel) {
  1214. XCTAssertEqual(model.name, fakeModelName)
  1215. XCTAssertEqual(model.size, fakeModelSize)
  1216. XCTAssertEqual(model.hash, fakeModelHash)
  1217. let modelURL = URL(fileURLWithPath: model.path)
  1218. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  1219. }
  1220. func tempFile(name: String = "fake-model-file.tmp") -> URL {
  1221. let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
  1222. isDirectory: true)
  1223. let tempFileURL = tempDirectoryURL.appendingPathComponent(name)
  1224. let tempData: Data = "fakeModelData".data(using: .utf8)!
  1225. try? tempData.write(to: tempFileURL)
  1226. return tempFileURL
  1227. }
  1228. func tempModelFile(invalid: Bool = false) -> URL {
  1229. let modelDirectoryURL = ModelFileManager.modelsDirectory!
  1230. let modelFileName = invalid ? "fbml_model_fake-model-file.tmp" :
  1231. "fbml_model@@__FIRAPP_DEFAULT@@\(fakeModelName)"
  1232. let tempFileURL = modelDirectoryURL.appendingPathComponent(modelFileName)
  1233. .appendingPathExtension("tflite")
  1234. let tempData: Data = "fakeModelData".data(using: .utf8)!
  1235. try? tempData.write(to: tempFileURL)
  1236. return tempFileURL
  1237. }
  1238. func deleteFile(at path: String) {
  1239. let url = URL(fileURLWithPath: path)
  1240. do {
  1241. try ModelFileManager.removeFile(at: url)
  1242. } catch {
  1243. XCTFail("Failed to delete file at \(path).")
  1244. }
  1245. }
  1246. func fakeLocalModelInfo(mismatch: Bool = false) -> LocalModelInfo {
  1247. let hash = mismatch ? "wrongModelHash" : fakeModelHash
  1248. return LocalModelInfo(
  1249. name: fakeModelName,
  1250. modelHash: hash,
  1251. size: 20
  1252. )
  1253. }
  1254. func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
  1255. let expiryTime = expired ? Date() : Date().addingTimeInterval(1000)
  1256. let size = oversized ? Int(Int64.max) / 4 : fakeModelSize
  1257. return RemoteModelInfo(name: fakeModelName,
  1258. downloadURL: fakeDownloadURL,
  1259. modelHash: fakeModelHash,
  1260. size: size,
  1261. urlExpiryTime: expiryTime)
  1262. }
  1263. }
  1264. #endif // !targetEnvironment(macCatalyst) && !os(macOS)