ModelDownloaderIntegrationTests.swift 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
  15. // involve interactions with the keychain that require a provisioning profile.
  16. // This is because MLModelDownloader depends on Installations.
  17. // See go/firebase-macos-keychain-popups for more details.
  18. #if !targetEnvironment(macCatalyst) && !os(macOS)
  19. @testable import FirebaseCore
  20. @testable import FirebaseInstallations
  21. @testable import FirebaseMLModelDownloader
  22. import XCTest
  23. #if SWIFT_PACKAGE
  24. @_implementationOnly import GoogleUtilities_UserDefaults
  25. #else
  26. @_implementationOnly import GoogleUtilities
  27. #endif // SWIFT_PACKAGE
  28. extension GULUserDefaults {
  29. /// Returns an instance of user defaults.
  30. static func createTestInstance(testName: String) -> GULUserDefaults {
  31. let suiteName = "com.google.firebase.ml.test.\(testName)"
  32. // Clear the suite (`UserDefaults` and `GULUserDefaults` map to the same
  33. // storage space and `GULUserDefaults` doesn't offer API to do this.)
  34. UserDefaults(suiteName: suiteName)!.removePersistentDomain(forName: suiteName)
  35. return GULUserDefaults(suiteName: suiteName)
  36. }
  37. /// Returns the existing user defaults instance.
  38. static func getTestInstance(testName: String) -> GULUserDefaults {
  39. let suiteName = "com.google.firebase.ml.test.\(testName)"
  40. return GULUserDefaults(suiteName: suiteName)
  41. }
  42. }
  43. final class ModelDownloaderIntegrationTests: XCTestCase {
  44. override class func setUp() {
  45. super.setUp()
  46. // TODO: Use FirebaseApp internal for test app.
  47. let bundle = Bundle(for: self)
  48. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  49. let options = FirebaseOptions(contentsOfFile: plistPath) {
  50. FirebaseApp.configure(options: options)
  51. } else {
  52. XCTFail("Could not locate GoogleService-Info.plist.")
  53. }
  54. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  55. }
  56. override func setUp() {
  57. do {
  58. try ModelFileManager.emptyModelsDirectory()
  59. } catch {
  60. XCTFail("Could not empty models directory.")
  61. }
  62. }
  63. /// Test to download model info - makes an actual network call.
  64. func testDownloadModelInfo() {
  65. guard let testApp = FirebaseApp.app() else {
  66. XCTFail("Default app was not configured.")
  67. return
  68. }
  69. testApp.isDataCollectionDefaultEnabled = false
  70. let testName = String(#function.dropLast(2))
  71. let testModelName = "pose-detection"
  72. let modelInfoRetriever = ModelInfoRetriever(
  73. modelName: testModelName,
  74. projectID: testApp.options.projectID!,
  75. apiKey: testApp.options.apiKey!,
  76. appName: testApp.name, installations: Installations.installations(app: testApp)
  77. )
  78. let modelInfoDownloadExpectation =
  79. expectation(description: "Wait for model info to download.")
  80. modelInfoRetriever.downloadModelInfo(completion: { result in
  81. switch result {
  82. case let .success(modelInfoResult):
  83. switch modelInfoResult {
  84. case let .modelInfo(modelInfo):
  85. XCTAssertNotNil(modelInfo.urlExpiryTime)
  86. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  87. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  88. XCTAssertGreaterThan(modelInfo.size, 0)
  89. let localModelInfo = LocalModelInfo(from: modelInfo)
  90. localModelInfo.writeToDefaults(
  91. .createTestInstance(testName: testName),
  92. appName: testApp.name
  93. )
  94. case .notModified:
  95. XCTFail("Failed to retrieve model info.")
  96. }
  97. case let .failure(error):
  98. XCTAssertNotNil(error)
  99. XCTFail("Failed to retrieve model info - \(error)")
  100. }
  101. modelInfoDownloadExpectation.fulfill()
  102. })
  103. wait(for: [modelInfoDownloadExpectation], timeout: 5)
  104. if let localInfo = LocalModelInfo(
  105. fromDefaults: .getTestInstance(testName: testName),
  106. name: testModelName,
  107. appName: testApp.name
  108. ) {
  109. XCTAssertNotNil(localInfo)
  110. testRetrieveModelInfo(localInfo: localInfo)
  111. } else {
  112. XCTFail("Could not save model info locally.")
  113. }
  114. }
  115. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  116. guard let testApp = FirebaseApp.app() else {
  117. XCTFail("Default app was not configured.")
  118. return
  119. }
  120. let testModelName = "pose-detection"
  121. let modelInfoRetriever = ModelInfoRetriever(
  122. modelName: testModelName,
  123. projectID: testApp.options.projectID!,
  124. apiKey: testApp.options.apiKey!,
  125. appName: testApp.name, installations: Installations.installations(app: testApp),
  126. localModelInfo: localInfo
  127. )
  128. let modelInfoRetrieveExpectation =
  129. expectation(description: "Wait for model info to be retrieved.")
  130. modelInfoRetriever.downloadModelInfo(completion: { result in
  131. switch result {
  132. case let .success(modelInfoResult):
  133. switch modelInfoResult {
  134. case .modelInfo:
  135. XCTFail("Local model info is already the latest and should not be set again.")
  136. case .notModified: break
  137. }
  138. case let .failure(error):
  139. XCTAssertNotNil(error)
  140. XCTFail("Failed to retrieve model info - \(error)")
  141. }
  142. modelInfoRetrieveExpectation.fulfill()
  143. })
  144. wait(for: [modelInfoRetrieveExpectation], timeout: 5)
  145. }
  146. /// Disabled since this test has been flaky in CI.
  147. /// From https://github.com/firebase/firebase-ios-sdk/actions/runs/9944896811/job/27471928997
  148. /// ModelDownloaderIntegrationTests.swift:190: XCTAssertGreaterThanOrEqual failed: ("-3106.0")
  149. /// is less than ("0.0")
  150. /// ModelDownloaderIntegrationTests.swift:190: XCTAssertGreaterThanOrEqual failed: ("-4902.0")
  151. /// is less than ("0.0")
  152. /// ModelDownloaderIntegrationTests.swift:204: error:
  153. /// failed - Error: internalError(description: "Model download failed with HTTP error code:
  154. /// 500")
  155. ///
  156. /// Test to download model file - makes an actual network call.
  157. func SKIPtestModelDownload() throws {
  158. guard let testApp = FirebaseApp.app() else {
  159. XCTFail("Default app was not configured.")
  160. return
  161. }
  162. testApp.isDataCollectionDefaultEnabled = false
  163. let testName = String(#function.dropLast(2))
  164. let testModelName = "\(testName)-test-model"
  165. let urlString =
  166. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  167. let url = URL(string: urlString)!
  168. let remoteModelInfo = RemoteModelInfo(
  169. name: testModelName,
  170. downloadURL: url,
  171. modelHash: "mock-valid-hash",
  172. size: 10,
  173. urlExpiryTime: Date()
  174. )
  175. let conditions = ModelDownloadConditions()
  176. let downloadExpectation = expectation(description: "Wait for model to download.")
  177. let downloader = ModelFileDownloader(conditions: conditions)
  178. let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
  179. XCTAssertLessThanOrEqual(progress, 1)
  180. XCTAssertGreaterThanOrEqual(progress, 0)
  181. }
  182. let taskCompletion: ModelDownloadTask.Completion = { result in
  183. switch result {
  184. case let .success(model):
  185. let modelURL = URL(fileURLWithPath: model.path)
  186. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  187. // Remove downloaded model file.
  188. do {
  189. try ModelFileManager.removeFile(at: modelURL)
  190. } catch {
  191. XCTFail("Model removal failed - \(error)")
  192. }
  193. case let .failure(error):
  194. XCTFail("Error: \(error)")
  195. }
  196. downloadExpectation.fulfill()
  197. }
  198. let modelDownloadManager = ModelDownloadTask(
  199. remoteModelInfo: remoteModelInfo,
  200. appName: testApp.name,
  201. defaults: .createTestInstance(testName: testName),
  202. downloader: downloader,
  203. progressHandler: taskProgressHandler,
  204. completion: taskCompletion
  205. )
  206. modelDownloadManager.resume()
  207. wait(for: [downloadExpectation], timeout: 5)
  208. XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
  209. }
  210. func testGetModel() {
  211. guard let testApp = FirebaseApp.app() else {
  212. XCTFail("Default app was not configured.")
  213. return
  214. }
  215. testApp.isDataCollectionDefaultEnabled = false
  216. let testName = String(#function.dropLast(2))
  217. let testModelName = "image-classification"
  218. let conditions = ModelDownloadConditions()
  219. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  220. .createTestInstance(testName: testName),
  221. app: testApp
  222. )
  223. /// Test download type - latest model.
  224. var downloadType: ModelDownloadType = .latestModel
  225. let latestModelExpectation = expectation(description: "Get latest model.")
  226. modelDownloader.getModel(
  227. name: testModelName,
  228. downloadType: downloadType,
  229. conditions: conditions,
  230. progressHandler: { progress in
  231. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  232. XCTAssertLessThanOrEqual(progress, 1)
  233. XCTAssertGreaterThanOrEqual(progress, 0)
  234. }
  235. ) { result in
  236. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  237. switch result {
  238. case let .success(model):
  239. XCTAssertNotNil(model.path)
  240. let modelURL = URL(fileURLWithPath: model.path)
  241. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  242. case let .failure(error):
  243. XCTFail("Failed to download model - \(error)")
  244. }
  245. latestModelExpectation.fulfill()
  246. }
  247. wait(for: [latestModelExpectation], timeout: 5)
  248. /// Test download type - local model update in background.
  249. downloadType = .localModelUpdateInBackground
  250. let backgroundModelExpectation =
  251. expectation(description: "Get local model and update in background.")
  252. modelDownloader.getModel(
  253. name: testModelName,
  254. downloadType: downloadType,
  255. conditions: conditions,
  256. progressHandler: { progress in
  257. XCTFail("Model is already available on device.")
  258. }
  259. ) { result in
  260. switch result {
  261. case let .success(model):
  262. XCTAssertNotNil(model.path)
  263. let modelURL = URL(fileURLWithPath: model.path)
  264. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  265. case let .failure(error):
  266. XCTFail("Failed to download model - \(error)")
  267. }
  268. backgroundModelExpectation.fulfill()
  269. }
  270. wait(for: [backgroundModelExpectation], timeout: 5)
  271. /// Test download type - local model.
  272. downloadType = .localModel
  273. let localModelExpectation = expectation(description: "Get local model.")
  274. modelDownloader.getModel(
  275. name: testModelName,
  276. downloadType: downloadType,
  277. conditions: conditions,
  278. progressHandler: { progress in
  279. XCTFail("Model is already available on device.")
  280. }
  281. ) { result in
  282. switch result {
  283. case let .success(model):
  284. XCTAssertNotNil(model.path)
  285. let modelURL = URL(fileURLWithPath: model.path)
  286. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  287. // Remove downloaded model file.
  288. do {
  289. try ModelFileManager.removeFile(at: modelURL)
  290. } catch {
  291. XCTFail("Model removal failed - \(error)")
  292. }
  293. case let .failure(error):
  294. XCTFail("Failed to download model - \(error)")
  295. }
  296. localModelExpectation.fulfill()
  297. }
  298. wait(for: [localModelExpectation], timeout: 5)
  299. }
  300. func testGetModelWhenNameIsEmpty() {
  301. guard let testApp = FirebaseApp.app() else {
  302. XCTFail("Default app was not configured.")
  303. return
  304. }
  305. testApp.isDataCollectionDefaultEnabled = false
  306. let testName = String(#function.dropLast(2))
  307. let emptyModelName = ""
  308. let conditions = ModelDownloadConditions()
  309. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  310. .createTestInstance(testName: testName),
  311. app: testApp
  312. )
  313. let completionExpectation = expectation(description: "getModel")
  314. let progressExpectation = expectation(description: "progressHandler")
  315. progressExpectation.isInverted = true
  316. modelDownloader.getModel(
  317. name: emptyModelName,
  318. downloadType: .latestModel,
  319. conditions: conditions,
  320. progressHandler: { progress in
  321. progressExpectation.fulfill()
  322. }
  323. ) { result in
  324. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  325. switch result {
  326. case .failure(.emptyModelName):
  327. // The expected error.
  328. break
  329. default:
  330. XCTFail("Unexpected result: \(result)")
  331. }
  332. completionExpectation.fulfill()
  333. }
  334. wait(for: [completionExpectation, progressExpectation], timeout: 5)
  335. }
  336. /// Delete previously downloaded model.
  337. func testDeleteModel() {
  338. guard let testApp = FirebaseApp.app() else {
  339. XCTFail("Default app was not configured.")
  340. return
  341. }
  342. testApp.isDataCollectionDefaultEnabled = false
  343. let testName = String(#function.dropLast(2))
  344. let testModelName = "pose-detection"
  345. let conditions = ModelDownloadConditions()
  346. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  347. .createTestInstance(testName: testName),
  348. app: testApp
  349. )
  350. let downloadType: ModelDownloadType = .latestModel
  351. let latestModelExpectation = expectation(description: "Get latest model for deletion.")
  352. modelDownloader.getModel(
  353. name: testModelName,
  354. downloadType: downloadType,
  355. conditions: conditions,
  356. progressHandler: { progress in
  357. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  358. XCTAssertLessThanOrEqual(progress, 1)
  359. XCTAssertGreaterThanOrEqual(progress, 0)
  360. }
  361. ) { result in
  362. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  363. switch result {
  364. case let .success(model):
  365. XCTAssertNotNil(model.path)
  366. let filePath = URL(fileURLWithPath: model.path)
  367. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  368. case let .failure(error):
  369. XCTFail("Failed to download model - \(error)")
  370. }
  371. latestModelExpectation.fulfill()
  372. }
  373. wait(for: [latestModelExpectation], timeout: 5)
  374. let deleteExpectation = expectation(description: "Wait for model deletion.")
  375. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  376. deleteExpectation.fulfill()
  377. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  378. switch result {
  379. case .success: break
  380. case let .failure(error):
  381. XCTFail("Failed to delete model - \(error)")
  382. }
  383. }
  384. wait(for: [deleteExpectation], timeout: 5)
  385. }
  386. /// Test listing models in model directory.
  387. func testListModels() {
  388. guard let testApp = FirebaseApp.app() else {
  389. XCTFail("Default app was not configured.")
  390. return
  391. }
  392. testApp.isDataCollectionDefaultEnabled = false
  393. let testName = String(#function.dropLast(2))
  394. let testModelName = "pose-detection"
  395. let conditions = ModelDownloadConditions()
  396. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  397. .createTestInstance(testName: testName),
  398. app: testApp
  399. )
  400. let downloadType: ModelDownloadType = .latestModel
  401. let latestModelExpectation = expectation(description: "Get latest model.")
  402. modelDownloader.getModel(
  403. name: testModelName,
  404. downloadType: downloadType,
  405. conditions: conditions,
  406. progressHandler: { progress in
  407. XCTAssertLessThanOrEqual(progress, 1)
  408. XCTAssertGreaterThanOrEqual(progress, 0)
  409. }
  410. ) { result in
  411. switch result {
  412. case let .success(model):
  413. XCTAssertNotNil(model.path)
  414. let filePath = URL(fileURLWithPath: model.path)
  415. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  416. case let .failure(error):
  417. XCTFail("Failed to download model - \(error)")
  418. }
  419. latestModelExpectation.fulfill()
  420. }
  421. wait(for: [latestModelExpectation], timeout: 5)
  422. let listExpectation = expectation(description: "Wait for list models.")
  423. modelDownloader.listDownloadedModels { result in
  424. listExpectation.fulfill()
  425. switch result {
  426. case let .success(models):
  427. XCTAssertGreaterThan(models.count, 0)
  428. case let .failure(error):
  429. XCTFail("Failed to list models - \(error)")
  430. }
  431. }
  432. wait(for: [listExpectation], timeout: 5)
  433. }
  434. /// Test logging telemetry event.
  435. func testLogTelemetryEvent() {
  436. guard let testApp = FirebaseApp.app() else {
  437. XCTFail("Default app was not configured.")
  438. return
  439. }
  440. // Flip this to `true` to test logging.
  441. testApp.isDataCollectionDefaultEnabled = false
  442. let testModelName = "digit-classification"
  443. let testName = String(#function.dropLast(2))
  444. let conditions = ModelDownloadConditions()
  445. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  446. .createTestInstance(testName: testName),
  447. app: testApp
  448. )
  449. let latestModelExpectation = expectation(description: "Test get model telemetry.")
  450. modelDownloader.getModel(
  451. name: testModelName,
  452. downloadType: .latestModel,
  453. conditions: conditions
  454. ) { result in
  455. switch result {
  456. case .success: break
  457. case let .failure(error):
  458. XCTFail("Failed to download model - \(error)")
  459. }
  460. latestModelExpectation.fulfill()
  461. }
  462. wait(for: [latestModelExpectation], timeout: 5)
  463. let deleteModelExpectation = expectation(description: "Test delete model telemetry.")
  464. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  465. switch result {
  466. case .success(()): break
  467. case let .failure(error):
  468. XCTFail("Failed to delete model - \(error)")
  469. }
  470. deleteModelExpectation.fulfill()
  471. }
  472. wait(for: [deleteModelExpectation], timeout: 5)
  473. let notFoundModelName = "fakeModelName"
  474. let notFoundModelExpectation =
  475. expectation(description: "Test get model telemetry with unknown model.")
  476. modelDownloader.getModel(
  477. name: notFoundModelName,
  478. downloadType: .latestModel,
  479. conditions: conditions
  480. ) { result in
  481. switch result {
  482. case .success: XCTFail("Unexpected model success.")
  483. case let .failure(error):
  484. XCTAssertEqual(error, .notFound)
  485. }
  486. notFoundModelExpectation.fulfill()
  487. }
  488. wait(for: [notFoundModelExpectation], timeout: 5)
  489. }
  490. func testGetModelWithConditions() {
  491. guard let testApp = FirebaseApp.app() else {
  492. XCTFail("Default app was not configured.")
  493. return
  494. }
  495. testApp.isDataCollectionDefaultEnabled = false
  496. let testModelName = "pose-detection"
  497. let testName = String(#function.dropLast(2))
  498. let conditions = ModelDownloadConditions(allowsCellularAccess: false)
  499. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  500. .createTestInstance(testName: testName),
  501. app: testApp
  502. )
  503. let latestModelExpectation = expectation(description: "Get latest model with conditions.")
  504. modelDownloader.getModel(
  505. name: testModelName,
  506. downloadType: .latestModel,
  507. conditions: conditions,
  508. progressHandler: { progress in
  509. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  510. XCTAssertLessThanOrEqual(progress, 1)
  511. XCTAssertGreaterThanOrEqual(progress, 0)
  512. }
  513. ) { result in
  514. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  515. switch result {
  516. case let .success(model):
  517. XCTAssertNotNil(model.path)
  518. let filePath = URL(fileURLWithPath: model.path)
  519. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  520. case let .failure(error):
  521. XCTFail("Failed to download model - \(error)")
  522. }
  523. latestModelExpectation.fulfill()
  524. }
  525. wait(for: [latestModelExpectation], timeout: 5)
  526. }
  527. }
  528. #endif // !targetEnvironment(macCatalyst) && !os(macOS)