ModelDownloaderIntegrationTests.swift 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Skip keychain tests on Catalyst and macOS. Tests are skipped because they
  15. // involve interactions with the keychain that require a provisioning profile.
  16. // This is because MLModelDownloader depends on Installations.
  17. // See go/firebase-macos-keychain-popups for more details.
  18. #if !targetEnvironment(macCatalyst) && !os(macOS)
  19. import XCTest
  20. @testable import FirebaseCore
  21. @testable import FirebaseInstallations
  22. @testable import FirebaseMLModelDownloader
  23. extension UserDefaults {
  24. /// Returns a new cleared instance of user defaults.
  25. static func createTestInstance(testName: String) -> UserDefaults {
  26. let suiteName = "com.google.firebase.ml.test.\(testName)"
  27. let defaults = UserDefaults(suiteName: suiteName)!
  28. defaults.removePersistentDomain(forName: suiteName)
  29. return defaults
  30. }
  31. /// Returns the existing user defaults instance.
  32. static func getTestInstance(testName: String) -> UserDefaults {
  33. let suiteName = "com.google.firebase.ml.test.\(testName)"
  34. return UserDefaults(suiteName: suiteName)!
  35. }
  36. }
  37. final class ModelDownloaderIntegrationTests: XCTestCase {
  38. override class func setUp() {
  39. super.setUp()
  40. // TODO: Use FirebaseApp internal for test app.
  41. let bundle = Bundle(for: self)
  42. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  43. let options = FirebaseOptions(contentsOfFile: plistPath) {
  44. FirebaseApp.configure(options: options)
  45. } else {
  46. XCTFail("Could not locate GoogleService-Info.plist.")
  47. }
  48. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  49. }
  50. override func setUp() {
  51. do {
  52. try ModelFileManager.emptyModelsDirectory()
  53. } catch {
  54. XCTFail("Could not empty models directory.")
  55. }
  56. }
  57. /// Test to download model info - makes an actual network call.
  58. func testDownloadModelInfo() {
  59. guard let testApp = FirebaseApp.app() else {
  60. XCTFail("Default app was not configured.")
  61. return
  62. }
  63. testApp.isDataCollectionDefaultEnabled = false
  64. let testName = String(#function.dropLast(2))
  65. let testModelName = "pose-detection"
  66. let modelInfoRetriever = ModelInfoRetriever(
  67. modelName: testModelName,
  68. projectID: testApp.options.projectID!,
  69. apiKey: testApp.options.apiKey!,
  70. appName: testApp.name, installations: Installations.installations(app: testApp)
  71. )
  72. let modelInfoDownloadExpectation =
  73. expectation(description: "Wait for model info to download.")
  74. modelInfoRetriever.downloadModelInfo(completion: { result in
  75. switch result {
  76. case let .success(modelInfoResult):
  77. switch modelInfoResult {
  78. case let .modelInfo(modelInfo):
  79. XCTAssertNotNil(modelInfo.urlExpiryTime)
  80. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  81. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  82. XCTAssertGreaterThan(modelInfo.size, 0)
  83. let localModelInfo = LocalModelInfo(from: modelInfo)
  84. localModelInfo.writeToDefaults(
  85. .createTestInstance(testName: testName),
  86. appName: testApp.name
  87. )
  88. case .notModified:
  89. XCTFail("Failed to retrieve model info.")
  90. }
  91. case let .failure(error):
  92. XCTAssertNotNil(error)
  93. XCTFail("Failed to retrieve model info - \(error)")
  94. }
  95. modelInfoDownloadExpectation.fulfill()
  96. })
  97. wait(for: [modelInfoDownloadExpectation], timeout: 5)
  98. if let localInfo = LocalModelInfo(
  99. fromDefaults: .getTestInstance(testName: testName),
  100. name: testModelName,
  101. appName: testApp.name
  102. ) {
  103. XCTAssertNotNil(localInfo)
  104. testRetrieveModelInfo(localInfo: localInfo)
  105. } else {
  106. XCTFail("Could not save model info locally.")
  107. }
  108. }
  109. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  110. guard let testApp = FirebaseApp.app() else {
  111. XCTFail("Default app was not configured.")
  112. return
  113. }
  114. let testModelName = "pose-detection"
  115. let modelInfoRetriever = ModelInfoRetriever(
  116. modelName: testModelName,
  117. projectID: testApp.options.projectID!,
  118. apiKey: testApp.options.apiKey!,
  119. appName: testApp.name, installations: Installations.installations(app: testApp),
  120. localModelInfo: localInfo
  121. )
  122. let modelInfoRetrieveExpectation =
  123. expectation(description: "Wait for model info to be retrieved.")
  124. modelInfoRetriever.downloadModelInfo(completion: { result in
  125. switch result {
  126. case let .success(modelInfoResult):
  127. switch modelInfoResult {
  128. case .modelInfo:
  129. XCTFail("Local model info is already the latest and should not be set again.")
  130. case .notModified: break
  131. }
  132. case let .failure(error):
  133. XCTAssertNotNil(error)
  134. XCTFail("Failed to retrieve model info - \(error)")
  135. }
  136. modelInfoRetrieveExpectation.fulfill()
  137. })
  138. wait(for: [modelInfoRetrieveExpectation], timeout: 5)
  139. }
  140. /// Test to download model file - makes an actual network call.
  141. func testModelDownload() throws {
  142. guard let testApp = FirebaseApp.app() else {
  143. XCTFail("Default app was not configured.")
  144. return
  145. }
  146. testApp.isDataCollectionDefaultEnabled = false
  147. let testName = String(#function.dropLast(2))
  148. let testModelName = "\(testName)-test-model"
  149. let urlString =
  150. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  151. let url = URL(string: urlString)!
  152. let remoteModelInfo = RemoteModelInfo(
  153. name: testModelName,
  154. downloadURL: url,
  155. modelHash: "mock-valid-hash",
  156. size: 10,
  157. urlExpiryTime: Date()
  158. )
  159. let conditions = ModelDownloadConditions()
  160. let downloadExpectation = expectation(description: "Wait for model to download.")
  161. let downloader = ModelFileDownloader(conditions: conditions)
  162. let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
  163. XCTAssertLessThanOrEqual(progress, 1)
  164. XCTAssertGreaterThanOrEqual(progress, 0)
  165. }
  166. let taskCompletion: ModelDownloadTask.Completion = { result in
  167. switch result {
  168. case let .success(model):
  169. let modelURL = URL(fileURLWithPath: model.path)
  170. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  171. // Remove downloaded model file.
  172. do {
  173. try ModelFileManager.removeFile(at: modelURL)
  174. } catch {
  175. XCTFail("Model removal failed - \(error)")
  176. }
  177. case let .failure(error):
  178. XCTFail("Error: \(error)")
  179. }
  180. downloadExpectation.fulfill()
  181. }
  182. let modelDownloadManager = ModelDownloadTask(
  183. remoteModelInfo: remoteModelInfo,
  184. appName: testApp.name,
  185. defaults: .createTestInstance(testName: testName),
  186. downloader: downloader,
  187. progressHandler: taskProgressHandler,
  188. completion: taskCompletion
  189. )
  190. modelDownloadManager.resume()
  191. wait(for: [downloadExpectation], timeout: 5)
  192. XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
  193. }
  194. func testGetModel() {
  195. guard let testApp = FirebaseApp.app() else {
  196. XCTFail("Default app was not configured.")
  197. return
  198. }
  199. testApp.isDataCollectionDefaultEnabled = false
  200. let testName = String(#function.dropLast(2))
  201. let testModelName = "image-classification"
  202. let conditions = ModelDownloadConditions()
  203. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  204. .createTestInstance(testName: testName),
  205. app: testApp
  206. )
  207. /// Test download type - latest model.
  208. var downloadType: ModelDownloadType = .latestModel
  209. let latestModelExpectation = expectation(description: "Get latest model.")
  210. modelDownloader.getModel(
  211. name: testModelName,
  212. downloadType: downloadType,
  213. conditions: conditions,
  214. progressHandler: { progress in
  215. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  216. XCTAssertLessThanOrEqual(progress, 1)
  217. XCTAssertGreaterThanOrEqual(progress, 0)
  218. }
  219. ) { result in
  220. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  221. switch result {
  222. case let .success(model):
  223. XCTAssertNotNil(model.path)
  224. let modelURL = URL(fileURLWithPath: model.path)
  225. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  226. case let .failure(error):
  227. XCTFail("Failed to download model - \(error)")
  228. }
  229. latestModelExpectation.fulfill()
  230. }
  231. wait(for: [latestModelExpectation], timeout: 5)
  232. /// Test download type - local model update in background.
  233. downloadType = .localModelUpdateInBackground
  234. let backgroundModelExpectation =
  235. expectation(description: "Get local model and update in background.")
  236. modelDownloader.getModel(
  237. name: testModelName,
  238. downloadType: downloadType,
  239. conditions: conditions,
  240. progressHandler: { progress in
  241. XCTFail("Model is already available on device.")
  242. }
  243. ) { result in
  244. switch result {
  245. case let .success(model):
  246. XCTAssertNotNil(model.path)
  247. let modelURL = URL(fileURLWithPath: model.path)
  248. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  249. case let .failure(error):
  250. XCTFail("Failed to download model - \(error)")
  251. }
  252. backgroundModelExpectation.fulfill()
  253. }
  254. wait(for: [backgroundModelExpectation], timeout: 5)
  255. /// Test download type - local model.
  256. downloadType = .localModel
  257. let localModelExpectation = expectation(description: "Get local model.")
  258. modelDownloader.getModel(
  259. name: testModelName,
  260. downloadType: downloadType,
  261. conditions: conditions,
  262. progressHandler: { progress in
  263. XCTFail("Model is already available on device.")
  264. }
  265. ) { result in
  266. switch result {
  267. case let .success(model):
  268. XCTAssertNotNil(model.path)
  269. let modelURL = URL(fileURLWithPath: model.path)
  270. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelURL))
  271. // Remove downloaded model file.
  272. do {
  273. try ModelFileManager.removeFile(at: modelURL)
  274. } catch {
  275. XCTFail("Model removal failed - \(error)")
  276. }
  277. case let .failure(error):
  278. XCTFail("Failed to download model - \(error)")
  279. }
  280. localModelExpectation.fulfill()
  281. }
  282. wait(for: [localModelExpectation], timeout: 5)
  283. }
  284. func testGetModelWhenNameIsEmpty() {
  285. guard let testApp = FirebaseApp.app() else {
  286. XCTFail("Default app was not configured.")
  287. return
  288. }
  289. testApp.isDataCollectionDefaultEnabled = false
  290. let testName = String(#function.dropLast(2))
  291. let emptyModelName = ""
  292. let conditions = ModelDownloadConditions()
  293. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  294. .createTestInstance(testName: testName),
  295. app: testApp
  296. )
  297. let completionExpectation = expectation(description: "getModel")
  298. let progressExpectation = expectation(description: "progressHandler")
  299. progressExpectation.isInverted = true
  300. modelDownloader.getModel(
  301. name: emptyModelName,
  302. downloadType: .latestModel,
  303. conditions: conditions,
  304. progressHandler: { progress in
  305. progressExpectation.fulfill()
  306. }
  307. ) { result in
  308. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  309. switch result {
  310. case .failure(.emptyModelName):
  311. // The expected error.
  312. break
  313. default:
  314. XCTFail("Unexpected result: \(result)")
  315. }
  316. completionExpectation.fulfill()
  317. }
  318. wait(for: [completionExpectation, progressExpectation], timeout: 5)
  319. }
  320. /// Delete previously downloaded model.
  321. func testDeleteModel() {
  322. guard let testApp = FirebaseApp.app() else {
  323. XCTFail("Default app was not configured.")
  324. return
  325. }
  326. testApp.isDataCollectionDefaultEnabled = false
  327. let testName = String(#function.dropLast(2))
  328. let testModelName = "pose-detection"
  329. let conditions = ModelDownloadConditions()
  330. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  331. .createTestInstance(testName: testName),
  332. app: testApp
  333. )
  334. let downloadType: ModelDownloadType = .latestModel
  335. let latestModelExpectation = expectation(description: "Get latest model for deletion.")
  336. modelDownloader.getModel(
  337. name: testModelName,
  338. downloadType: downloadType,
  339. conditions: conditions,
  340. progressHandler: { progress in
  341. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  342. XCTAssertLessThanOrEqual(progress, 1)
  343. XCTAssertGreaterThanOrEqual(progress, 0)
  344. }
  345. ) { result in
  346. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  347. switch result {
  348. case let .success(model):
  349. XCTAssertNotNil(model.path)
  350. let filePath = URL(fileURLWithPath: model.path)
  351. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  352. case let .failure(error):
  353. XCTFail("Failed to download model - \(error)")
  354. }
  355. latestModelExpectation.fulfill()
  356. }
  357. wait(for: [latestModelExpectation], timeout: 5)
  358. let deleteExpectation = expectation(description: "Wait for model deletion.")
  359. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  360. deleteExpectation.fulfill()
  361. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  362. switch result {
  363. case .success: break
  364. case let .failure(error):
  365. XCTFail("Failed to delete model - \(error)")
  366. }
  367. }
  368. wait(for: [deleteExpectation], timeout: 5)
  369. }
  370. /// Test listing models in model directory.
  371. func testListModels() {
  372. guard let testApp = FirebaseApp.app() else {
  373. XCTFail("Default app was not configured.")
  374. return
  375. }
  376. testApp.isDataCollectionDefaultEnabled = false
  377. let testName = String(#function.dropLast(2))
  378. let testModelName = "pose-detection"
  379. let conditions = ModelDownloadConditions()
  380. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  381. .createTestInstance(testName: testName),
  382. app: testApp
  383. )
  384. let downloadType: ModelDownloadType = .latestModel
  385. let latestModelExpectation = expectation(description: "Get latest model.")
  386. modelDownloader.getModel(
  387. name: testModelName,
  388. downloadType: downloadType,
  389. conditions: conditions,
  390. progressHandler: { progress in
  391. XCTAssertLessThanOrEqual(progress, 1)
  392. XCTAssertGreaterThanOrEqual(progress, 0)
  393. }
  394. ) { result in
  395. switch result {
  396. case let .success(model):
  397. XCTAssertNotNil(model.path)
  398. let filePath = URL(fileURLWithPath: model.path)
  399. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  400. case let .failure(error):
  401. XCTFail("Failed to download model - \(error)")
  402. }
  403. latestModelExpectation.fulfill()
  404. }
  405. wait(for: [latestModelExpectation], timeout: 5)
  406. let listExpectation = expectation(description: "Wait for list models.")
  407. modelDownloader.listDownloadedModels { result in
  408. listExpectation.fulfill()
  409. switch result {
  410. case let .success(models):
  411. XCTAssertGreaterThan(models.count, 0)
  412. case let .failure(error):
  413. XCTFail("Failed to list models - \(error)")
  414. }
  415. }
  416. wait(for: [listExpectation], timeout: 5)
  417. }
  418. /// Test logging telemetry event.
  419. func testLogTelemetryEvent() {
  420. guard let testApp = FirebaseApp.app() else {
  421. XCTFail("Default app was not configured.")
  422. return
  423. }
  424. // Flip this to `true` to test logging.
  425. testApp.isDataCollectionDefaultEnabled = false
  426. let testModelName = "digit-classification"
  427. let testName = String(#function.dropLast(2))
  428. let conditions = ModelDownloadConditions()
  429. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  430. .createTestInstance(testName: testName),
  431. app: testApp
  432. )
  433. let latestModelExpectation = expectation(description: "Test get model telemetry.")
  434. modelDownloader.getModel(
  435. name: testModelName,
  436. downloadType: .latestModel,
  437. conditions: conditions
  438. ) { result in
  439. switch result {
  440. case .success: break
  441. case let .failure(error):
  442. XCTFail("Failed to download model - \(error)")
  443. }
  444. latestModelExpectation.fulfill()
  445. }
  446. wait(for: [latestModelExpectation], timeout: 5)
  447. let deleteModelExpectation = expectation(description: "Test delete model telemetry.")
  448. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  449. switch result {
  450. case .success(()): break
  451. case let .failure(error):
  452. XCTFail("Failed to delete model - \(error)")
  453. }
  454. deleteModelExpectation.fulfill()
  455. }
  456. wait(for: [deleteModelExpectation], timeout: 5)
  457. let notFoundModelName = "fakeModelName"
  458. let notFoundModelExpectation =
  459. expectation(description: "Test get model telemetry with unknown model.")
  460. modelDownloader.getModel(
  461. name: notFoundModelName,
  462. downloadType: .latestModel,
  463. conditions: conditions
  464. ) { result in
  465. switch result {
  466. case .success: XCTFail("Unexpected model success.")
  467. case let .failure(error):
  468. XCTAssertEqual(error, .notFound)
  469. }
  470. notFoundModelExpectation.fulfill()
  471. }
  472. wait(for: [notFoundModelExpectation], timeout: 5)
  473. }
  474. func testGetModelWithConditions() {
  475. guard let testApp = FirebaseApp.app() else {
  476. XCTFail("Default app was not configured.")
  477. return
  478. }
  479. testApp.isDataCollectionDefaultEnabled = false
  480. let testModelName = "pose-detection"
  481. let testName = String(#function.dropLast(2))
  482. let conditions = ModelDownloadConditions(allowsCellularAccess: false)
  483. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  484. .createTestInstance(testName: testName),
  485. app: testApp
  486. )
  487. let latestModelExpectation = expectation(description: "Get latest model with conditions.")
  488. modelDownloader.getModel(
  489. name: testModelName,
  490. downloadType: .latestModel,
  491. conditions: conditions,
  492. progressHandler: { progress in
  493. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  494. XCTAssertLessThanOrEqual(progress, 1)
  495. XCTAssertGreaterThanOrEqual(progress, 0)
  496. }
  497. ) { result in
  498. XCTAssertTrue(Thread.isMainThread, "Completion must be called on the main thread.")
  499. switch result {
  500. case let .success(model):
  501. XCTAssertNotNil(model.path)
  502. let filePath = URL(fileURLWithPath: model.path)
  503. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  504. case let .failure(error):
  505. XCTFail("Failed to download model - \(error)")
  506. }
  507. latestModelExpectation.fulfill()
  508. }
  509. wait(for: [latestModelExpectation], timeout: 5)
  510. }
  511. }
  512. #endif // !targetEnvironment(macCatalyst) && !os(macOS)