ModelDownloaderIntegrationTests.swift 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. extension UserDefaults {
  19. /// Returns a new cleared instance of user defaults.
  20. static func createTestInstance(testName: String) -> UserDefaults {
  21. let suiteName = "com.google.firebase.ml.test.\(testName)"
  22. // TODO: reconsider force unwrapping
  23. let defaults = UserDefaults(suiteName: suiteName)!
  24. defaults.removePersistentDomain(forName: suiteName)
  25. return defaults
  26. }
  27. /// Returns the existing user defaults instance.
  28. static func getTestInstance(testName: String) -> UserDefaults {
  29. let suiteName = "com.google.firebase.ml.test.\(testName)"
  30. // TODO: reconsider force unwrapping
  31. return UserDefaults(suiteName: suiteName)!
  32. }
  33. }
  34. // TODO: Use FirebaseApp internal init for testApp
  35. final class ModelDownloaderIntegrationTests: XCTestCase {
  36. override class func setUp() {
  37. super.setUp()
  38. let bundle = Bundle(for: self)
  39. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  40. let options = FirebaseOptions(contentsOfFile: plistPath) {
  41. FirebaseApp.configure(options: options)
  42. } else {
  43. XCTFail("Could not locate GoogleService-Info.plist.")
  44. }
  45. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  46. }
  47. /// Test to download model info - makes an actual network call.
  48. func testDownloadModelInfo() {
  49. guard let testApp = FirebaseApp.app() else {
  50. XCTFail("Default app was not configured.")
  51. return
  52. }
  53. let testModelName = "pose-detection"
  54. let modelInfoRetriever = ModelInfoRetriever(
  55. modelName: testModelName,
  56. projectID: testApp.options.projectID!,
  57. apiKey: testApp.options.apiKey!,
  58. installations: Installations.installations(app: testApp),
  59. appName: testApp.name
  60. )
  61. let downloadExpectation = expectation(description: "Wait for model info to download.")
  62. modelInfoRetriever.downloadModelInfo(completion: { result in
  63. switch result {
  64. case let .success(modelInfoResult):
  65. switch modelInfoResult {
  66. case let .modelInfo(modelInfo):
  67. XCTAssertNotNil(modelInfo.urlExpiryTime)
  68. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  69. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  70. XCTAssertGreaterThan(modelInfo.size, 0)
  71. let localModelInfo = LocalModelInfo(from: modelInfo, path: "mock-valid-path")
  72. localModelInfo.writeToDefaults(
  73. .createTestInstance(testName: #function),
  74. appName: testApp.name
  75. )
  76. case .notModified:
  77. XCTFail("Failed to retrieve model info.")
  78. }
  79. case let .failure(error):
  80. XCTAssertNotNil(error)
  81. XCTFail("Failed to retrieve model info - \(error)")
  82. }
  83. downloadExpectation.fulfill()
  84. })
  85. waitForExpectations(timeout: 5, handler: nil)
  86. if let localInfo = LocalModelInfo(
  87. fromDefaults: .getTestInstance(testName: #function),
  88. name: testModelName,
  89. appName: testApp.name
  90. ) {
  91. XCTAssertNotNil(localInfo)
  92. testRetrieveModelInfo(localInfo: localInfo)
  93. } else {
  94. XCTFail("Could not save model info locally.")
  95. }
  96. }
  97. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  98. guard let testApp = FirebaseApp.app() else {
  99. XCTFail("Default app was not configured.")
  100. return
  101. }
  102. let testModelName = "pose-detection"
  103. let modelInfoRetriever = ModelInfoRetriever(
  104. modelName: testModelName,
  105. projectID: testApp.options.projectID!,
  106. apiKey: testApp.options.apiKey!,
  107. installations: Installations.installations(app: testApp),
  108. appName: testApp.name,
  109. localModelInfo: localInfo
  110. )
  111. let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
  112. modelInfoRetriever.downloadModelInfo(completion: { result in
  113. switch result {
  114. case let .success(modelInfoResult):
  115. switch modelInfoResult {
  116. case .modelInfo:
  117. XCTFail("Local model info is already the latest and should not be set again.")
  118. case .notModified: break
  119. }
  120. case let .failure(error):
  121. XCTAssertNotNil(error)
  122. XCTFail("Failed to retrieve model info - \(error)")
  123. }
  124. retrieveExpectation.fulfill()
  125. })
  126. waitForExpectations(timeout: 5, handler: nil)
  127. }
  128. /// Test to download model file - makes an actual network call.
  129. func testModelDownload() throws {
  130. guard let testApp = FirebaseApp.app() else {
  131. XCTFail("Default app was not configured.")
  132. return
  133. }
  134. let testName = #function.dropLast(2)
  135. let testModelName = "\(testName)-test-model"
  136. let urlString =
  137. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  138. let url = URL(string: urlString)!
  139. let remoteModelInfo = RemoteModelInfo(
  140. name: testModelName,
  141. downloadURL: url,
  142. modelHash: "mock-valid-hash",
  143. size: 10,
  144. urlExpiryTime: Date()
  145. )
  146. let conditions = ModelDownloadConditions()
  147. let expectation = self.expectation(description: "Wait for model to download.")
  148. let downloader = ModelFileDownloader(conditions: conditions)
  149. let modelDownloadManager = ModelDownloadTask(
  150. remoteModelInfo: remoteModelInfo,
  151. appName: testApp.name,
  152. defaults: .createTestInstance(testName: #function),
  153. downloader: downloader
  154. )
  155. modelDownloadManager.download(progressHandler: { progress in
  156. XCTAssertLessThanOrEqual(progress, 1)
  157. XCTAssertGreaterThanOrEqual(progress, 0)
  158. }) { result in
  159. switch result {
  160. case let .success(model):
  161. guard let modelPath = URL(string: model.path) else {
  162. XCTFail("Invalid or empty model path.")
  163. return
  164. }
  165. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  166. /// Remove downloaded model file.
  167. do {
  168. try ModelFileManager.removeFile(at: modelPath)
  169. } catch {
  170. XCTFail("Model removal failed - \(error)")
  171. }
  172. case let .failure(error):
  173. XCTFail("Error: \(error)")
  174. }
  175. expectation.fulfill()
  176. }
  177. waitForExpectations(timeout: 5, handler: nil)
  178. XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
  179. }
  180. func testGetModel() {
  181. guard let testApp = FirebaseApp.app() else {
  182. XCTFail("Default app was not configured.")
  183. return
  184. }
  185. let testName = String(#function.dropLast(2))
  186. let testModelName = "image-classification"
  187. let conditions = ModelDownloadConditions()
  188. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  189. .createTestInstance(testName: testName),
  190. app: testApp
  191. )
  192. /// Test download type - latest model.
  193. var downloadType: ModelDownloadType = .latestModel
  194. let latestModelExpectation = expectation(description: "Get latest model.")
  195. modelDownloader.getModel(
  196. name: testModelName,
  197. downloadType: downloadType,
  198. conditions: conditions,
  199. progressHandler: { progress in
  200. XCTAssertLessThanOrEqual(progress, 1)
  201. XCTAssertGreaterThanOrEqual(progress, 0)
  202. }
  203. ) { result in
  204. switch result {
  205. case let .success(model):
  206. XCTAssertNotNil(model.path)
  207. guard let modelPath = URL(string: model.path) else {
  208. XCTFail("Invalid model path.")
  209. return
  210. }
  211. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  212. case let .failure(error):
  213. XCTFail("Failed to download model - \(error)")
  214. }
  215. latestModelExpectation.fulfill()
  216. }
  217. waitForExpectations(timeout: 5, handler: nil)
  218. /// Test download type - local model.
  219. downloadType = .localModel
  220. let localModelExpectation = expectation(description: "Get local model.")
  221. modelDownloader.getModel(
  222. name: testModelName,
  223. downloadType: downloadType,
  224. conditions: conditions,
  225. progressHandler: { progress in
  226. XCTFail("Model is already available on device.")
  227. }
  228. ) { result in
  229. switch result {
  230. case let .success(model):
  231. XCTAssertNotNil(model.path)
  232. guard let modelPath = URL(string: model.path) else {
  233. XCTFail("Invalid model path.")
  234. return
  235. }
  236. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  237. case let .failure(error):
  238. XCTFail("Failed to download model - \(error)")
  239. }
  240. localModelExpectation.fulfill()
  241. }
  242. waitForExpectations(timeout: 5, handler: nil)
  243. /// Test download type - local model update in background.
  244. downloadType = .localModelUpdateInBackground
  245. let backgroundModelExpectation =
  246. expectation(description: "Get local model and update in background.")
  247. modelDownloader.getModel(
  248. name: testModelName,
  249. downloadType: downloadType,
  250. conditions: conditions,
  251. progressHandler: { progress in
  252. XCTFail("Model is already available on device.")
  253. }
  254. ) { result in
  255. switch result {
  256. case let .success(model):
  257. XCTAssertNotNil(model.path)
  258. guard let modelPath = URL(string: model.path) else {
  259. XCTFail("Invalid model path.")
  260. return
  261. }
  262. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  263. /// Remove downloaded model file.
  264. do {
  265. try ModelFileManager.removeFile(at: modelPath)
  266. } catch {
  267. XCTFail("Model removal failed - \(error)")
  268. }
  269. case let .failure(error):
  270. XCTFail("Failed to download model - \(error)")
  271. }
  272. backgroundModelExpectation.fulfill()
  273. }
  274. waitForExpectations(timeout: 5, handler: nil)
  275. }
  276. /// Delete previously downloaded model.
  277. func testDeleteModel() {
  278. guard let testApp = FirebaseApp.app() else {
  279. XCTFail("Default app was not configured.")
  280. return
  281. }
  282. let testName = String(#function.dropLast(2))
  283. let testModelName = "pose-detection"
  284. let conditions = ModelDownloadConditions()
  285. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  286. .createTestInstance(testName: testName),
  287. app: testApp
  288. )
  289. let downloadType: ModelDownloadType = .latestModel
  290. let latestModelExpectation = expectation(description: "Get latest model for deletion.")
  291. modelDownloader.getModel(
  292. name: testModelName,
  293. downloadType: downloadType,
  294. conditions: conditions,
  295. progressHandler: { progress in
  296. XCTAssertLessThanOrEqual(progress, 1)
  297. XCTAssertGreaterThanOrEqual(progress, 0)
  298. }
  299. ) { result in
  300. switch result {
  301. case let .success(model):
  302. XCTAssertNotNil(model.path)
  303. guard let filePath = URL(string: model.path) else {
  304. XCTFail("Invalid model path.")
  305. return
  306. }
  307. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  308. case let .failure(error):
  309. XCTFail("Failed to download model - \(error)")
  310. }
  311. latestModelExpectation.fulfill()
  312. }
  313. waitForExpectations(timeout: 5, handler: nil)
  314. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  315. switch result {
  316. case .success: break
  317. case let .failure(error):
  318. XCTFail("Failed to delete model - \(error)")
  319. }
  320. }
  321. }
  322. /// Test listing models in model directory.
  323. func testListModels() {
  324. guard let testApp = FirebaseApp.app() else {
  325. XCTFail("Default app was not configured.")
  326. return
  327. }
  328. let testName = String(#function.dropLast(2))
  329. let testModelName = "pose-detection"
  330. let conditions = ModelDownloadConditions()
  331. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  332. .createTestInstance(testName: testName),
  333. app: testApp
  334. )
  335. let downloadType: ModelDownloadType = .latestModel
  336. let latestModelExpectation = expectation(description: "Get latest model.")
  337. modelDownloader.getModel(
  338. name: testModelName,
  339. downloadType: downloadType,
  340. conditions: conditions,
  341. progressHandler: { progress in
  342. XCTAssertLessThanOrEqual(progress, 1)
  343. XCTAssertGreaterThanOrEqual(progress, 0)
  344. }
  345. ) { result in
  346. switch result {
  347. case let .success(model):
  348. XCTAssertNotNil(model.path)
  349. guard let filePath = URL(string: model.path) else {
  350. XCTFail("Invalid model path.")
  351. return
  352. }
  353. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  354. case let .failure(error):
  355. XCTFail("Failed to download model - \(error)")
  356. }
  357. latestModelExpectation.fulfill()
  358. }
  359. waitForExpectations(timeout: 5, handler: nil)
  360. modelDownloader.listDownloadedModels { result in
  361. switch result {
  362. case let .success(models):
  363. XCTAssertGreaterThan(models.count, 0)
  364. case let .failure(error):
  365. XCTFail("Failed to list models - \(error)")
  366. }
  367. }
  368. }
  369. /// Test logging telemetry event.
  370. func testLogTelemetryEvent() {
  371. guard let testApp = FirebaseApp.app() else {
  372. XCTFail("Default app was not configured.")
  373. return
  374. }
  375. let testModelName = "image-classification"
  376. let testName = #function
  377. let conditions = ModelDownloadConditions()
  378. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  379. .createTestInstance(testName: testName),
  380. app: testApp
  381. )
  382. let latestModelExpectation = expectation(description: "Test telemetry.")
  383. modelDownloader.getModel(
  384. name: testModelName,
  385. downloadType: .latestModel,
  386. conditions: conditions
  387. ) { result in
  388. switch result {
  389. case let .success(model):
  390. guard let telemetryLogger = TelemetryLogger(app: testApp) else {
  391. XCTFail("Could not initialize logger.")
  392. return
  393. }
  394. // TODO: Remove actual logging and stub out with mocks.
  395. telemetryLogger.logModelDownloadEvent(
  396. eventName: .modelDownload,
  397. status: .succeeded,
  398. model: model,
  399. downloadErrorCode: .noError
  400. )
  401. let modelPath = URL(string: model.path)!
  402. try? ModelFileManager.removeFile(at: modelPath)
  403. case let .failure(error):
  404. XCTFail("Failed to download model - \(error)")
  405. }
  406. latestModelExpectation.fulfill()
  407. }
  408. waitForExpectations(timeout: 5, handler: nil)
  409. }
  410. func testGetModelWithConditions() {
  411. guard let testApp = FirebaseApp.app() else {
  412. XCTFail("Default app was not configured.")
  413. return
  414. }
  415. let testModelName = "pose-detection"
  416. let testName = #function
  417. // TODO: Figure out a better way to test this.
  418. var conditions = ModelDownloadConditions()
  419. conditions.allowsCellularAccess = false
  420. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  421. .createTestInstance(testName: testName),
  422. app: testApp
  423. )
  424. let latestModelExpectation = expectation(description: "Get latest model with conditions.")
  425. modelDownloader.getModel(
  426. name: testModelName,
  427. downloadType: .latestModel,
  428. conditions: conditions,
  429. progressHandler: { progress in
  430. XCTAssertLessThanOrEqual(progress, 1)
  431. XCTAssertGreaterThanOrEqual(progress, 0)
  432. }
  433. ) { result in
  434. switch result {
  435. case let .success(model):
  436. XCTAssertNotNil(model.path)
  437. guard let filePath = URL(string: model.path) else {
  438. XCTFail("Invalid model path.")
  439. return
  440. }
  441. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  442. case let .failure(error):
  443. XCTFail("Failed to download model - \(error)")
  444. }
  445. latestModelExpectation.fulfill()
  446. }
  447. waitForExpectations(timeout: 5, handler: nil)
  448. }
  449. }