ModelDownloaderIntegrationTests.swift 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
  15. // involve interactions with the keychain that require a provisioning profile.
  16. // This is because MLModelDownloader depends on Installations.
  17. // See go/firebase-macos-keychain-popups for more details.
  18. #if !targetEnvironment(macCatalyst) && !os(macOS)
  19. @testable import FirebaseCore
  20. @testable import FirebaseInstallations
  21. @testable import FirebaseMLModelDownloader
  22. import XCTest
  23. #if SWIFT_PACKAGE
  24. @_implementationOnly import GoogleUtilities_UserDefaults
  25. #else
  26. @_implementationOnly import GoogleUtilities
  27. #endif // SWIFT_PACKAGE
  28. extension GULUserDefaults {
  29. /// Returns an instance of user defaults.
  30. static func createTestInstance(testName: String) -> GULUserDefaults {
  31. let suiteName = "com.google.firebase.ml.test.\(testName)"
  32. // Clear the suite (`UserDefaults` and `GULUserDefaults` map to the same
  33. // storage space and `GULUserDefaults` doesn't offer API to do this.)
  34. UserDefaults(suiteName: suiteName)!.removePersistentDomain(forName: suiteName)
  35. return GULUserDefaults(suiteName: suiteName)
  36. }
  37. /// Returns the existing user defaults instance.
  38. static func getTestInstance(testName: String) -> GULUserDefaults {
  39. let suiteName = "com.google.firebase.ml.test.\(testName)"
  40. return GULUserDefaults(suiteName: suiteName)
  41. }
  42. }
  43. final class ModelDownloaderIntegrationTests: XCTestCase {
  44. override class func setUp() {
  45. super.setUp()
  46. // TODO: Use FirebaseApp internal for test app.
  47. let bundle = Bundle(for: self)
  48. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  49. let options = FirebaseOptions(contentsOfFile: plistPath) {
  50. FirebaseApp.configure(options: options)
  51. } else {
  52. XCTFail("Could not locate GoogleService-Info.plist.")
  53. }
  54. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  55. }
  56. override func setUp() {
  57. do {
  58. try ModelFileManager.emptyModelsDirectory()
  59. } catch {
  60. XCTFail("Could not empty models directory.")
  61. }
  62. }
  63. /// Test to download model info - makes an actual network call.
  64. func testDownloadModelInfo() {
  65. guard let testApp = FirebaseApp.app() else {
  66. XCTFail("Default app was not configured.")
  67. return
  68. }
  69. testApp.isDataCollectionDefaultEnabled = false
  70. let testName = String(#function.dropLast(2))
  71. let testModelName = "pose-detection"
  72. let modelInfoRetriever = ModelInfoRetriever(
  73. modelName: testModelName,
  74. projectID: testApp.options.projectID!,
  75. apiKey: testApp.options.apiKey!,
  76. appName: testApp.name, installations: Installations.installations(app: testApp)
  77. )
  78. let modelInfoDownloadExpectation =
  79. expectation(description: "Wait for model info to download.")
  80. modelInfoRetriever.downloadModelInfo(completion: { result in
  81. switch result {
  82. case let .success(modelInfoResult):
  83. switch modelInfoResult {
  84. case let .modelInfo(modelInfo):
  85. XCTAssertNotNil(modelInfo.urlExpiryTime)
  86. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  87. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  88. XCTAssertGreaterThan(modelInfo.size, 0)
  89. let localModelInfo = LocalModelInfo(from: modelInfo)
  90. localModelInfo.writeToDefaults(
  91. .createTestInstance(testName: testName),
  92. appName: testApp.name
  93. )
  94. case .notModified:
  95. XCTFail("Failed to retrieve model info.")
  96. }
  97. case let .failure(error):
  98. XCTAssertNotNil(error)
  99. XCTFail("Failed to retrieve model info - \(error)")
  100. }
  101. modelInfoDownloadExpectation.fulfill()
  102. })
  103. wait(for: [modelInfoDownloadExpectation], timeout: 5)
  104. if let localInfo = LocalModelInfo(
  105. fromDefaults: .getTestInstance(testName: testName),
  106. name: testModelName,
  107. appName: testApp.name
  108. ) {
  109. XCTAssertNotNil(localInfo)
  110. testRetrieveModelInfo(localInfo: localInfo)
  111. } else {
  112. XCTFail("Could not save model info locally.")
  113. }
  114. }
  115. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  116. guard let testApp = FirebaseApp.app() else {
  117. XCTFail("Default app was not configured.")
  118. return
  119. }
  120. let testModelName = "pose-detection"
  121. let modelInfoRetriever = ModelInfoRetriever(
  122. modelName: testModelName,
  123. projectID: testApp.options.projectID!,
  124. apiKey: testApp.options.apiKey!,
  125. appName: testApp.name, installations: Installations.installations(app: testApp),
  126. localModelInfo: localInfo
  127. )
  128. let modelInfoRetrieveExpectation =
  129. expectation(description: "Wait for model info to be retrieved.")
  130. modelInfoRetriever.downloadModelInfo(completion: { result in
  131. switch result {
  132. case let .success(modelInfoResult):
  133. switch modelInfoResult {
  134. case .modelInfo:
  135. XCTFail("Local model info is already the latest and should not be set again.")
  136. case .notModified: break
  137. }
  138. case let .failure(error):
  139. XCTAssertNotNil(error)
  140. XCTFail("Failed to retrieve model info - \(error)")
  141. }
  142. modelInfoRetrieveExpectation.fulfill()
  143. })
  144. wait(for: [modelInfoRetrieveExpectation], timeout: 5)
  145. }
  146. /// Test to download model file - makes an actual network call.
  147. func testModelDownload() throws {
  148. guard let testApp = FirebaseApp.app() else {
  149. XCTFail("Default app was not configured.")
  150. return
  151. }
  152. testApp.isDataCollectionDefaultEnabled = false
  153. let testName = String(#function.dropLast(2))
  154. let testModelName = "\(testName)-test-model"
  155. let urlString =
  156. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  157. let url = URL(string: urlString)!
  158. let remoteModelInfo = RemoteModelInfo(
  159. name: testModelName,
  160. downloadURL: url,
  161. modelHash: "mock-valid-hash",
  162. size: 10,
  163. urlExpiryTime: Date()
  164. )
  165. let conditions = ModelDownloadConditions()
  166. let downloadExpectation = expectation(description: "Wait for model to download.")
  167. let downloader = ModelFileDownloader(conditions: conditions)
  168. let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
  169. XCTAssertLessThanOrEqual(progress, 1)
  170. XCTAssertGreaterThanOrEqual(progress, 0)
  171. }
  172. let taskCompletion: ModelDownloadTask.Completion = { result in
  173. switch result {
  174. case let .success(model):
  175. let modelURL = URL(fileURLWithPath: model.path)
  176. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  177. // Remove downloaded model file.
  178. do {
  179. try ModelFileManager.removeFile(at: modelURL)
  180. } catch {
  181. XCTFail("Model removal failed - \(error)")
  182. }
  183. case let .failure(error):
  184. XCTFail("Error: \(error)")
  185. }
  186. downloadExpectation.fulfill()
  187. }
  188. let modelDownloadManager = ModelDownloadTask(
  189. remoteModelInfo: remoteModelInfo,
  190. appName: testApp.name,
  191. defaults: .createTestInstance(testName: testName),
  192. downloader: downloader,
  193. progressHandler: taskProgressHandler,
  194. completion: taskCompletion
  195. )
  196. modelDownloadManager.resume()
  197. wait(for: [downloadExpectation], timeout: 5)
  198. XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
  199. }
  200. func testGetModel() {
  201. guard let testApp = FirebaseApp.app() else {
  202. XCTFail("Default app was not configured.")
  203. return
  204. }
  205. testApp.isDataCollectionDefaultEnabled = false
  206. let testName = String(#function.dropLast(2))
  207. let testModelName = "image-classification"
  208. let conditions = ModelDownloadConditions()
  209. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  210. .createTestInstance(testName: testName),
  211. app: testApp
  212. )
  213. /// Test download type - latest model.
  214. var downloadType: ModelDownloadType = .latestModel
  215. let latestModelExpectation = expectation(description: "Get latest model.")
  216. modelDownloader.getModel(
  217. name: testModelName,
  218. downloadType: downloadType,
  219. conditions: conditions,
  220. progressHandler: { progress in
  221. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  222. XCTAssertLessThanOrEqual(progress, 1)
  223. XCTAssertGreaterThanOrEqual(progress, 0)
  224. }
  225. ) { result in
  226. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  227. switch result {
  228. case let .success(model):
  229. XCTAssertNotNil(model.path)
  230. let modelURL = URL(fileURLWithPath: model.path)
  231. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  232. case let .failure(error):
  233. XCTFail("Failed to download model - \(error)")
  234. }
  235. latestModelExpectation.fulfill()
  236. }
  237. wait(for: [latestModelExpectation], timeout: 5)
  238. /// Test download type - local model update in background.
  239. downloadType = .localModelUpdateInBackground
  240. let backgroundModelExpectation =
  241. expectation(description: "Get local model and update in background.")
  242. modelDownloader.getModel(
  243. name: testModelName,
  244. downloadType: downloadType,
  245. conditions: conditions,
  246. progressHandler: { progress in
  247. XCTFail("Model is already available on device.")
  248. }
  249. ) { result in
  250. switch result {
  251. case let .success(model):
  252. XCTAssertNotNil(model.path)
  253. let modelURL = URL(fileURLWithPath: model.path)
  254. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  255. case let .failure(error):
  256. XCTFail("Failed to download model - \(error)")
  257. }
  258. backgroundModelExpectation.fulfill()
  259. }
  260. wait(for: [backgroundModelExpectation], timeout: 5)
  261. /// Test download type - local model.
  262. downloadType = .localModel
  263. let localModelExpectation = expectation(description: "Get local model.")
  264. modelDownloader.getModel(
  265. name: testModelName,
  266. downloadType: downloadType,
  267. conditions: conditions,
  268. progressHandler: { progress in
  269. XCTFail("Model is already available on device.")
  270. }
  271. ) { result in
  272. switch result {
  273. case let .success(model):
  274. XCTAssertNotNil(model.path)
  275. let modelURL = URL(fileURLWithPath: model.path)
  276. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  277. // Remove downloaded model file.
  278. do {
  279. try ModelFileManager.removeFile(at: modelURL)
  280. } catch {
  281. XCTFail("Model removal failed - \(error)")
  282. }
  283. case let .failure(error):
  284. XCTFail("Failed to download model - \(error)")
  285. }
  286. localModelExpectation.fulfill()
  287. }
  288. wait(for: [localModelExpectation], timeout: 5)
  289. }
  290. func testGetModelWhenNameIsEmpty() {
  291. guard let testApp = FirebaseApp.app() else {
  292. XCTFail("Default app was not configured.")
  293. return
  294. }
  295. testApp.isDataCollectionDefaultEnabled = false
  296. let testName = String(#function.dropLast(2))
  297. let emptyModelName = ""
  298. let conditions = ModelDownloadConditions()
  299. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  300. .createTestInstance(testName: testName),
  301. app: testApp
  302. )
  303. let completionExpectation = expectation(description: "getModel")
  304. let progressExpectation = expectation(description: "progressHandler")
  305. progressExpectation.isInverted = true
  306. modelDownloader.getModel(
  307. name: emptyModelName,
  308. downloadType: .latestModel,
  309. conditions: conditions,
  310. progressHandler: { progress in
  311. progressExpectation.fulfill()
  312. }
  313. ) { result in
  314. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  315. switch result {
  316. case .failure(.emptyModelName):
  317. // The expected error.
  318. break
  319. default:
  320. XCTFail("Unexpected result: \(result)")
  321. }
  322. completionExpectation.fulfill()
  323. }
  324. wait(for: [completionExpectation, progressExpectation], timeout: 5)
  325. }
  326. /// Delete previously downloaded model.
  327. func testDeleteModel() {
  328. guard let testApp = FirebaseApp.app() else {
  329. XCTFail("Default app was not configured.")
  330. return
  331. }
  332. testApp.isDataCollectionDefaultEnabled = false
  333. let testName = String(#function.dropLast(2))
  334. let testModelName = "pose-detection"
  335. let conditions = ModelDownloadConditions()
  336. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  337. .createTestInstance(testName: testName),
  338. app: testApp
  339. )
  340. let downloadType: ModelDownloadType = .latestModel
  341. let latestModelExpectation = expectation(description: "Get latest model for deletion.")
  342. modelDownloader.getModel(
  343. name: testModelName,
  344. downloadType: downloadType,
  345. conditions: conditions,
  346. progressHandler: { progress in
  347. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  348. XCTAssertLessThanOrEqual(progress, 1)
  349. XCTAssertGreaterThanOrEqual(progress, 0)
  350. }
  351. ) { result in
  352. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  353. switch result {
  354. case let .success(model):
  355. XCTAssertNotNil(model.path)
  356. let filePath = URL(fileURLWithPath: model.path)
  357. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  358. case let .failure(error):
  359. XCTFail("Failed to download model - \(error)")
  360. }
  361. latestModelExpectation.fulfill()
  362. }
  363. wait(for: [latestModelExpectation], timeout: 5)
  364. let deleteExpectation = expectation(description: "Wait for model deletion.")
  365. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  366. deleteExpectation.fulfill()
  367. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  368. switch result {
  369. case .success: break
  370. case let .failure(error):
  371. XCTFail("Failed to delete model - \(error)")
  372. }
  373. }
  374. wait(for: [deleteExpectation], timeout: 5)
  375. }
  376. /// Test listing models in model directory.
  377. func testListModels() {
  378. guard let testApp = FirebaseApp.app() else {
  379. XCTFail("Default app was not configured.")
  380. return
  381. }
  382. testApp.isDataCollectionDefaultEnabled = false
  383. let testName = String(#function.dropLast(2))
  384. let testModelName = "pose-detection"
  385. let conditions = ModelDownloadConditions()
  386. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  387. .createTestInstance(testName: testName),
  388. app: testApp
  389. )
  390. let downloadType: ModelDownloadType = .latestModel
  391. let latestModelExpectation = expectation(description: "Get latest model.")
  392. modelDownloader.getModel(
  393. name: testModelName,
  394. downloadType: downloadType,
  395. conditions: conditions,
  396. progressHandler: { progress in
  397. XCTAssertLessThanOrEqual(progress, 1)
  398. XCTAssertGreaterThanOrEqual(progress, 0)
  399. }
  400. ) { result in
  401. switch result {
  402. case let .success(model):
  403. XCTAssertNotNil(model.path)
  404. let filePath = URL(fileURLWithPath: model.path)
  405. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  406. case let .failure(error):
  407. XCTFail("Failed to download model - \(error)")
  408. }
  409. latestModelExpectation.fulfill()
  410. }
  411. wait(for: [latestModelExpectation], timeout: 5)
  412. let listExpectation = expectation(description: "Wait for list models.")
  413. modelDownloader.listDownloadedModels { result in
  414. listExpectation.fulfill()
  415. switch result {
  416. case let .success(models):
  417. XCTAssertGreaterThan(models.count, 0)
  418. case let .failure(error):
  419. XCTFail("Failed to list models - \(error)")
  420. }
  421. }
  422. wait(for: [listExpectation], timeout: 5)
  423. }
  424. /// Test logging telemetry event.
  425. func testLogTelemetryEvent() {
  426. guard let testApp = FirebaseApp.app() else {
  427. XCTFail("Default app was not configured.")
  428. return
  429. }
  430. // Flip this to `true` to test logging.
  431. testApp.isDataCollectionDefaultEnabled = false
  432. let testModelName = "digit-classification"
  433. let testName = String(#function.dropLast(2))
  434. let conditions = ModelDownloadConditions()
  435. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  436. .createTestInstance(testName: testName),
  437. app: testApp
  438. )
  439. let latestModelExpectation = expectation(description: "Test get model telemetry.")
  440. modelDownloader.getModel(
  441. name: testModelName,
  442. downloadType: .latestModel,
  443. conditions: conditions
  444. ) { result in
  445. switch result {
  446. case .success: break
  447. case let .failure(error):
  448. XCTFail("Failed to download model - \(error)")
  449. }
  450. latestModelExpectation.fulfill()
  451. }
  452. wait(for: [latestModelExpectation], timeout: 5)
  453. let deleteModelExpectation = expectation(description: "Test delete model telemetry.")
  454. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  455. switch result {
  456. case .success(()): break
  457. case let .failure(error):
  458. XCTFail("Failed to delete model - \(error)")
  459. }
  460. deleteModelExpectation.fulfill()
  461. }
  462. wait(for: [deleteModelExpectation], timeout: 5)
  463. let notFoundModelName = "fakeModelName"
  464. let notFoundModelExpectation =
  465. expectation(description: "Test get model telemetry with unknown model.")
  466. modelDownloader.getModel(
  467. name: notFoundModelName,
  468. downloadType: .latestModel,
  469. conditions: conditions
  470. ) { result in
  471. switch result {
  472. case .success: XCTFail("Unexpected model success.")
  473. case let .failure(error):
  474. XCTAssertEqual(error, .notFound)
  475. }
  476. notFoundModelExpectation.fulfill()
  477. }
  478. wait(for: [notFoundModelExpectation], timeout: 5)
  479. }
  480. func testGetModelWithConditions() {
  481. guard let testApp = FirebaseApp.app() else {
  482. XCTFail("Default app was not configured.")
  483. return
  484. }
  485. testApp.isDataCollectionDefaultEnabled = false
  486. let testModelName = "pose-detection"
  487. let testName = String(#function.dropLast(2))
  488. let conditions = ModelDownloadConditions(allowsCellularAccess: false)
  489. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  490. .createTestInstance(testName: testName),
  491. app: testApp
  492. )
  493. let latestModelExpectation = expectation(description: "Get latest model with conditions.")
  494. modelDownloader.getModel(
  495. name: testModelName,
  496. downloadType: .latestModel,
  497. conditions: conditions,
  498. progressHandler: { progress in
  499. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  500. XCTAssertLessThanOrEqual(progress, 1)
  501. XCTAssertGreaterThanOrEqual(progress, 0)
  502. }
  503. ) { result in
  504. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  505. switch result {
  506. case let .success(model):
  507. XCTAssertNotNil(model.path)
  508. let filePath = URL(fileURLWithPath: model.path)
  509. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  510. case let .failure(error):
  511. XCTFail("Failed to download model - \(error)")
  512. }
  513. latestModelExpectation.fulfill()
  514. }
  515. wait(for: [latestModelExpectation], timeout: 5)
  516. }
  517. }
  518. #endif // !targetEnvironment(macCatalyst) && !os(macOS)