ModelDownloaderUnitTests.swift 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. /// Mock options to configure default Firebase app.
  19. private enum MockOptions {
  20. static let appID = "1:123:ios:123abc"
  21. static let gcmSenderID = "mock-sender-id"
  22. static let projectID = "mock-project-id"
  23. static let apiKey = "ABcdEf-APIKeyWithValidFormat_0123456789"
  24. }
  25. // TODO: Create separate files for each class tested.
  26. final class ModelDownloaderUnitTests: XCTestCase {
  27. let fakeModelName = "fakeModelName"
  28. let fakeModelHash = "fakeModelHash"
  29. let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
  30. let fakeFileURL = URL(string: "www.fake-model-file.com")!
  31. let fakeExpiryTime = "2021-01-20T04:20:10.220Z"
  32. let fakeModelSize = 20
  33. let fakeProjectID = "fakeProjectID"
  34. let fakeAPIKey = "fakeAPIKey"
  35. var fakeModelJSON: String {
  36. """
  37. {
  38. "downloadUri":"\(fakeDownloadURL)",
  39. "expireTime":"\(fakeExpiryTime)",
  40. "sizeBytes":"\(fakeModelSize)"
  41. }
  42. """
  43. }
  44. let successAuthTokenProvider =
  45. { (completion: @escaping (Result<String, DownloadError>) -> Void) in
  46. completion(.success("fakeFISToken"))
  47. }
  48. override class func setUp() {
  49. let options = FirebaseOptions(
  50. googleAppID: MockOptions.appID,
  51. gcmSenderID: MockOptions.gcmSenderID
  52. )
  53. options.apiKey = MockOptions.apiKey
  54. options.projectID = MockOptions.projectID
  55. FirebaseApp.configure(options: options)
  56. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  57. }
  58. override class func tearDown() {
  59. FirebaseApp.app()?.delete { _ in }
  60. }
  61. /// Test if the same downloader instance is returned for an app.
  62. func testModelDownloaderInit() {
  63. guard let testApp = FirebaseApp.app() else {
  64. XCTFail("Default app was not configured.")
  65. return
  66. }
  67. let modelDownloader1 = ModelDownloader.modelDownloader(app: testApp)
  68. let modelDownloader2 = ModelDownloader.modelDownloader()
  69. XCTAssert(modelDownloader1 === modelDownloader2)
  70. }
  71. /// Test if app deletion removes associated downloader instance.
  72. func testAppDeletion() {
  73. guard let testApp = FirebaseApp.app() else {
  74. XCTFail("Default app was not configured.")
  75. return
  76. }
  77. let modelDownloader1 = ModelDownloader.modelDownloader()
  78. testApp.delete { success in
  79. XCTAssertTrue(success)
  80. }
  81. ModelDownloaderUnitTests.setUp()
  82. let modelDownloader2 = ModelDownloader.modelDownloader()
  83. XCTAssert(modelDownloader1 !== modelDownloader2)
  84. }
  85. /// Test invalid model file deletion.
  86. func testDeleteInvalidModel() {
  87. let modelDownloader = ModelDownloader.modelDownloader()
  88. modelDownloader.deleteDownloadedModel(name: "fake-model-name") { result in
  89. switch result {
  90. case .success: XCTFail("Unexpected successful model deletion.")
  91. case let .failure(error):
  92. XCTAssertNotNil(error)
  93. }
  94. }
  95. }
  96. /// Test listing models in model directory.
  97. // TODO: Add unit test.
  98. func testListModels() {
  99. let modelDownloader = ModelDownloader.modelDownloader()
  100. modelDownloader.listDownloadedModels { result in
  101. switch result {
  102. case .success: break
  103. case .failure: break
  104. }
  105. }
  106. }
  107. func testGetModel() {
  108. // This is an example of a functional test case.
  109. // Use XCTAssert and related functions to verify your tests produce the correct
  110. // results.
  111. guard let testApp = FirebaseApp.app() else {
  112. XCTFail("Default app was not configured.")
  113. return
  114. }
  115. let modelDownloader = ModelDownloader.modelDownloader()
  116. let modelDownloaderWithApp = ModelDownloader.modelDownloader(app: testApp)
  117. /// These should point to the same instance.
  118. XCTAssert(modelDownloader === modelDownloaderWithApp)
  119. let conditions = ModelDownloadConditions()
  120. // Download model w/ progress handler
  121. modelDownloader.getModel(
  122. name: "your_model_name",
  123. downloadType: .latestModel,
  124. conditions: conditions,
  125. progressHandler: { progress in
  126. // Handle progress
  127. }
  128. ) { result in
  129. switch result {
  130. case .success:
  131. // Use model with your inference API
  132. // let interpreter = Interpreter(modelPath: customModel.modelPath)
  133. break
  134. case .failure:
  135. // Handle download error
  136. break
  137. }
  138. }
  139. // Access array of downloaded models
  140. modelDownloaderWithApp.listDownloadedModels { result in
  141. switch result {
  142. case .success:
  143. // Pick model(s) for further use
  144. break
  145. case .failure:
  146. // Handle failure
  147. break
  148. }
  149. }
  150. // Delete downloaded model
  151. modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
  152. switch result {
  153. case .success():
  154. // Apply any other clean up
  155. break
  156. case .failure:
  157. // Handle failure
  158. break
  159. }
  160. }
  161. }
  162. /// Test on-device logging.
  163. func testDeviceLogging() {
  164. DeviceLogger.logEvent(level: .error,
  165. message: "This is a test error logged to device.",
  166. messageCode: .testError)
  167. }
  168. /// Compare proto serialization methods.
  169. func testTelemetryEncoding() {
  170. let fakeModel = CustomModel(
  171. name: "fakeModelName",
  172. size: 10,
  173. path: "fakeModelPath",
  174. hash: "fakeModelHash"
  175. )
  176. var modelOptions = ModelOptions()
  177. modelOptions.setModelOptions(model: fakeModel)
  178. guard let binaryData = try? modelOptions.serializedData(),
  179. let jsonData = try? modelOptions.jsonUTF8Data(),
  180. let binaryEvent = try? ModelOptions(serializedData: binaryData),
  181. let jsonEvent = try? ModelOptions(jsonUTF8Data: jsonData) else {
  182. XCTFail("Encoding error.")
  183. return
  184. }
  185. XCTAssertNotNil(binaryData)
  186. XCTAssertNotNil(jsonData)
  187. XCTAssertLessThan(binaryData.count, jsonData.count)
  188. XCTAssertEqual(binaryEvent, jsonEvent)
  189. }
  190. /// Get model info if server returns a new model info.
  191. func testGetModelInfoWith200() {
  192. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  193. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  194. let completionExpectation = expectation(description: "Completion handler")
  195. modelInfoRetriever.downloadModelInfo { result in
  196. completionExpectation.fulfill()
  197. switch result {
  198. case let .success(fakemodelInfoResult):
  199. switch fakemodelInfoResult {
  200. case let .modelInfo(remoteModelInfo):
  201. XCTAssertEqual(remoteModelInfo.name, self.fakeModelName)
  202. XCTAssertEqual(remoteModelInfo.downloadURL, self.fakeDownloadURL)
  203. XCTAssertEqual(remoteModelInfo.size, self.fakeModelSize)
  204. XCTAssertEqual(remoteModelInfo.modelHash, self.fakeModelHash)
  205. case .notModified: XCTFail("Expected new model info from server.")
  206. }
  207. case let .failure(error): XCTFail("Failed with error - \(error)")
  208. }
  209. }
  210. wait(for: [completionExpectation], timeout: 0.5)
  211. }
  212. /// Get model info if model info is not modified.
  213. func testGetModelInfoWith304() {
  214. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  215. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  216. let modelInfoRetriever = fakeModelRetriever(
  217. fakeSession: session,
  218. fakeLocalModelInfo: localModelInfo
  219. )
  220. let completionExpectation = expectation(description: "Completion handler")
  221. modelInfoRetriever.downloadModelInfo { result in
  222. completionExpectation.fulfill()
  223. switch result {
  224. case let .success(fakemodelInfoResult):
  225. switch fakemodelInfoResult {
  226. case .modelInfo: XCTFail("Unexpected new model info from server.")
  227. case .notModified: break
  228. }
  229. case let .failure(error): XCTFail("Failed with error - \(error)")
  230. }
  231. }
  232. wait(for: [completionExpectation], timeout: 0.5)
  233. }
  234. /// Get model info if model info is not modified but local model info is not set.
  235. func testGetModelInfoWith304MissingLocalInfo() {
  236. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  237. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  238. let completionExpectation = expectation(description: "Completion handler")
  239. modelInfoRetriever.downloadModelInfo { result in
  240. completionExpectation.fulfill()
  241. switch result {
  242. case let .success(fakemodelInfoResult):
  243. switch fakemodelInfoResult {
  244. case .modelInfo: XCTFail("Unexpected new model info from server.")
  245. case .notModified: XCTFail("Expected failure since local model info was not set.")
  246. }
  247. case let .failure(error):
  248. switch error {
  249. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  250. "Model info was deleted unexpectedly.")
  251. default: XCTFail("Expected failure since local model info was not set.")
  252. }
  253. }
  254. }
  255. wait(for: [completionExpectation], timeout: 0.5)
  256. }
  257. /// Get model info if model info is not modified but local model info is not set.
  258. func testGetModelInfoWith304HashMismatch() {
  259. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  260. let localModelInfo = fakeLocalModelInfo(mismatch: true)
  261. let modelInfoRetriever = fakeModelRetriever(
  262. fakeSession: session,
  263. fakeLocalModelInfo: localModelInfo
  264. )
  265. let completionExpectation = expectation(description: "Completion handler")
  266. modelInfoRetriever.downloadModelInfo { result in
  267. completionExpectation.fulfill()
  268. switch result {
  269. case let .success(fakemodelInfoResult):
  270. switch fakemodelInfoResult {
  271. case .modelInfo: XCTFail("Unexpected new model info from server.")
  272. case .notModified: XCTFail("Expected failure since model hash does not match.")
  273. }
  274. case let .failure(error):
  275. switch error {
  276. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  277. "Unexpected model hash value.")
  278. default: XCTFail("Expected failure since local model info was not set.")
  279. }
  280. }
  281. }
  282. wait(for: [completionExpectation], timeout: 0.5)
  283. }
  284. /// Get model info if model name is invalid.
  285. func testGetModelInfoWith400() {
  286. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 400)
  287. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  288. let modelInfoRetriever = fakeModelRetriever(
  289. fakeSession: session,
  290. fakeLocalModelInfo: localModelInfo
  291. )
  292. let completionExpectation = expectation(description: "Completion handler")
  293. modelInfoRetriever.downloadModelInfo { result in
  294. completionExpectation.fulfill()
  295. switch result {
  296. case let .success(fakemodelInfoResult):
  297. switch fakemodelInfoResult {
  298. case .modelInfo: XCTFail("Unexpected new model info from server.")
  299. case .notModified: XCTFail("Expected failure since model name is invalid.")
  300. }
  301. case let .failure(error):
  302. XCTAssertEqual(error, .invalidArgument)
  303. }
  304. }
  305. wait(for: [completionExpectation], timeout: 0.5)
  306. }
  307. /// Get model if server returns a model file.
  308. func testModelDownloadWith200() {
  309. let tempFileURL = tempFile()
  310. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  311. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  312. statusCode: 200,
  313. httpVersion: nil,
  314. headerFields: nil)!
  315. let downloader = MockModelFileDownloader()
  316. let fileDownloaderExpectation = expectation(description: "File downloader")
  317. downloader.downloadFileHandler = { url in
  318. fileDownloaderExpectation.fulfill()
  319. XCTAssertEqual(url, self.fakeDownloadURL)
  320. }
  321. let progressExpectation = expectation(description: "Progress handler")
  322. progressExpectation.expectedFulfillmentCount = 2
  323. let completionExpectation = expectation(description: "Completion handler")
  324. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  325. appName: "fakeAppName",
  326. defaults: .createTestInstance(testName: #function),
  327. downloader: downloader)
  328. modelDownloadTask.download(progressHandler: {
  329. progress in
  330. progressExpectation.fulfill()
  331. XCTAssertEqual(progress, 0.4)
  332. }) { result in
  333. completionExpectation.fulfill()
  334. switch result {
  335. case let .success(model):
  336. XCTAssertEqual(model.name, self.fakeModelName)
  337. XCTAssertEqual(model.size, self.fakeModelSize)
  338. XCTAssertEqual(model.hash, self.fakeModelHash)
  339. let modelPath = URL(string: model.path)!
  340. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  341. try? self.deleteFile(at: modelPath)
  342. case let .failure(error):
  343. XCTFail("Error - \(error)")
  344. }
  345. }
  346. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  347. downloader.progressHandler?(100, 250)
  348. downloader.progressHandler?(100, 250)
  349. wait(for: [progressExpectation], timeout: 0.5)
  350. downloader
  351. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  352. fileURL: tempFileURL)))
  353. wait(for: [completionExpectation], timeout: 0.5)
  354. try? ModelFileManager.removeFile(at: tempFileURL)
  355. }
  356. /// Get model if server returns an error due to expired url.
  357. func testModelDownloadWith400Expired() {
  358. let tempFileURL = tempFile()
  359. let remoteModelInfo = fakeModelInfo(expired: true, oversized: false)
  360. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  361. statusCode: 400,
  362. httpVersion: nil,
  363. headerFields: nil)!
  364. let downloader = MockModelFileDownloader()
  365. let fileDownloaderExpectation = expectation(description: "File downloader")
  366. downloader.downloadFileHandler = { url in
  367. fileDownloaderExpectation.fulfill()
  368. XCTAssertEqual(url, self.fakeDownloadURL)
  369. }
  370. let completionExpectation = expectation(description: "Completion handler")
  371. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  372. appName: "fakeAppName",
  373. defaults: .createTestInstance(testName: #function),
  374. downloader: downloader)
  375. modelDownloadTask.download(progressHandler: nil) { result in
  376. completionExpectation.fulfill()
  377. switch result {
  378. case .success: XCTFail("Unexpected successful model download.")
  379. case let .failure(error):
  380. XCTAssertEqual(error, .expiredDownloadURL)
  381. }
  382. }
  383. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  384. downloader
  385. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  386. fileURL: tempFileURL)))
  387. wait(for: [completionExpectation], timeout: 0.5)
  388. try? ModelFileManager.removeFile(at: tempFileURL)
  389. }
  390. /// Get model if server returns an error due to invalid model name.
  391. func testModelDownloadWith400Invalid() {
  392. let tempFileURL = tempFile()
  393. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  394. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  395. statusCode: 400,
  396. httpVersion: nil,
  397. headerFields: nil)!
  398. let downloader = MockModelFileDownloader()
  399. let fileDownloaderExpectation = expectation(description: "File downloader")
  400. downloader.downloadFileHandler = { url in
  401. fileDownloaderExpectation.fulfill()
  402. XCTAssertEqual(url, self.fakeDownloadURL)
  403. }
  404. let completionExpectation = expectation(description: "Completion handler")
  405. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  406. appName: "fakeAppName",
  407. defaults: .createTestInstance(testName: #function),
  408. downloader: downloader)
  409. modelDownloadTask.download(progressHandler: nil) { result in
  410. completionExpectation.fulfill()
  411. switch result {
  412. case .success: XCTFail("Unexpected successful model download.")
  413. case let .failure(error):
  414. XCTAssertEqual(error, .invalidArgument)
  415. }
  416. }
  417. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  418. downloader
  419. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  420. fileURL: tempFileURL)))
  421. wait(for: [completionExpectation], timeout: 0.5)
  422. try? ModelFileManager.removeFile(at: tempFileURL)
  423. }
  424. /// Get model if server returns an error due to model not found.
  425. func testModelDownloadWith404() {
  426. let tempFileURL = tempFile()
  427. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  428. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  429. statusCode: 404,
  430. httpVersion: nil,
  431. headerFields: nil)!
  432. let downloader = MockModelFileDownloader()
  433. let fileDownloaderExpectation = expectation(description: "File downloader")
  434. downloader.downloadFileHandler = { url in
  435. fileDownloaderExpectation.fulfill()
  436. XCTAssertEqual(url, self.fakeDownloadURL)
  437. }
  438. let completionExpectation = expectation(description: "Completion handler")
  439. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  440. appName: "fakeAppName",
  441. defaults: .createTestInstance(testName: #function),
  442. downloader: downloader)
  443. modelDownloadTask.download(progressHandler: nil) { result in
  444. completionExpectation.fulfill()
  445. switch result {
  446. case .success: XCTFail("Unexpected successful model download.")
  447. case let .failure(error):
  448. XCTAssertEqual(error, .notFound)
  449. }
  450. }
  451. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  452. downloader
  453. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  454. fileURL: tempFileURL)))
  455. wait(for: [completionExpectation], timeout: 0.5)
  456. try? ModelFileManager.removeFile(at: tempFileURL)
  457. }
  458. /// Get model if download url expired but retry was successful.
  459. func testGetModelSuccessfulRetry() {
  460. let modelDownloader = ModelDownloader.modelDownloader()
  461. let conditions = ModelDownloadConditions()
  462. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  463. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  464. let downloader = MockModelFileDownloader()
  465. let tempFileURL = tempFile()
  466. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  467. statusCode: 400,
  468. httpVersion: nil,
  469. headerFields: nil)!
  470. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  471. statusCode: 200,
  472. httpVersion: nil,
  473. headerFields: nil)!
  474. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  475. let completionExpectation = expectation(description: "Completion handler.")
  476. // Handles initial response.
  477. downloader.downloadFileHandler = { url in
  478. fileDownloaderExpectation1.fulfill()
  479. XCTAssertEqual(url, self.fakeDownloadURL)
  480. }
  481. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  482. modelInfoRetriever: modelInfoRetriever,
  483. downloader: downloader,
  484. conditions: conditions) { result in
  485. completionExpectation.fulfill()
  486. switch result {
  487. // Retry successful - model returned.
  488. case let .success(model):
  489. XCTAssertEqual(model.name, self.fakeModelName)
  490. XCTAssertEqual(model.size, self.fakeModelSize)
  491. XCTAssertEqual(model.hash, self.fakeModelHash)
  492. let modelPath = URL(string: model.path)!
  493. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  494. try? self.deleteFile(at: modelPath)
  495. case let .failure(error):
  496. XCTFail("Error - \(error)")
  497. }
  498. }
  499. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  500. // Handles response after one retry.
  501. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  502. downloader.downloadFileHandler = { url in
  503. fileDownloaderExpectation2.fulfill()
  504. XCTAssertEqual(url, self.fakeDownloadURL)
  505. }
  506. downloader
  507. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  508. fileURL: tempFileURL)))
  509. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  510. downloader
  511. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  512. fileURL: tempFileURL)))
  513. wait(for: [completionExpectation], timeout: 0.5)
  514. try? ModelFileManager.removeFile(at: tempFileURL)
  515. }
  516. /// Get model if model doesn't exist.
  517. func testGetModelNotFound() {
  518. let modelDownloader = ModelDownloader.modelDownloader()
  519. let conditions = ModelDownloadConditions()
  520. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  521. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  522. let downloader = MockModelFileDownloader()
  523. let tempFileURL = tempFile()
  524. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  525. statusCode: 404,
  526. httpVersion: nil,
  527. headerFields: nil)!
  528. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  529. let completionExpectation = expectation(description: "Completion handler")
  530. // Handles initial response.
  531. downloader.downloadFileHandler = { url in
  532. fileDownloaderExpectation.fulfill()
  533. XCTAssertEqual(url, self.fakeDownloadURL)
  534. }
  535. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  536. modelInfoRetriever: modelInfoRetriever,
  537. downloader: downloader,
  538. conditions: conditions) { result in
  539. completionExpectation.fulfill()
  540. switch result {
  541. // Retry failed due to another expired url - error returned.
  542. case .success: XCTFail("Unexpected successful model download.")
  543. case let .failure(error):
  544. switch error {
  545. case .notFound: break
  546. default: XCTFail("Expected failure due to model not found.")
  547. }
  548. }
  549. }
  550. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  551. downloader
  552. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  553. fileURL: tempFileURL)))
  554. wait(for: [completionExpectation], timeout: 0.5)
  555. try? ModelFileManager.removeFile(at: tempFileURL)
  556. }
  557. /// Get model if invalid or insufficient permissions.
  558. func testGetModelUnauthorized() {
  559. let modelDownloader = ModelDownloader.modelDownloader()
  560. let conditions = ModelDownloadConditions()
  561. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  562. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  563. let downloader = MockModelFileDownloader()
  564. let tempFileURL = tempFile()
  565. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  566. statusCode: 403,
  567. httpVersion: nil,
  568. headerFields: nil)!
  569. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  570. let completionExpectation = expectation(description: "Completion handler")
  571. // Handles initial response.
  572. downloader.downloadFileHandler = { url in
  573. fileDownloaderExpectation.fulfill()
  574. XCTAssertEqual(url, self.fakeDownloadURL)
  575. }
  576. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  577. modelInfoRetriever: modelInfoRetriever,
  578. downloader: downloader,
  579. conditions: conditions) { result in
  580. completionExpectation.fulfill()
  581. switch result {
  582. // Retry failed due to another expired url - error returned.
  583. case .success: XCTFail("Unexpected successful model download.")
  584. case let .failure(error):
  585. switch error {
  586. case .permissionDenied: break
  587. default: XCTFail("Expected failure due to bad permissions.")
  588. }
  589. }
  590. }
  591. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  592. downloader
  593. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  594. fileURL: tempFileURL)))
  595. wait(for: [completionExpectation], timeout: 0.5)
  596. try? ModelFileManager.removeFile(at: tempFileURL)
  597. }
  598. /// Get model if download url expired and retry url also expired.
  599. func testGetModelFailedRetry() {
  600. let modelDownloader = ModelDownloader.modelDownloader()
  601. let conditions = ModelDownloadConditions()
  602. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  603. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  604. let downloader = MockModelFileDownloader()
  605. let tempFileURL = tempFile()
  606. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  607. statusCode: 400,
  608. httpVersion: nil,
  609. headerFields: nil)!
  610. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  611. let completionExpectation = expectation(description: "Completion handler")
  612. // Handles initial response.
  613. downloader.downloadFileHandler = { url in
  614. fileDownloaderExpectation1.fulfill()
  615. XCTAssertEqual(url, self.fakeDownloadURL)
  616. }
  617. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  618. modelInfoRetriever: modelInfoRetriever,
  619. downloader: downloader,
  620. conditions: conditions) { result in
  621. completionExpectation.fulfill()
  622. switch result {
  623. // Retry failed due to another expired url - error returned.
  624. case .success: XCTFail("Unexpected successful model download.")
  625. case let .failure(error):
  626. switch error {
  627. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  628. "Unable to update expired model info.")
  629. default: XCTFail("Expected failure due to expired model info.")
  630. }
  631. }
  632. }
  633. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  634. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  635. // Handles response after one retry.
  636. downloader.downloadFileHandler = { url in
  637. fileDownloaderExpectation2.fulfill()
  638. XCTAssertEqual(url, self.fakeDownloadURL)
  639. }
  640. downloader
  641. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  642. fileURL: tempFileURL)))
  643. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  644. downloader
  645. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  646. fileURL: tempFileURL)))
  647. wait(for: [completionExpectation], timeout: 0.5)
  648. try? ModelFileManager.removeFile(at: tempFileURL)
  649. }
  650. /// Get model if multiple retries all returned expired urls.
  651. func testGetModelMultipleFailedRetries() {
  652. let modelDownloader = ModelDownloader.modelDownloader()
  653. modelDownloader.numberOfRetries = 2
  654. let conditions = ModelDownloadConditions()
  655. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  656. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  657. let downloader = MockModelFileDownloader()
  658. let tempFileURL = tempFile()
  659. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  660. statusCode: 400,
  661. httpVersion: nil,
  662. headerFields: nil)!
  663. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  664. let completionExpectation = expectation(description: "Completion handler")
  665. // Handles initial response.
  666. downloader.downloadFileHandler = { url in
  667. fileDownloaderExpectation1.fulfill()
  668. XCTAssertEqual(url, self.fakeDownloadURL)
  669. }
  670. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  671. modelInfoRetriever: modelInfoRetriever,
  672. downloader: downloader,
  673. conditions: conditions) { result in
  674. completionExpectation.fulfill()
  675. switch result {
  676. // Two retries failed due to expired urls - error returned.
  677. case .success: XCTFail("Unexpected successful model download.")
  678. case let .failure(error):
  679. switch error {
  680. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  681. "Unable to update expired model info.")
  682. default: XCTFail("Expected failure due to expired model info.")
  683. }
  684. }
  685. }
  686. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  687. let fileDownloaderExpectation2 = expectation(description: "file downloader after first retry")
  688. // Handles response after first retry.
  689. downloader.downloadFileHandler = { url in
  690. fileDownloaderExpectation2.fulfill()
  691. XCTAssertEqual(url, self.fakeDownloadURL)
  692. }
  693. downloader
  694. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  695. fileURL: tempFileURL)))
  696. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  697. let fileDownloaderExpectation3 = expectation(description: "File downloader after second retry")
  698. // Handles response after second retry.
  699. downloader.downloadFileHandler = { url in
  700. fileDownloaderExpectation3.fulfill()
  701. XCTAssertEqual(url, self.fakeDownloadURL)
  702. }
  703. downloader
  704. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  705. fileURL: tempFileURL)))
  706. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  707. downloader
  708. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  709. fileURL: tempFileURL)))
  710. wait(for: [completionExpectation], timeout: 0.5)
  711. try? ModelFileManager.removeFile(at: tempFileURL)
  712. }
  713. /// Get model if a retry finally returns an unexpired download url.
  714. func testGetModelMultipleRetriesWithSuccess() {
  715. let modelDownloader = ModelDownloader.modelDownloader()
  716. modelDownloader.numberOfRetries = 4
  717. let conditions = ModelDownloadConditions()
  718. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  719. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  720. let downloader = MockModelFileDownloader()
  721. let tempFileURL = tempFile()
  722. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  723. statusCode: 400,
  724. httpVersion: nil,
  725. headerFields: nil)!
  726. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  727. statusCode: 200,
  728. httpVersion: nil,
  729. headerFields: nil)!
  730. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  731. let completionExpectation = expectation(description: "Completion handler")
  732. // Handles initial response.
  733. downloader.downloadFileHandler = { url in
  734. fileDownloaderExpectation1.fulfill()
  735. XCTAssertEqual(url, self.fakeDownloadURL)
  736. }
  737. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  738. modelInfoRetriever: modelInfoRetriever,
  739. downloader: downloader,
  740. conditions: conditions) { result in
  741. completionExpectation.fulfill()
  742. switch result {
  743. // One retry failed but the next retry succeeded - model returned.
  744. case let .success(model):
  745. XCTAssertEqual(model.name, self.fakeModelName)
  746. XCTAssertEqual(model.size, self.fakeModelSize)
  747. XCTAssertEqual(model.hash, self.fakeModelHash)
  748. let modelPath = URL(string: model.path)!
  749. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  750. try? self.deleteFile(at: modelPath)
  751. case let .failure(error):
  752. XCTFail("Error - \(error)")
  753. }
  754. }
  755. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  756. let fileDownloaderExpectation2 = expectation(description: "File downloader after first retry")
  757. // Handles response after first retry.
  758. downloader.downloadFileHandler = { url in
  759. fileDownloaderExpectation2.fulfill()
  760. XCTAssertEqual(url, self.fakeDownloadURL)
  761. }
  762. downloader
  763. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  764. fileURL: tempFileURL)))
  765. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  766. let fileDownloaderExpectation3 = expectation(description: "File downloader after second retry")
  767. // Handles response after second retry.
  768. downloader.downloadFileHandler = { url in
  769. fileDownloaderExpectation3.fulfill()
  770. XCTAssertEqual(url, self.fakeDownloadURL)
  771. }
  772. downloader
  773. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  774. fileURL: tempFileURL)))
  775. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  776. downloader
  777. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  778. fileURL: tempFileURL)))
  779. wait(for: [completionExpectation], timeout: 0.5)
  780. try? ModelFileManager.removeFile(at: tempFileURL)
  781. }
  782. }
  783. /// Mock URL session for testing.
  784. class MockModelInfoRetrieverSession: ModelInfoRetrieverSession {
  785. var data: Data?
  786. var response: URLResponse?
  787. var error: Error?
  788. func getModelInfo(with request: URLRequest,
  789. completion: @escaping (Data?, URLResponse?, Error?) -> Void) {
  790. completion(data, response, error)
  791. }
  792. }
  793. class MockModelFileDownloader: FileDownloader {
  794. var progressHandler: ProgressHandler?
  795. var completion: CompletionHandler?
  796. var downloadFileHandler: ((_ url: URL) -> Void)?
  797. func downloadFile(with url: URL, progressHandler: @escaping ProgressHandler,
  798. completion: @escaping CompletionHandler) {
  799. self.progressHandler = progressHandler
  800. self.completion = completion
  801. downloadFileHandler?(url)
  802. }
  803. }
  804. extension UserDefaults {
  805. /// Returns a new cleared instance of user defaults.
  806. static func createTestInstance(testName: String) -> UserDefaults {
  807. let suiteName = "com.google.firebase.ml.test.\(testName)"
  808. // TODO: reconsider force unwrapping
  809. let defaults = UserDefaults(suiteName: suiteName)!
  810. defaults.removePersistentDomain(forName: suiteName)
  811. return defaults
  812. }
  813. }
  814. // MARK: - Helpers
  815. extension ModelDownloaderUnitTests {
  816. func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int) -> MockModelInfoRetrieverSession {
  817. let fakeSession = MockModelInfoRetrieverSession()
  818. fakeSession.data = fakeModelJSON.data(using: .utf8)
  819. fakeSession.response = HTTPURLResponse(
  820. url: url,
  821. statusCode: statusCode,
  822. httpVersion: nil,
  823. headerFields: ["Etag": "fakeModelHash"]
  824. )
  825. fakeSession.error = nil
  826. return fakeSession
  827. }
  828. func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
  829. fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
  830. // TODO: Replace with a fake one so we can check that is was used correctly by the download task.
  831. let modelInfoRetriever = ModelInfoRetriever(
  832. modelName: "fakeModelName",
  833. projectID: "fakeProjectID",
  834. apiKey: "fakeAPIKey",
  835. authTokenProvider: successAuthTokenProvider,
  836. appName: "fakeAppName",
  837. localModelInfo: fakeLocalModelInfo,
  838. session: fakeSession
  839. )
  840. return modelInfoRetriever
  841. }
  842. func tempFile() -> URL {
  843. let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
  844. isDirectory: true)
  845. let tempFileURL = tempDirectoryURL.appendingPathComponent("fake-model-file.tmp")
  846. let tempData: Data = "fakeModelData".data(using: .utf8)!
  847. try? tempData.write(to: tempFileURL)
  848. return tempFileURL
  849. }
  850. func deleteFile(at url: URL) throws {
  851. try ModelFileManager.removeFile(at: url)
  852. }
  853. func fakeLocalModelInfo(mismatch: Bool) -> LocalModelInfo {
  854. var hash = "fakeModelHash"
  855. if mismatch {
  856. hash = "wrongModelHash"
  857. }
  858. return LocalModelInfo(
  859. name: "fakeModelName",
  860. modelHash: hash,
  861. size: 20,
  862. path: "fakeModelPath"
  863. )
  864. }
  865. func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
  866. var expiryTime = Date()
  867. var size = fakeModelSize
  868. if !expired {
  869. expiryTime = expiryTime.addingTimeInterval(1000)
  870. }
  871. if oversized {
  872. size = Int(Int64.max) / 4
  873. }
  874. return RemoteModelInfo(name: fakeModelName,
  875. downloadURL: fakeDownloadURL,
  876. modelHash: fakeModelHash,
  877. size: size,
  878. urlExpiryTime: expiryTime)
  879. }
  880. }