ModelDownloaderUnitTests.swift 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. /// Mock options to configure default Firebase app.
  19. private enum MockOptions {
  20. static let appID = "1:123:ios:123abc"
  21. static let gcmSenderID = "mock-sender-id"
  22. static let projectID = "mock-project-id"
  23. static let apiKey = "ABcdEf-APIKeyWithValidFormat_0123456789"
  24. }
  25. // TODO: Create separate files for each class tested.
  26. final class ModelDownloaderUnitTests: XCTestCase {
  27. let fakeModelName = "fakeModelName"
  28. let fakeModelHash = "fakeModelHash"
  29. let fakeDownloadURL = URL(string: "www.fake-download-url.com")!
  30. let fakeFileURL = URL(string: "www.fake-model-file.com")!
  31. let fakeExpiryTime = "2021-01-20T04:20:10.220Z"
  32. let fakeModelSize = 20
  33. let fakeProjectID = "fakeProjectID"
  34. let fakeAPIKey = "fakeAPIKey"
  35. var fakeModelJSON: String {
  36. """
  37. {
  38. "downloadUri":"\(fakeDownloadURL)",
  39. "expireTime":"\(fakeExpiryTime)",
  40. "sizeBytes":"\(fakeModelSize)"
  41. }
  42. """
  43. }
  44. let successAuthTokenProvider =
  45. { (completion: @escaping (Result<String, DownloadError>) -> Void) in
  46. completion(.success("fakeFISToken"))
  47. }
  48. override class func setUp() {
  49. let options = FirebaseOptions(
  50. googleAppID: MockOptions.appID,
  51. gcmSenderID: MockOptions.gcmSenderID
  52. )
  53. options.apiKey = MockOptions.apiKey
  54. options.projectID = MockOptions.projectID
  55. FirebaseApp.configure(options: options)
  56. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  57. }
  58. override class func tearDown() {
  59. try? FileManager.default.removeItem(atPath: NSTemporaryDirectory())
  60. FirebaseApp.app()?.delete { _ in }
  61. }
  62. /// Test if the same downloader instance is returned for an app.
  63. func testModelDownloaderInit() {
  64. guard let testApp = FirebaseApp.app() else {
  65. XCTFail("Default app was not configured.")
  66. return
  67. }
  68. let modelDownloader1 = ModelDownloader.modelDownloader(app: testApp)
  69. let modelDownloader2 = ModelDownloader.modelDownloader()
  70. XCTAssert(modelDownloader1 === modelDownloader2)
  71. }
  72. /// Test if app deletion removes associated downloader instance.
  73. func testAppDeletion() {
  74. guard let testApp = FirebaseApp.app() else {
  75. XCTFail("Default app was not configured.")
  76. return
  77. }
  78. let modelDownloader1 = ModelDownloader.modelDownloader()
  79. testApp.delete { success in
  80. XCTAssertTrue(success)
  81. }
  82. ModelDownloaderUnitTests.setUp()
  83. let modelDownloader2 = ModelDownloader.modelDownloader()
  84. XCTAssert(modelDownloader1 !== modelDownloader2)
  85. }
  86. /// Test invalid model file deletion.
  87. func testDeleteInvalidModel() {
  88. let modelDownloader = ModelDownloader.modelDownloader()
  89. modelDownloader.deleteDownloadedModel(name: "fake-model-name") { result in
  90. switch result {
  91. case .success: XCTFail("Unexpected successful model deletion.")
  92. case let .failure(error):
  93. XCTAssertNotNil(error)
  94. }
  95. }
  96. }
  97. /// Test listing models in model directory.
  98. // TODO: Add unit test.
  99. func testListModels() {
  100. let modelDownloader = ModelDownloader.modelDownloader()
  101. modelDownloader.listDownloadedModels { result in
  102. switch result {
  103. case .success: break
  104. case .failure: break
  105. }
  106. }
  107. }
  108. func testGetModel() {
  109. // This is an example of a functional test case.
  110. // Use XCTAssert and related functions to verify your tests produce the correct
  111. // results.
  112. guard let testApp = FirebaseApp.app() else {
  113. XCTFail("Default app was not configured.")
  114. return
  115. }
  116. let modelDownloader = ModelDownloader.modelDownloader()
  117. let modelDownloaderWithApp = ModelDownloader.modelDownloader(app: testApp)
  118. /// These should point to the same instance.
  119. XCTAssert(modelDownloader === modelDownloaderWithApp)
  120. let conditions = ModelDownloadConditions()
  121. // Download model w/ progress handler
  122. modelDownloader.getModel(
  123. name: "your_model_name",
  124. downloadType: .latestModel,
  125. conditions: conditions,
  126. progressHandler: { progress in
  127. // Handle progress
  128. }
  129. ) { result in
  130. switch result {
  131. case .success:
  132. // Use model with your inference API
  133. // let interpreter = Interpreter(modelPath: customModel.modelPath)
  134. break
  135. case .failure:
  136. // Handle download error
  137. break
  138. }
  139. }
  140. // Access array of downloaded models
  141. modelDownloaderWithApp.listDownloadedModels { result in
  142. switch result {
  143. case .success:
  144. // Pick model(s) for further use
  145. break
  146. case .failure:
  147. // Handle failure
  148. break
  149. }
  150. }
  151. // Delete downloaded model
  152. modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
  153. switch result {
  154. case .success():
  155. // Apply any other clean up
  156. break
  157. case .failure:
  158. // Handle failure
  159. break
  160. }
  161. }
  162. }
  163. /// Test on-device logging.
  164. func testDeviceLogging() {
  165. DeviceLogger.logEvent(level: .error,
  166. message: "This is a test error logged to device.",
  167. messageCode: .testError)
  168. }
  169. /// Compare proto serialization methods.
  170. func testTelemetryEncoding() {
  171. let fakeModel = CustomModel(
  172. name: "fakeModelName",
  173. size: 10,
  174. path: "fakeModelPath",
  175. hash: "fakeModelHash"
  176. )
  177. var modelOptions = ModelOptions()
  178. modelOptions.setModelOptions(model: fakeModel)
  179. guard let binaryData = try? modelOptions.serializedData(),
  180. let jsonData = try? modelOptions.jsonUTF8Data(),
  181. let binaryEvent = try? ModelOptions(serializedData: binaryData),
  182. let jsonEvent = try? ModelOptions(jsonUTF8Data: jsonData) else {
  183. XCTFail("Encoding error.")
  184. return
  185. }
  186. XCTAssertNotNil(binaryData)
  187. XCTAssertNotNil(jsonData)
  188. XCTAssertLessThan(binaryData.count, jsonData.count)
  189. XCTAssertEqual(binaryEvent, jsonEvent)
  190. }
  191. /// Get model info if server returns a new model info.
  192. func testGetModelInfoWith200() {
  193. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  194. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  195. let completionExpectation = expectation(description: "Completion handler")
  196. modelInfoRetriever.downloadModelInfo { result in
  197. completionExpectation.fulfill()
  198. switch result {
  199. case let .success(fakemodelInfoResult):
  200. switch fakemodelInfoResult {
  201. case let .modelInfo(remoteModelInfo):
  202. XCTAssertEqual(remoteModelInfo.name, self.fakeModelName)
  203. XCTAssertEqual(remoteModelInfo.downloadURL, self.fakeDownloadURL)
  204. XCTAssertEqual(remoteModelInfo.size, self.fakeModelSize)
  205. XCTAssertEqual(remoteModelInfo.modelHash, self.fakeModelHash)
  206. case .notModified: XCTFail("Expected new model info from server.")
  207. }
  208. case let .failure(error): XCTFail("Failed with error - \(error)")
  209. }
  210. }
  211. wait(for: [completionExpectation], timeout: 0.5)
  212. }
  213. /// Get model info if model info is not modified.
  214. func testGetModelInfoWith304() {
  215. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  216. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  217. let modelInfoRetriever = fakeModelRetriever(
  218. fakeSession: session,
  219. fakeLocalModelInfo: localModelInfo
  220. )
  221. let completionExpectation = expectation(description: "Completion handler")
  222. modelInfoRetriever.downloadModelInfo { result in
  223. completionExpectation.fulfill()
  224. switch result {
  225. case let .success(fakemodelInfoResult):
  226. switch fakemodelInfoResult {
  227. case .modelInfo: XCTFail("Unexpected new model info from server.")
  228. case .notModified: break
  229. }
  230. case let .failure(error): XCTFail("Failed with error - \(error)")
  231. }
  232. }
  233. wait(for: [completionExpectation], timeout: 0.5)
  234. }
  235. /// Get model info if model info is not modified but local model info is not set.
  236. func testGetModelInfoWith304MissingLocalInfo() {
  237. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  238. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  239. let completionExpectation = expectation(description: "Completion handler")
  240. modelInfoRetriever.downloadModelInfo { result in
  241. completionExpectation.fulfill()
  242. switch result {
  243. case let .success(fakemodelInfoResult):
  244. switch fakemodelInfoResult {
  245. case .modelInfo: XCTFail("Unexpected new model info from server.")
  246. case .notModified: XCTFail("Expected failure since local model info was not set.")
  247. }
  248. case let .failure(error):
  249. switch error {
  250. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  251. "Model info was deleted unexpectedly.")
  252. default: XCTFail("Expected failure since local model info was not set.")
  253. }
  254. }
  255. }
  256. wait(for: [completionExpectation], timeout: 0.5)
  257. }
  258. /// Get model info if model info is not modified but local model info is not set.
  259. func testGetModelInfoWith304HashMismatch() {
  260. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 304)
  261. let localModelInfo = fakeLocalModelInfo(mismatch: true)
  262. let modelInfoRetriever = fakeModelRetriever(
  263. fakeSession: session,
  264. fakeLocalModelInfo: localModelInfo
  265. )
  266. let completionExpectation = expectation(description: "Completion handler")
  267. modelInfoRetriever.downloadModelInfo { result in
  268. completionExpectation.fulfill()
  269. switch result {
  270. case let .success(fakemodelInfoResult):
  271. switch fakemodelInfoResult {
  272. case .modelInfo: XCTFail("Unexpected new model info from server.")
  273. case .notModified: XCTFail("Expected failure since model hash does not match.")
  274. }
  275. case let .failure(error):
  276. switch error {
  277. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  278. "Unexpected model hash value.")
  279. default: XCTFail("Expected failure since local model info was not set.")
  280. }
  281. }
  282. }
  283. wait(for: [completionExpectation], timeout: 0.5)
  284. }
  285. /// Get model info if model name is invalid.
  286. func testGetModelInfoWith400() {
  287. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 400)
  288. let localModelInfo = fakeLocalModelInfo(mismatch: false)
  289. let modelInfoRetriever = fakeModelRetriever(
  290. fakeSession: session,
  291. fakeLocalModelInfo: localModelInfo
  292. )
  293. let completionExpectation = expectation(description: "Completion handler")
  294. modelInfoRetriever.downloadModelInfo { result in
  295. completionExpectation.fulfill()
  296. switch result {
  297. case let .success(fakemodelInfoResult):
  298. switch fakemodelInfoResult {
  299. case .modelInfo: XCTFail("Unexpected new model info from server.")
  300. case .notModified: XCTFail("Expected failure since model name is invalid.")
  301. }
  302. case let .failure(error):
  303. XCTAssertEqual(error, .invalidArgument)
  304. }
  305. }
  306. wait(for: [completionExpectation], timeout: 0.5)
  307. }
  308. /// Get model if server returns a model file.
  309. func testModelDownloadWith200() {
  310. let tempFileURL = tempFile()
  311. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  312. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  313. statusCode: 200,
  314. httpVersion: nil,
  315. headerFields: nil)!
  316. let downloader = MockModelFileDownloader()
  317. let fileDownloaderExpectation = expectation(description: "File downloader")
  318. downloader.downloadFileHandler = { url in
  319. fileDownloaderExpectation.fulfill()
  320. XCTAssertEqual(url, self.fakeDownloadURL)
  321. }
  322. let progressExpectation = expectation(description: "Progress handler")
  323. progressExpectation.expectedFulfillmentCount = 2
  324. let completionExpectation = expectation(description: "Completion handler")
  325. let taskProgressHandler: ModelDownloadTask.ProgressHandler = {
  326. progress in
  327. progressExpectation.fulfill()
  328. XCTAssertEqual(progress, 0.4)
  329. }
  330. let taskCompletion: ModelDownloadTask.Completion = { result in
  331. completionExpectation.fulfill()
  332. switch result {
  333. case let .success(model):
  334. XCTAssertEqual(model.name, self.fakeModelName)
  335. XCTAssertEqual(model.size, self.fakeModelSize)
  336. XCTAssertEqual(model.hash, self.fakeModelHash)
  337. let modelPath = URL(string: model.path)!
  338. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  339. try? self.deleteFile(at: modelPath)
  340. case let .failure(error):
  341. XCTFail("Error - \(error)")
  342. }
  343. }
  344. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  345. appName: "fakeAppName",
  346. defaults: .createTestInstance(testName: #function),
  347. downloader: downloader,
  348. progressHandler: taskProgressHandler,
  349. completion: taskCompletion)
  350. modelDownloadTask.resume()
  351. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  352. downloader.progressHandler?(100, 250)
  353. downloader.progressHandler?(100, 250)
  354. wait(for: [progressExpectation], timeout: 0.5)
  355. downloader
  356. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  357. fileURL: tempFileURL)))
  358. wait(for: [completionExpectation], timeout: 0.5)
  359. try? ModelFileManager.removeFile(at: tempFileURL)
  360. }
  361. /// Get model if server returns an error due to expired url.
  362. func testModelDownloadWith400Expired() {
  363. let tempFileURL = tempFile()
  364. let remoteModelInfo = fakeModelInfo(expired: true, oversized: false)
  365. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  366. statusCode: 400,
  367. httpVersion: nil,
  368. headerFields: nil)!
  369. let downloader = MockModelFileDownloader()
  370. let fileDownloaderExpectation = expectation(description: "File downloader")
  371. downloader.downloadFileHandler = { url in
  372. fileDownloaderExpectation.fulfill()
  373. XCTAssertEqual(url, self.fakeDownloadURL)
  374. }
  375. let completionExpectation = expectation(description: "Completion handler")
  376. let taskCompletion: ModelDownloadTask.Completion = { result in
  377. completionExpectation.fulfill()
  378. switch result {
  379. case .success: XCTFail("Unexpected successful model download.")
  380. case let .failure(error):
  381. XCTAssertEqual(error, .expiredDownloadURL)
  382. }
  383. }
  384. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  385. appName: "fakeAppName",
  386. defaults: .createTestInstance(testName: #function),
  387. downloader: downloader,
  388. completion: taskCompletion)
  389. modelDownloadTask.resume()
  390. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  391. downloader
  392. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  393. fileURL: tempFileURL)))
  394. wait(for: [completionExpectation], timeout: 0.5)
  395. try? ModelFileManager.removeFile(at: tempFileURL)
  396. }
  397. /// Get model if server returns an error due to invalid model name.
  398. func testModelDownloadWith400Invalid() {
  399. let tempFileURL = tempFile()
  400. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  401. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  402. statusCode: 400,
  403. httpVersion: nil,
  404. headerFields: nil)!
  405. let downloader = MockModelFileDownloader()
  406. let fileDownloaderExpectation = expectation(description: "File downloader")
  407. downloader.downloadFileHandler = { url in
  408. fileDownloaderExpectation.fulfill()
  409. XCTAssertEqual(url, self.fakeDownloadURL)
  410. }
  411. let completionExpectation = expectation(description: "Completion handler")
  412. let taskCompletion: ModelDownloadTask.Completion = { result in
  413. completionExpectation.fulfill()
  414. switch result {
  415. case .success: XCTFail("Unexpected successful model download.")
  416. case let .failure(error):
  417. XCTAssertEqual(error, .invalidArgument)
  418. }
  419. }
  420. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  421. appName: "fakeAppName",
  422. defaults: .createTestInstance(testName: #function),
  423. downloader: downloader,
  424. completion: taskCompletion)
  425. modelDownloadTask.resume()
  426. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  427. downloader
  428. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  429. fileURL: tempFileURL)))
  430. wait(for: [completionExpectation], timeout: 0.5)
  431. try? ModelFileManager.removeFile(at: tempFileURL)
  432. }
  433. /// Get model if multiple duplicate requests are made.
  434. func testModelDownloadWithMergeRequests() {
  435. let tempFileURL = tempFile()
  436. let numberOfRequests = 5
  437. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  438. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  439. statusCode: 200,
  440. httpVersion: nil,
  441. headerFields: nil)!
  442. let downloader = MockModelFileDownloader()
  443. let fileDownloaderExpectation = expectation(description: "File downloader")
  444. downloader.downloadFileHandler = { url in
  445. fileDownloaderExpectation.fulfill()
  446. XCTAssertEqual(url, self.fakeDownloadURL)
  447. }
  448. let completionExpectation = expectation(description: "Completion handler")
  449. completionExpectation.expectedFulfillmentCount = numberOfRequests + 1
  450. let taskCompletion: ModelDownloadTask.Completion = { result in
  451. completionExpectation.fulfill()
  452. switch result {
  453. case let .success(model):
  454. XCTAssertEqual(model.name, self.fakeModelName)
  455. XCTAssertEqual(model.size, self.fakeModelSize)
  456. XCTAssertEqual(model.hash, self.fakeModelHash)
  457. let modelPath = URL(string: model.path)!
  458. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  459. case let .failure(error):
  460. XCTFail("Error - \(error)")
  461. }
  462. }
  463. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  464. appName: "fakeAppName",
  465. defaults: .createTestInstance(testName: #function),
  466. downloader: downloader,
  467. completion: taskCompletion)
  468. for i in 1 ... numberOfRequests {
  469. let newTaskCompletion: ModelDownloadTask.Completion = { result in
  470. completionExpectation.fulfill()
  471. switch result {
  472. case let .success(model):
  473. XCTAssertEqual(model.name, self.fakeModelName)
  474. XCTAssertEqual(model.size, self.fakeModelSize)
  475. XCTAssertEqual(model.hash, self.fakeModelHash)
  476. let modelPath = URL(string: model.path)!
  477. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  478. if i == numberOfRequests {
  479. try? self.deleteFile(at: modelPath)
  480. }
  481. case let .failure(error):
  482. XCTFail("Error - \(error)")
  483. }
  484. }
  485. modelDownloadTask.resume()
  486. modelDownloadTask.merge(newCompletion: newTaskCompletion)
  487. }
  488. downloader
  489. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  490. fileURL: tempFileURL)))
  491. wait(for: [fileDownloaderExpectation], timeout: 1)
  492. wait(for: [completionExpectation], timeout: 2)
  493. try? ModelFileManager.removeFile(at: tempFileURL)
  494. }
  495. /// Get model if multiple duplicate requests are made (test at the API surface).
  496. func testGetModelWithMergeRequests() {
  497. let modelDownloader = ModelDownloader.modelDownloader()
  498. let conditions = ModelDownloadConditions()
  499. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  500. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  501. let tempFileURL = tempFile()
  502. let numberOfRequests = 6
  503. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  504. statusCode: 200,
  505. httpVersion: nil,
  506. headerFields: nil)!
  507. let downloader = MockModelFileDownloader()
  508. let fileDownloaderExpectation = expectation(description: "File downloader")
  509. downloader.downloadFileHandler = { url in
  510. fileDownloaderExpectation.fulfill()
  511. XCTAssertEqual(url, self.fakeDownloadURL)
  512. }
  513. let completionExpectation = expectation(description: "Completion handler")
  514. completionExpectation.expectedFulfillmentCount = numberOfRequests
  515. for i in 1 ... numberOfRequests {
  516. modelDownloader.downloadInfoAndModel(modelName: "fakeModelName",
  517. modelInfoRetriever: modelInfoRetriever,
  518. downloader: downloader,
  519. conditions: conditions) { result in
  520. completionExpectation.fulfill()
  521. switch result {
  522. case let .success(model):
  523. XCTAssertEqual(model.name, self.fakeModelName)
  524. XCTAssertEqual(model.size, self.fakeModelSize)
  525. XCTAssertEqual(model.hash, self.fakeModelHash)
  526. let modelPath = URL(string: model.path)!
  527. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  528. if i == numberOfRequests {
  529. try? self.deleteFile(at: modelPath)
  530. }
  531. case let .failure(error):
  532. XCTFail("Error - \(error)")
  533. }
  534. }
  535. }
  536. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  537. downloader
  538. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  539. fileURL: tempFileURL)))
  540. wait(for: [completionExpectation], timeout: 1)
  541. try? ModelFileManager.removeFile(at: tempFileURL)
  542. }
  543. /// Get model if server returns an error due to model not found.
  544. func testModelDownloadWith404() {
  545. let tempFileURL = tempFile()
  546. let remoteModelInfo = fakeModelInfo(expired: false, oversized: false)
  547. let fakeResponse = HTTPURLResponse(url: fakeFileURL,
  548. statusCode: 404,
  549. httpVersion: nil,
  550. headerFields: nil)!
  551. let downloader = MockModelFileDownloader()
  552. let fileDownloaderExpectation = expectation(description: "File downloader")
  553. downloader.downloadFileHandler = { url in
  554. fileDownloaderExpectation.fulfill()
  555. XCTAssertEqual(url, self.fakeDownloadURL)
  556. }
  557. let completionExpectation = expectation(description: "Completion handler")
  558. let taskCompletion: ModelDownloadTask.Completion = { result in
  559. completionExpectation.fulfill()
  560. switch result {
  561. case .success: XCTFail("Unexpected successful model download.")
  562. case let .failure(error):
  563. XCTAssertEqual(error, .notFound)
  564. }
  565. }
  566. let modelDownloadTask = ModelDownloadTask(remoteModelInfo: remoteModelInfo,
  567. appName: "fakeAppName",
  568. defaults: .createTestInstance(testName: #function),
  569. downloader: downloader,
  570. completion: taskCompletion)
  571. modelDownloadTask.resume()
  572. wait(for: [fileDownloaderExpectation], timeout: 0.1)
  573. downloader
  574. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponse,
  575. fileURL: tempFileURL)))
  576. wait(for: [completionExpectation], timeout: 0.5)
  577. try? ModelFileManager.removeFile(at: tempFileURL)
  578. }
  579. /// Get model if download url expired but retry was successful.
  580. func testGetModelSuccessfulRetry() {
  581. let modelDownloader = ModelDownloader.modelDownloader()
  582. let conditions = ModelDownloadConditions()
  583. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  584. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  585. let downloader = MockModelFileDownloader()
  586. let tempFileURL = tempFile()
  587. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  588. statusCode: 400,
  589. httpVersion: nil,
  590. headerFields: nil)!
  591. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  592. statusCode: 200,
  593. httpVersion: nil,
  594. headerFields: nil)!
  595. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  596. let completionExpectation = expectation(description: "Completion handler.")
  597. // Handles initial response.
  598. downloader.downloadFileHandler = { url in
  599. fileDownloaderExpectation1.fulfill()
  600. XCTAssertEqual(url, self.fakeDownloadURL)
  601. }
  602. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  603. modelInfoRetriever: modelInfoRetriever,
  604. downloader: downloader,
  605. conditions: conditions) { result in
  606. completionExpectation.fulfill()
  607. switch result {
  608. // Retry successful - model returned.
  609. case let .success(model):
  610. XCTAssertEqual(model.name, self.fakeModelName)
  611. XCTAssertEqual(model.size, self.fakeModelSize)
  612. XCTAssertEqual(model.hash, self.fakeModelHash)
  613. let modelPath = URL(string: model.path)!
  614. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  615. try? self.deleteFile(at: modelPath)
  616. case let .failure(error):
  617. XCTFail("Error - \(error)")
  618. }
  619. }
  620. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  621. // Handles response after one retry.
  622. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  623. downloader.downloadFileHandler = { url in
  624. fileDownloaderExpectation2.fulfill()
  625. XCTAssertEqual(url, self.fakeDownloadURL)
  626. }
  627. downloader
  628. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  629. fileURL: tempFileURL)))
  630. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  631. downloader
  632. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  633. fileURL: tempFileURL)))
  634. wait(for: [completionExpectation], timeout: 0.5)
  635. try? ModelFileManager.removeFile(at: tempFileURL)
  636. }
  637. /// Get model if model doesn't exist.
  638. func testGetModelNotFound() {
  639. let modelDownloader = ModelDownloader.modelDownloader()
  640. let conditions = ModelDownloadConditions()
  641. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  642. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  643. let downloader = MockModelFileDownloader()
  644. let tempFileURL = tempFile()
  645. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  646. statusCode: 404,
  647. httpVersion: nil,
  648. headerFields: nil)!
  649. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  650. let completionExpectation = expectation(description: "Completion handler")
  651. // Handles initial response.
  652. downloader.downloadFileHandler = { url in
  653. fileDownloaderExpectation.fulfill()
  654. XCTAssertEqual(url, self.fakeDownloadURL)
  655. }
  656. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  657. modelInfoRetriever: modelInfoRetriever,
  658. downloader: downloader,
  659. conditions: conditions) { result in
  660. completionExpectation.fulfill()
  661. switch result {
  662. // Retry failed due to another expired url - error returned.
  663. case .success: XCTFail("Unexpected successful model download.")
  664. case let .failure(error):
  665. switch error {
  666. case .notFound: break
  667. default: XCTFail("Expected failure due to model not found.")
  668. }
  669. }
  670. }
  671. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  672. downloader
  673. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  674. fileURL: tempFileURL)))
  675. wait(for: [completionExpectation], timeout: 0.5)
  676. try? ModelFileManager.removeFile(at: tempFileURL)
  677. }
  678. /// Get model if invalid or insufficient permissions.
  679. func testGetModelUnauthorized() {
  680. let modelDownloader = ModelDownloader.modelDownloader()
  681. let conditions = ModelDownloadConditions()
  682. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  683. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  684. let downloader = MockModelFileDownloader()
  685. let tempFileURL = tempFile()
  686. let fakeResponseNotFound = HTTPURLResponse(url: fakeFileURL,
  687. statusCode: 403,
  688. httpVersion: nil,
  689. headerFields: nil)!
  690. let fileDownloaderExpectation = expectation(description: "File downloader before retry")
  691. let completionExpectation = expectation(description: "Completion handler")
  692. // Handles initial response.
  693. downloader.downloadFileHandler = { url in
  694. fileDownloaderExpectation.fulfill()
  695. XCTAssertEqual(url, self.fakeDownloadURL)
  696. }
  697. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  698. modelInfoRetriever: modelInfoRetriever,
  699. downloader: downloader,
  700. conditions: conditions) { result in
  701. completionExpectation.fulfill()
  702. switch result {
  703. // Retry failed due to another expired url - error returned.
  704. case .success: XCTFail("Unexpected successful model download.")
  705. case let .failure(error):
  706. switch error {
  707. case .permissionDenied: break
  708. default: XCTFail("Expected failure due to bad permissions.")
  709. }
  710. }
  711. }
  712. wait(for: [fileDownloaderExpectation], timeout: 0.5)
  713. downloader
  714. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseNotFound,
  715. fileURL: tempFileURL)))
  716. wait(for: [completionExpectation], timeout: 0.5)
  717. try? ModelFileManager.removeFile(at: tempFileURL)
  718. }
  719. /// Get model if download url expired and retry url also expired.
  720. func testGetModelFailedRetry() {
  721. let modelDownloader = ModelDownloader.modelDownloader()
  722. let conditions = ModelDownloadConditions()
  723. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  724. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  725. let downloader = MockModelFileDownloader()
  726. let tempFileURL = tempFile()
  727. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  728. statusCode: 400,
  729. httpVersion: nil,
  730. headerFields: nil)!
  731. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  732. let completionExpectation = expectation(description: "Completion handler")
  733. // Handles initial response.
  734. downloader.downloadFileHandler = { url in
  735. fileDownloaderExpectation1.fulfill()
  736. XCTAssertEqual(url, self.fakeDownloadURL)
  737. }
  738. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  739. modelInfoRetriever: modelInfoRetriever,
  740. downloader: downloader,
  741. conditions: conditions) { result in
  742. completionExpectation.fulfill()
  743. switch result {
  744. // Retry failed due to another expired url - error returned.
  745. case .success: XCTFail("Unexpected successful model download.")
  746. case let .failure(error):
  747. switch error {
  748. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  749. "Unable to update expired model info.")
  750. default: XCTFail("Expected failure due to expired model info.")
  751. }
  752. }
  753. }
  754. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  755. let fileDownloaderExpectation2 = expectation(description: "File downloader after retry")
  756. // Handles response after one retry.
  757. downloader.downloadFileHandler = { url in
  758. fileDownloaderExpectation2.fulfill()
  759. XCTAssertEqual(url, self.fakeDownloadURL)
  760. }
  761. downloader
  762. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  763. fileURL: tempFileURL)))
  764. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  765. downloader
  766. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  767. fileURL: tempFileURL)))
  768. wait(for: [completionExpectation], timeout: 0.5)
  769. try? ModelFileManager.removeFile(at: tempFileURL)
  770. }
  771. /// Get model if multiple retries all returned expired urls.
  772. func testGetModelMultipleFailedRetries() {
  773. let modelDownloader = ModelDownloader.modelDownloader()
  774. modelDownloader.numberOfRetries = 2
  775. let conditions = ModelDownloadConditions()
  776. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  777. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  778. let downloader = MockModelFileDownloader()
  779. let tempFileURL = tempFile()
  780. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  781. statusCode: 400,
  782. httpVersion: nil,
  783. headerFields: nil)!
  784. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  785. let completionExpectation = expectation(description: "Completion handler")
  786. // Handles initial response.
  787. downloader.downloadFileHandler = { url in
  788. fileDownloaderExpectation1.fulfill()
  789. XCTAssertEqual(url, self.fakeDownloadURL)
  790. }
  791. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  792. modelInfoRetriever: modelInfoRetriever,
  793. downloader: downloader,
  794. conditions: conditions) { result in
  795. completionExpectation.fulfill()
  796. switch result {
  797. // Two retries failed due to expired urls - error returned.
  798. case .success: XCTFail("Unexpected successful model download.")
  799. case let .failure(error):
  800. switch error {
  801. case let .internalError(description: errorMessage): XCTAssertEqual(errorMessage,
  802. "Unable to update expired model info.")
  803. default: XCTFail("Expected failure due to expired model info.")
  804. }
  805. }
  806. }
  807. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  808. let fileDownloaderExpectation2 = expectation(description: "file downloader after first retry")
  809. // Handles response after first retry.
  810. downloader.downloadFileHandler = { url in
  811. fileDownloaderExpectation2.fulfill()
  812. XCTAssertEqual(url, self.fakeDownloadURL)
  813. }
  814. downloader
  815. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  816. fileURL: tempFileURL)))
  817. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  818. let fileDownloaderExpectation3 = expectation(description: "File downloader after second retry")
  819. // Handles response after second retry.
  820. downloader.downloadFileHandler = { url in
  821. fileDownloaderExpectation3.fulfill()
  822. XCTAssertEqual(url, self.fakeDownloadURL)
  823. }
  824. downloader
  825. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  826. fileURL: tempFileURL)))
  827. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  828. downloader
  829. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  830. fileURL: tempFileURL)))
  831. wait(for: [completionExpectation], timeout: 0.5)
  832. try? ModelFileManager.removeFile(at: tempFileURL)
  833. }
  834. /// Get model if a retry finally returns an unexpired download url.
  835. func testGetModelMultipleRetriesWithSuccess() {
  836. let modelDownloader = ModelDownloader.modelDownloader()
  837. modelDownloader.numberOfRetries = 4
  838. let conditions = ModelDownloadConditions()
  839. let session = fakeModelInfoSessionWithURL(fakeDownloadURL, statusCode: 200)
  840. let modelInfoRetriever = fakeModelRetriever(fakeSession: session)
  841. let downloader = MockModelFileDownloader()
  842. let tempFileURL = tempFile()
  843. let fakeResponseExpired = HTTPURLResponse(url: fakeFileURL,
  844. statusCode: 400,
  845. httpVersion: nil,
  846. headerFields: nil)!
  847. let fakeResponseSucceeded = HTTPURLResponse(url: fakeFileURL,
  848. statusCode: 200,
  849. httpVersion: nil,
  850. headerFields: nil)!
  851. let fileDownloaderExpectation1 = expectation(description: "File downloader before retry")
  852. let completionExpectation = expectation(description: "Completion handler")
  853. // Handles initial response.
  854. downloader.downloadFileHandler = { url in
  855. fileDownloaderExpectation1.fulfill()
  856. XCTAssertEqual(url, self.fakeDownloadURL)
  857. }
  858. modelDownloader.downloadInfoAndModel(modelName: fakeModelName,
  859. modelInfoRetriever: modelInfoRetriever,
  860. downloader: downloader,
  861. conditions: conditions) { result in
  862. completionExpectation.fulfill()
  863. switch result {
  864. // One retry failed but the next retry succeeded - model returned.
  865. case let .success(model):
  866. XCTAssertEqual(model.name, self.fakeModelName)
  867. XCTAssertEqual(model.size, self.fakeModelSize)
  868. XCTAssertEqual(model.hash, self.fakeModelHash)
  869. let modelPath = URL(string: model.path)!
  870. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  871. try? self.deleteFile(at: modelPath)
  872. case let .failure(error):
  873. XCTFail("Error - \(error)")
  874. }
  875. }
  876. wait(for: [fileDownloaderExpectation1], timeout: 0.5)
  877. let fileDownloaderExpectation2 = expectation(description: "File downloader after first retry")
  878. // Handles response after first retry.
  879. downloader.downloadFileHandler = { url in
  880. fileDownloaderExpectation2.fulfill()
  881. XCTAssertEqual(url, self.fakeDownloadURL)
  882. }
  883. downloader
  884. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  885. fileURL: tempFileURL)))
  886. wait(for: [fileDownloaderExpectation2], timeout: 0.5)
  887. let fileDownloaderExpectation3 = expectation(description: "File downloader after second retry")
  888. // Handles response after second retry.
  889. downloader.downloadFileHandler = { url in
  890. fileDownloaderExpectation3.fulfill()
  891. XCTAssertEqual(url, self.fakeDownloadURL)
  892. }
  893. downloader
  894. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseExpired,
  895. fileURL: tempFileURL)))
  896. wait(for: [fileDownloaderExpectation3], timeout: 0.5)
  897. downloader
  898. .completion?(.success(FileDownloaderResponse(urlResponse: fakeResponseSucceeded,
  899. fileURL: tempFileURL)))
  900. wait(for: [completionExpectation], timeout: 0.5)
  901. try? ModelFileManager.removeFile(at: tempFileURL)
  902. }
  903. }
  904. /// Mock URL session for testing.
  905. class MockModelInfoRetrieverSession: ModelInfoRetrieverSession {
  906. var data: Data?
  907. var response: URLResponse?
  908. var error: Error?
  909. func getModelInfo(with request: URLRequest,
  910. completion: @escaping (Data?, URLResponse?, Error?) -> Void) {
  911. completion(data, response, error)
  912. }
  913. }
  914. class MockModelFileDownloader: FileDownloader {
  915. var progressHandler: ProgressHandler?
  916. var completion: CompletionHandler?
  917. var downloadFileHandler: ((_ url: URL) -> Void)?
  918. func downloadFile(with url: URL, progressHandler: @escaping ProgressHandler,
  919. completion: @escaping CompletionHandler) {
  920. self.progressHandler = progressHandler
  921. self.completion = completion
  922. downloadFileHandler?(url)
  923. }
  924. }
  925. extension UserDefaults {
  926. /// Returns a new cleared instance of user defaults.
  927. static func createTestInstance(testName: String) -> UserDefaults {
  928. let suiteName = "com.google.firebase.ml.test.\(testName)"
  929. // TODO: reconsider force unwrapping
  930. let defaults = UserDefaults(suiteName: suiteName)!
  931. defaults.removePersistentDomain(forName: suiteName)
  932. return defaults
  933. }
  934. }
  935. // MARK: - Helpers
  936. extension ModelDownloaderUnitTests {
  937. func fakeModelInfoSessionWithURL(_ url: URL, statusCode: Int) -> MockModelInfoRetrieverSession {
  938. let fakeSession = MockModelInfoRetrieverSession()
  939. fakeSession.data = fakeModelJSON.data(using: .utf8)
  940. fakeSession.response = HTTPURLResponse(
  941. url: url,
  942. statusCode: statusCode,
  943. httpVersion: nil,
  944. headerFields: ["Etag": "fakeModelHash"]
  945. )
  946. fakeSession.error = nil
  947. return fakeSession
  948. }
  949. func fakeModelRetriever(fakeSession: MockModelInfoRetrieverSession,
  950. fakeLocalModelInfo: LocalModelInfo? = nil) -> ModelInfoRetriever {
  951. // TODO: Replace with a fake one so we can check that is was used correctly by the download task.
  952. let modelInfoRetriever = ModelInfoRetriever(
  953. modelName: "fakeModelName",
  954. projectID: "fakeProjectID",
  955. apiKey: "fakeAPIKey",
  956. authTokenProvider: successAuthTokenProvider,
  957. appName: "fakeAppName",
  958. localModelInfo: fakeLocalModelInfo,
  959. session: fakeSession
  960. )
  961. return modelInfoRetriever
  962. }
  963. func tempFile(name: String = "fake-model-file.tmp") -> URL {
  964. let tempDirectoryURL = URL(fileURLWithPath: NSTemporaryDirectory(),
  965. isDirectory: true)
  966. let tempFileURL = tempDirectoryURL.appendingPathComponent(name)
  967. let tempData: Data = "fakeModelData".data(using: .utf8)!
  968. try? tempData.write(to: tempFileURL)
  969. return tempFileURL
  970. }
  971. func deleteFile(at url: URL) throws {
  972. try ModelFileManager.removeFile(at: url)
  973. }
  974. func fakeLocalModelInfo(mismatch: Bool) -> LocalModelInfo {
  975. var hash = "fakeModelHash"
  976. if mismatch {
  977. hash = "wrongModelHash"
  978. }
  979. return LocalModelInfo(
  980. name: "fakeModelName",
  981. modelHash: hash,
  982. size: 20,
  983. path: "fakeModelPath"
  984. )
  985. }
  986. func fakeModelInfo(expired: Bool, oversized: Bool) -> RemoteModelInfo {
  987. var expiryTime = Date()
  988. var size = fakeModelSize
  989. if !expired {
  990. expiryTime = expiryTime.addingTimeInterval(1000)
  991. }
  992. if oversized {
  993. size = Int(Int64.max) / 4
  994. }
  995. return RemoteModelInfo(name: fakeModelName,
  996. downloadURL: fakeDownloadURL,
  997. modelHash: fakeModelHash,
  998. size: size,
  999. urlExpiryTime: expiryTime)
  1000. }
  1001. }