ModelDownloaderIntegrationTests.swift 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. extension UserDefaults {
  19. /// Returns a new cleared instance of user defaults.
  20. static func createTestInstance(testName: String) -> UserDefaults {
  21. let suiteName = "com.google.firebase.ml.test.\(testName)"
  22. let defaults = UserDefaults(suiteName: suiteName)!
  23. defaults.removePersistentDomain(forName: suiteName)
  24. return defaults
  25. }
  26. /// Returns the existing user defaults instance.
  27. static func getTestInstance(testName: String) -> UserDefaults {
  28. let suiteName = "com.google.firebase.ml.test.\(testName)"
  29. return UserDefaults(suiteName: suiteName)!
  30. }
  31. }
  32. final class ModelDownloaderIntegrationTests: XCTestCase {
  33. override class func setUp() {
  34. super.setUp()
  35. // TODO: Use FirebaseApp internal for test app.
  36. let bundle = Bundle(for: self)
  37. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  38. let options = FirebaseOptions(contentsOfFile: plistPath) {
  39. FirebaseApp.configure(options: options)
  40. } else {
  41. XCTFail("Could not locate GoogleService-Info.plist.")
  42. }
  43. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  44. }
  45. override func setUp() {
  46. do {
  47. try ModelFileManager.emptyModelsDirectory()
  48. } catch {
  49. XCTFail("Could not empty models directory.")
  50. }
  51. }
  52. /// Test to download model info - makes an actual network call.
  53. func testDownloadModelInfo() {
  54. guard let testApp = FirebaseApp.app() else {
  55. XCTFail("Default app was not configured.")
  56. return
  57. }
  58. testApp.isDataCollectionDefaultEnabled = false
  59. let testName = String(#function.dropLast(2))
  60. let testModelName = "pose-detection"
  61. let modelInfoRetriever = ModelInfoRetriever(
  62. modelName: testModelName,
  63. projectID: testApp.options.projectID!,
  64. apiKey: testApp.options.apiKey!,
  65. appName: testApp.name, installations: Installations.installations(app: testApp)
  66. )
  67. let modelInfoDownloadExpectation = expectation(description: "Wait for model info to download.")
  68. modelInfoRetriever.downloadModelInfo(completion: { result in
  69. switch result {
  70. case let .success(modelInfoResult):
  71. switch modelInfoResult {
  72. case let .modelInfo(modelInfo):
  73. XCTAssertNotNil(modelInfo.urlExpiryTime)
  74. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  75. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  76. XCTAssertGreaterThan(modelInfo.size, 0)
  77. let localModelInfo = LocalModelInfo(from: modelInfo)
  78. localModelInfo.writeToDefaults(
  79. .createTestInstance(testName: testName),
  80. appName: testApp.name
  81. )
  82. case .notModified:
  83. XCTFail("Failed to retrieve model info.")
  84. }
  85. case let .failure(error):
  86. XCTAssertNotNil(error)
  87. XCTFail("Failed to retrieve model info - \(error)")
  88. }
  89. modelInfoDownloadExpectation.fulfill()
  90. })
  91. wait(for: [modelInfoDownloadExpectation], timeout: 5)
  92. if let localInfo = LocalModelInfo(
  93. fromDefaults: .getTestInstance(testName: testName),
  94. name: testModelName,
  95. appName: testApp.name
  96. ) {
  97. XCTAssertNotNil(localInfo)
  98. testRetrieveModelInfo(localInfo: localInfo)
  99. } else {
  100. XCTFail("Could not save model info locally.")
  101. }
  102. }
  103. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  104. guard let testApp = FirebaseApp.app() else {
  105. XCTFail("Default app was not configured.")
  106. return
  107. }
  108. let testModelName = "pose-detection"
  109. let modelInfoRetriever = ModelInfoRetriever(
  110. modelName: testModelName,
  111. projectID: testApp.options.projectID!,
  112. apiKey: testApp.options.apiKey!,
  113. appName: testApp.name, installations: Installations.installations(app: testApp),
  114. localModelInfo: localInfo
  115. )
  116. let modelInfoRetrieveExpectation =
  117. expectation(description: "Wait for model info to be retrieved.")
  118. modelInfoRetriever.downloadModelInfo(completion: { result in
  119. switch result {
  120. case let .success(modelInfoResult):
  121. switch modelInfoResult {
  122. case .modelInfo:
  123. XCTFail("Local model info is already the latest and should not be set again.")
  124. case .notModified: break
  125. }
  126. case let .failure(error):
  127. XCTAssertNotNil(error)
  128. XCTFail("Failed to retrieve model info - \(error)")
  129. }
  130. modelInfoRetrieveExpectation.fulfill()
  131. })
  132. wait(for: [modelInfoRetrieveExpectation], timeout: 5)
  133. }
  134. /// Test to download model file - makes an actual network call.
  135. func testModelDownload() throws {
  136. guard let testApp = FirebaseApp.app() else {
  137. XCTFail("Default app was not configured.")
  138. return
  139. }
  140. testApp.isDataCollectionDefaultEnabled = false
  141. let testName = String(#function.dropLast(2))
  142. let testModelName = "\(testName)-test-model"
  143. let urlString =
  144. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  145. let url = URL(string: urlString)!
  146. let remoteModelInfo = RemoteModelInfo(
  147. name: testModelName,
  148. downloadURL: url,
  149. modelHash: "mock-valid-hash",
  150. size: 10,
  151. urlExpiryTime: Date()
  152. )
  153. let conditions = ModelDownloadConditions()
  154. let downloadExpectation = expectation(description: "Wait for model to download.")
  155. let downloader = ModelFileDownloader(conditions: conditions)
  156. let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
  157. XCTAssertLessThanOrEqual(progress, 1)
  158. XCTAssertGreaterThanOrEqual(progress, 0)
  159. }
  160. let taskCompletion: ModelDownloadTask.Completion = { result in
  161. switch result {
  162. case let .success(model):
  163. let modelURL = URL(fileURLWithPath: model.path)
  164. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  165. // Remove downloaded model file.
  166. do {
  167. try ModelFileManager.removeFile(at: modelURL)
  168. } catch {
  169. XCTFail("Model removal failed - \(error)")
  170. }
  171. case let .failure(error):
  172. XCTFail("Error: \(error)")
  173. }
  174. downloadExpectation.fulfill()
  175. }
  176. let modelDownloadManager = ModelDownloadTask(
  177. remoteModelInfo: remoteModelInfo,
  178. appName: testApp.name,
  179. defaults: .createTestInstance(testName: testName),
  180. downloader: downloader,
  181. progressHandler: taskProgressHandler,
  182. completion: taskCompletion
  183. )
  184. modelDownloadManager.resume()
  185. wait(for: [downloadExpectation], timeout: 5)
  186. XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
  187. }
  188. func testGetModel() {
  189. guard let testApp = FirebaseApp.app() else {
  190. XCTFail("Default app was not configured.")
  191. return
  192. }
  193. testApp.isDataCollectionDefaultEnabled = false
  194. let testName = String(#function.dropLast(2))
  195. let testModelName = "image-classification"
  196. let conditions = ModelDownloadConditions()
  197. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  198. .createTestInstance(testName: testName),
  199. app: testApp
  200. )
  201. /// Test download type - latest model.
  202. var downloadType: ModelDownloadType = .latestModel
  203. let latestModelExpectation = expectation(description: "Get latest model.")
  204. modelDownloader.getModel(
  205. name: testModelName,
  206. downloadType: downloadType,
  207. conditions: conditions,
  208. progressHandler: { progress in
  209. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  210. XCTAssertLessThanOrEqual(progress, 1)
  211. XCTAssertGreaterThanOrEqual(progress, 0)
  212. }
  213. ) { result in
  214. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  215. switch result {
  216. case let .success(model):
  217. XCTAssertNotNil(model.path)
  218. let modelURL = URL(fileURLWithPath: model.path)
  219. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  220. case let .failure(error):
  221. XCTFail("Failed to download model - \(error)")
  222. }
  223. latestModelExpectation.fulfill()
  224. }
  225. wait(for: [latestModelExpectation], timeout: 5)
  226. /// Test download type - local model update in background.
  227. downloadType = .localModelUpdateInBackground
  228. let backgroundModelExpectation =
  229. expectation(description: "Get local model and update in background.")
  230. modelDownloader.getModel(
  231. name: testModelName,
  232. downloadType: downloadType,
  233. conditions: conditions,
  234. progressHandler: { progress in
  235. XCTFail("Model is already available on device.")
  236. }
  237. ) { result in
  238. switch result {
  239. case let .success(model):
  240. XCTAssertNotNil(model.path)
  241. let modelURL = URL(fileURLWithPath: model.path)
  242. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  243. case let .failure(error):
  244. XCTFail("Failed to download model - \(error)")
  245. }
  246. backgroundModelExpectation.fulfill()
  247. }
  248. wait(for: [backgroundModelExpectation], timeout: 5)
  249. /// Test download type - local model.
  250. downloadType = .localModel
  251. let localModelExpectation = expectation(description: "Get local model.")
  252. modelDownloader.getModel(
  253. name: testModelName,
  254. downloadType: downloadType,
  255. conditions: conditions,
  256. progressHandler: { progress in
  257. XCTFail("Model is already available on device.")
  258. }
  259. ) { result in
  260. switch result {
  261. case let .success(model):
  262. XCTAssertNotNil(model.path)
  263. let modelURL = URL(fileURLWithPath: model.path)
  264. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  265. // Remove downloaded model file.
  266. do {
  267. try ModelFileManager.removeFile(at: modelURL)
  268. } catch {
  269. XCTFail("Model removal failed - \(error)")
  270. }
  271. case let .failure(error):
  272. XCTFail("Failed to download model - \(error)")
  273. }
  274. localModelExpectation.fulfill()
  275. }
  276. wait(for: [localModelExpectation], timeout: 5)
  277. }
  278. func testGetModelWhenNameIsEmpty() {
  279. guard let testApp = FirebaseApp.app() else {
  280. XCTFail("Default app was not configured.")
  281. return
  282. }
  283. testApp.isDataCollectionDefaultEnabled = false
  284. let testName = String(#function.dropLast(2))
  285. let emptyModelName = ""
  286. let conditions = ModelDownloadConditions()
  287. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  288. .createTestInstance(testName: testName),
  289. app: testApp
  290. )
  291. let completionExpectation = expectation(description: "getModel")
  292. let progressExpectation = expectation(description: "progressHandler")
  293. progressExpectation.isInverted = true
  294. modelDownloader.getModel(
  295. name: emptyModelName,
  296. downloadType: .latestModel,
  297. conditions: conditions,
  298. progressHandler: { progress in
  299. progressExpectation.fulfill()
  300. }
  301. ) { result in
  302. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  303. switch result {
  304. case .failure(.emptyModelName):
  305. // The expected error.
  306. break
  307. default:
  308. XCTFail("Unexpected result: \(result)")
  309. }
  310. completionExpectation.fulfill()
  311. }
  312. wait(for: [completionExpectation, progressExpectation], timeout: 5)
  313. }
  314. /// Delete previously downloaded model.
  315. func testDeleteModel() {
  316. guard let testApp = FirebaseApp.app() else {
  317. XCTFail("Default app was not configured.")
  318. return
  319. }
  320. testApp.isDataCollectionDefaultEnabled = false
  321. let testName = String(#function.dropLast(2))
  322. let testModelName = "pose-detection"
  323. let conditions = ModelDownloadConditions()
  324. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  325. .createTestInstance(testName: testName),
  326. app: testApp
  327. )
  328. let downloadType: ModelDownloadType = .latestModel
  329. let latestModelExpectation = expectation(description: "Get latest model for deletion.")
  330. modelDownloader.getModel(
  331. name: testModelName,
  332. downloadType: downloadType,
  333. conditions: conditions,
  334. progressHandler: { progress in
  335. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  336. XCTAssertLessThanOrEqual(progress, 1)
  337. XCTAssertGreaterThanOrEqual(progress, 0)
  338. }
  339. ) { result in
  340. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  341. switch result {
  342. case let .success(model):
  343. XCTAssertNotNil(model.path)
  344. let filePath = URL(fileURLWithPath: model.path)
  345. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  346. case let .failure(error):
  347. XCTFail("Failed to download model - \(error)")
  348. }
  349. latestModelExpectation.fulfill()
  350. }
  351. wait(for: [latestModelExpectation], timeout: 5)
  352. let deleteExpectation = expectation(description: "Wait for model deletion.")
  353. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  354. deleteExpectation.fulfill()
  355. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  356. switch result {
  357. case .success: break
  358. case let .failure(error):
  359. XCTFail("Failed to delete model - \(error)")
  360. }
  361. }
  362. wait(for: [deleteExpectation], timeout: 5)
  363. }
  364. /// Test listing models in model directory.
  365. func testListModels() {
  366. guard let testApp = FirebaseApp.app() else {
  367. XCTFail("Default app was not configured.")
  368. return
  369. }
  370. testApp.isDataCollectionDefaultEnabled = false
  371. let testName = String(#function.dropLast(2))
  372. let testModelName = "pose-detection"
  373. let conditions = ModelDownloadConditions()
  374. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  375. .createTestInstance(testName: testName),
  376. app: testApp
  377. )
  378. let downloadType: ModelDownloadType = .latestModel
  379. let latestModelExpectation = expectation(description: "Get latest model.")
  380. modelDownloader.getModel(
  381. name: testModelName,
  382. downloadType: downloadType,
  383. conditions: conditions,
  384. progressHandler: { progress in
  385. XCTAssertLessThanOrEqual(progress, 1)
  386. XCTAssertGreaterThanOrEqual(progress, 0)
  387. }
  388. ) { result in
  389. switch result {
  390. case let .success(model):
  391. XCTAssertNotNil(model.path)
  392. let filePath = URL(fileURLWithPath: model.path)
  393. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  394. case let .failure(error):
  395. XCTFail("Failed to download model - \(error)")
  396. }
  397. latestModelExpectation.fulfill()
  398. }
  399. wait(for: [latestModelExpectation], timeout: 5)
  400. let listExpectation = expectation(description: "Wait for list models.")
  401. modelDownloader.listDownloadedModels { result in
  402. listExpectation.fulfill()
  403. switch result {
  404. case let .success(models):
  405. XCTAssertGreaterThan(models.count, 0)
  406. case let .failure(error):
  407. XCTFail("Failed to list models - \(error)")
  408. }
  409. }
  410. wait(for: [listExpectation], timeout: 5)
  411. }
  412. /// Test logging telemetry event.
  413. func testLogTelemetryEvent() {
  414. guard let testApp = FirebaseApp.app() else {
  415. XCTFail("Default app was not configured.")
  416. return
  417. }
  418. // Flip this to `true` to test logging.
  419. testApp.isDataCollectionDefaultEnabled = false
  420. let testModelName = "digit-classification"
  421. let testName = String(#function.dropLast(2))
  422. let conditions = ModelDownloadConditions()
  423. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  424. .createTestInstance(testName: testName),
  425. app: testApp
  426. )
  427. let latestModelExpectation = expectation(description: "Test get model telemetry.")
  428. modelDownloader.getModel(
  429. name: testModelName,
  430. downloadType: .latestModel,
  431. conditions: conditions
  432. ) { result in
  433. switch result {
  434. case .success: break
  435. case let .failure(error):
  436. XCTFail("Failed to download model - \(error)")
  437. }
  438. latestModelExpectation.fulfill()
  439. }
  440. wait(for: [latestModelExpectation], timeout: 5)
  441. let deleteModelExpectation = expectation(description: "Test delete model telemetry.")
  442. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  443. switch result {
  444. case .success(()): break
  445. case let .failure(error):
  446. XCTFail("Failed to delete model - \(error)")
  447. }
  448. deleteModelExpectation.fulfill()
  449. }
  450. wait(for: [deleteModelExpectation], timeout: 5)
  451. let notFoundModelName = "fakeModelName"
  452. let notFoundModelExpectation =
  453. expectation(description: "Test get model telemetry with unknown model.")
  454. modelDownloader.getModel(
  455. name: notFoundModelName,
  456. downloadType: .latestModel,
  457. conditions: conditions
  458. ) { result in
  459. switch result {
  460. case .success: XCTFail("Unexpected model success.")
  461. case let .failure(error):
  462. XCTAssertEqual(error, .notFound)
  463. }
  464. notFoundModelExpectation.fulfill()
  465. }
  466. wait(for: [notFoundModelExpectation], timeout: 5)
  467. }
  468. func testGetModelWithConditions() {
  469. guard let testApp = FirebaseApp.app() else {
  470. XCTFail("Default app was not configured.")
  471. return
  472. }
  473. testApp.isDataCollectionDefaultEnabled = false
  474. let testModelName = "pose-detection"
  475. let testName = String(#function.dropLast(2))
  476. let conditions = ModelDownloadConditions(allowsCellularAccess: false)
  477. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  478. .createTestInstance(testName: testName),
  479. app: testApp
  480. )
  481. let latestModelExpectation = expectation(description: "Get latest model with conditions.")
  482. modelDownloader.getModel(
  483. name: testModelName,
  484. downloadType: .latestModel,
  485. conditions: conditions,
  486. progressHandler: { progress in
  487. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  488. XCTAssertLessThanOrEqual(progress, 1)
  489. XCTAssertGreaterThanOrEqual(progress, 0)
  490. }
  491. ) { result in
  492. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  493. switch result {
  494. case let .success(model):
  495. XCTAssertNotNil(model.path)
  496. let filePath = URL(fileURLWithPath: model.path)
  497. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  498. case let .failure(error):
  499. XCTFail("Failed to download model - \(error)")
  500. }
  501. latestModelExpectation.fulfill()
  502. }
  503. wait(for: [latestModelExpectation], timeout: 5)
  504. }
  505. }