ModelDownloaderIntegrationTests.swift 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. extension UserDefaults {
  19. /// Returns a new cleared instance of user defaults.
  20. static func createTestInstance(testName: String) -> UserDefaults {
  21. let suiteName = "com.google.firebase.ml.test.\(testName)"
  22. // TODO: reconsider force unwrapping
  23. let defaults = UserDefaults(suiteName: suiteName)!
  24. defaults.removePersistentDomain(forName: suiteName)
  25. return defaults
  26. }
  27. /// Returns the existing user defaults instance.
  28. static func getTestInstance(testName: String) -> UserDefaults {
  29. let suiteName = "com.google.firebase.ml.test.\(testName)"
  30. // TODO: reconsider force unwrapping
  31. return UserDefaults(suiteName: suiteName)!
  32. }
  33. }
  34. // TODO: Use FirebaseApp internal init for testApp
  35. final class ModelDownloaderIntegrationTests: XCTestCase {
  36. override class func setUp() {
  37. super.setUp()
  38. let bundle = Bundle(for: self)
  39. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  40. let options = FirebaseOptions(contentsOfFile: plistPath) {
  41. FirebaseApp.configure(options: options)
  42. } else {
  43. XCTFail("Could not locate GoogleService-Info.plist.")
  44. }
  45. FirebaseConfiguration.shared.setLoggerLevel(.debug)
  46. }
  47. /// Test to download model info - makes an actual network call.
  48. func testDownloadModelInfo() {
  49. guard let testApp = FirebaseApp.app() else {
  50. XCTFail("Default app was not configured.")
  51. return
  52. }
  53. let testModelName = "pose-detection"
  54. let modelInfoRetriever = ModelInfoRetriever(
  55. modelName: testModelName,
  56. projectID: testApp.options.projectID!,
  57. apiKey: testApp.options.apiKey!,
  58. installations: Installations.installations(app: testApp),
  59. appName: testApp.name
  60. )
  61. let downloadExpectation = expectation(description: "Wait for model info to download.")
  62. modelInfoRetriever.downloadModelInfo(completion: { result in
  63. switch result {
  64. case let .success(modelInfoResult):
  65. switch modelInfoResult {
  66. case let .modelInfo(modelInfo):
  67. XCTAssertNotNil(modelInfo.urlExpiryTime)
  68. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  69. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  70. XCTAssertGreaterThan(modelInfo.size, 0)
  71. let localModelInfo = LocalModelInfo(from: modelInfo, path: "mock-valid-path")
  72. localModelInfo.writeToDefaults(
  73. .createTestInstance(testName: #function),
  74. appName: testApp.name
  75. )
  76. case .notModified:
  77. XCTFail("Failed to retrieve model info.")
  78. }
  79. case let .failure(error):
  80. XCTAssertNotNil(error)
  81. XCTFail("Failed to retrieve model info - \(error)")
  82. }
  83. downloadExpectation.fulfill()
  84. })
  85. waitForExpectations(timeout: 5, handler: nil)
  86. if let localInfo = LocalModelInfo(
  87. fromDefaults: .getTestInstance(testName: #function),
  88. name: testModelName,
  89. appName: testApp.name
  90. ) {
  91. XCTAssertNotNil(localInfo)
  92. testRetrieveModelInfo(localInfo: localInfo)
  93. } else {
  94. XCTFail("Could not save model info locally.")
  95. }
  96. }
  97. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  98. guard let testApp = FirebaseApp.app() else {
  99. XCTFail("Default app was not configured.")
  100. return
  101. }
  102. let testModelName = "pose-detection"
  103. let modelInfoRetriever = ModelInfoRetriever(
  104. modelName: testModelName,
  105. projectID: testApp.options.projectID!,
  106. apiKey: testApp.options.apiKey!,
  107. installations: Installations.installations(app: testApp),
  108. appName: testApp.name,
  109. localModelInfo: localInfo
  110. )
  111. let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
  112. modelInfoRetriever.downloadModelInfo(completion: { result in
  113. switch result {
  114. case let .success(modelInfoResult):
  115. switch modelInfoResult {
  116. case .modelInfo:
  117. XCTFail("Local model info is already the latest and should not be set again.")
  118. case .notModified: break
  119. }
  120. case let .failure(error):
  121. XCTAssertNotNil(error)
  122. XCTFail("Failed to retrieve model info - \(error)")
  123. }
  124. retrieveExpectation.fulfill()
  125. })
  126. waitForExpectations(timeout: 5, handler: nil)
  127. }
  128. /// Test to download model file - makes an actual network call.
  129. func testModelDownload() throws {
  130. guard let testApp = FirebaseApp.app() else {
  131. XCTFail("Default app was not configured.")
  132. return
  133. }
  134. let testName = #function.dropLast(2)
  135. let testModelName = "\(testName)-test-model"
  136. let urlString =
  137. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  138. let url = URL(string: urlString)!
  139. let remoteModelInfo = RemoteModelInfo(
  140. name: testModelName,
  141. downloadURL: url,
  142. modelHash: "mock-valid-hash",
  143. size: 10,
  144. urlExpiryTime: Date()
  145. )
  146. let conditions = ModelDownloadConditions()
  147. let expectation = self.expectation(description: "Wait for model to download.")
  148. let downloader = ModelFileDownloader(conditions: conditions)
  149. let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
  150. XCTAssertLessThanOrEqual(progress, 1)
  151. XCTAssertGreaterThanOrEqual(progress, 0)
  152. }
  153. let taskCompletion: ModelDownloadTask.Completion = { result in
  154. switch result {
  155. case let .success(model):
  156. guard let modelPath = URL(string: model.path) else {
  157. XCTFail("Invalid or empty model path.")
  158. return
  159. }
  160. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  161. /// Remove downloaded model file.
  162. do {
  163. try ModelFileManager.removeFile(at: modelPath)
  164. } catch {
  165. XCTFail("Model removal failed - \(error)")
  166. }
  167. case let .failure(error):
  168. XCTFail("Error: \(error)")
  169. }
  170. expectation.fulfill()
  171. }
  172. let modelDownloadManager = ModelDownloadTask(
  173. remoteModelInfo: remoteModelInfo,
  174. appName: testApp.name,
  175. defaults: .createTestInstance(testName: #function),
  176. downloader: downloader,
  177. progressHandler: taskProgressHandler,
  178. completion: taskCompletion
  179. )
  180. modelDownloadManager.resume()
  181. waitForExpectations(timeout: 5, handler: nil)
  182. XCTAssertEqual(modelDownloadManager.downloadStatus, .complete)
  183. }
  184. func testGetModel() {
  185. guard let testApp = FirebaseApp.app() else {
  186. XCTFail("Default app was not configured.")
  187. return
  188. }
  189. let testName = String(#function.dropLast(2))
  190. let testModelName = "image-classification"
  191. let conditions = ModelDownloadConditions()
  192. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  193. .createTestInstance(testName: testName),
  194. app: testApp
  195. )
  196. /// Test download type - latest model.
  197. var downloadType: ModelDownloadType = .latestModel
  198. let latestModelExpectation = expectation(description: "Get latest model.")
  199. modelDownloader.getModel(
  200. name: testModelName,
  201. downloadType: downloadType,
  202. conditions: conditions,
  203. progressHandler: { progress in
  204. XCTAssertLessThanOrEqual(progress, 1)
  205. XCTAssertGreaterThanOrEqual(progress, 0)
  206. }
  207. ) { result in
  208. switch result {
  209. case let .success(model):
  210. XCTAssertNotNil(model.path)
  211. guard let modelPath = URL(string: model.path) else {
  212. XCTFail("Invalid model path.")
  213. return
  214. }
  215. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  216. case let .failure(error):
  217. XCTFail("Failed to download model - \(error)")
  218. }
  219. latestModelExpectation.fulfill()
  220. }
  221. waitForExpectations(timeout: 5, handler: nil)
  222. /// Test download type - local model.
  223. downloadType = .localModel
  224. let localModelExpectation = expectation(description: "Get local model.")
  225. modelDownloader.getModel(
  226. name: testModelName,
  227. downloadType: downloadType,
  228. conditions: conditions,
  229. progressHandler: { progress in
  230. XCTFail("Model is already available on device.")
  231. }
  232. ) { result in
  233. switch result {
  234. case let .success(model):
  235. XCTAssertNotNil(model.path)
  236. guard let modelPath = URL(string: model.path) else {
  237. XCTFail("Invalid model path.")
  238. return
  239. }
  240. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  241. case let .failure(error):
  242. XCTFail("Failed to download model - \(error)")
  243. }
  244. localModelExpectation.fulfill()
  245. }
  246. waitForExpectations(timeout: 5, handler: nil)
  247. /// Test download type - local model update in background.
  248. downloadType = .localModelUpdateInBackground
  249. let backgroundModelExpectation =
  250. expectation(description: "Get local model and update in background.")
  251. modelDownloader.getModel(
  252. name: testModelName,
  253. downloadType: downloadType,
  254. conditions: conditions,
  255. progressHandler: { progress in
  256. XCTFail("Model is already available on device.")
  257. }
  258. ) { result in
  259. switch result {
  260. case let .success(model):
  261. XCTAssertNotNil(model.path)
  262. guard let modelPath = URL(string: model.path) else {
  263. XCTFail("Invalid model path.")
  264. return
  265. }
  266. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  267. /// Remove downloaded model file.
  268. do {
  269. try ModelFileManager.removeFile(at: modelPath)
  270. } catch {
  271. XCTFail("Model removal failed - \(error)")
  272. }
  273. case let .failure(error):
  274. XCTFail("Failed to download model - \(error)")
  275. }
  276. backgroundModelExpectation.fulfill()
  277. }
  278. waitForExpectations(timeout: 5, handler: nil)
  279. }
  280. /// Delete previously downloaded model.
  281. func testDeleteModel() {
  282. guard let testApp = FirebaseApp.app() else {
  283. XCTFail("Default app was not configured.")
  284. return
  285. }
  286. let testName = String(#function.dropLast(2))
  287. let testModelName = "pose-detection"
  288. let conditions = ModelDownloadConditions()
  289. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  290. .createTestInstance(testName: testName),
  291. app: testApp
  292. )
  293. let downloadType: ModelDownloadType = .latestModel
  294. let latestModelExpectation = expectation(description: "Get latest model for deletion.")
  295. modelDownloader.getModel(
  296. name: testModelName,
  297. downloadType: downloadType,
  298. conditions: conditions,
  299. progressHandler: { progress in
  300. XCTAssertLessThanOrEqual(progress, 1)
  301. XCTAssertGreaterThanOrEqual(progress, 0)
  302. }
  303. ) { result in
  304. switch result {
  305. case let .success(model):
  306. XCTAssertNotNil(model.path)
  307. guard let filePath = URL(string: model.path) else {
  308. XCTFail("Invalid model path.")
  309. return
  310. }
  311. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  312. case let .failure(error):
  313. XCTFail("Failed to download model - \(error)")
  314. }
  315. latestModelExpectation.fulfill()
  316. }
  317. waitForExpectations(timeout: 5, handler: nil)
  318. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  319. switch result {
  320. case .success: break
  321. case let .failure(error):
  322. XCTFail("Failed to delete model - \(error)")
  323. }
  324. }
  325. }
  326. /// Test listing models in model directory.
  327. func testListModels() {
  328. guard let testApp = FirebaseApp.app() else {
  329. XCTFail("Default app was not configured.")
  330. return
  331. }
  332. let testName = String(#function.dropLast(2))
  333. let testModelName = "pose-detection"
  334. let conditions = ModelDownloadConditions()
  335. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  336. .createTestInstance(testName: testName),
  337. app: testApp
  338. )
  339. let downloadType: ModelDownloadType = .latestModel
  340. let latestModelExpectation = expectation(description: "Get latest model.")
  341. modelDownloader.getModel(
  342. name: testModelName,
  343. downloadType: downloadType,
  344. conditions: conditions,
  345. progressHandler: { progress in
  346. XCTAssertLessThanOrEqual(progress, 1)
  347. XCTAssertGreaterThanOrEqual(progress, 0)
  348. }
  349. ) { result in
  350. switch result {
  351. case let .success(model):
  352. XCTAssertNotNil(model.path)
  353. guard let filePath = URL(string: model.path) else {
  354. XCTFail("Invalid model path.")
  355. return
  356. }
  357. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  358. case let .failure(error):
  359. XCTFail("Failed to download model - \(error)")
  360. }
  361. latestModelExpectation.fulfill()
  362. }
  363. waitForExpectations(timeout: 5, handler: nil)
  364. modelDownloader.listDownloadedModels { result in
  365. switch result {
  366. case let .success(models):
  367. XCTAssertGreaterThan(models.count, 0)
  368. case let .failure(error):
  369. XCTFail("Failed to list models - \(error)")
  370. }
  371. }
  372. }
  373. /// Test logging telemetry event.
  374. func testLogTelemetryEvent() {
  375. guard let testApp = FirebaseApp.app() else {
  376. XCTFail("Default app was not configured.")
  377. return
  378. }
  379. let testModelName = "image-classification"
  380. let testName = #function
  381. let conditions = ModelDownloadConditions()
  382. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  383. .createTestInstance(testName: testName),
  384. app: testApp
  385. )
  386. let latestModelExpectation = expectation(description: "Test telemetry.")
  387. modelDownloader.getModel(
  388. name: testModelName,
  389. downloadType: .latestModel,
  390. conditions: conditions
  391. ) { result in
  392. switch result {
  393. case let .success(model):
  394. guard let telemetryLogger = TelemetryLogger(app: testApp) else {
  395. XCTFail("Could not initialize logger.")
  396. return
  397. }
  398. // TODO: Remove actual logging and stub out with mocks.
  399. telemetryLogger.logModelDownloadEvent(
  400. eventName: .modelDownload,
  401. status: .succeeded,
  402. model: model,
  403. downloadErrorCode: .noError
  404. )
  405. let modelPath = URL(string: model.path)!
  406. try? ModelFileManager.removeFile(at: modelPath)
  407. case let .failure(error):
  408. XCTFail("Failed to download model - \(error)")
  409. }
  410. latestModelExpectation.fulfill()
  411. }
  412. waitForExpectations(timeout: 5, handler: nil)
  413. }
  414. func testGetModelWithConditions() {
  415. guard let testApp = FirebaseApp.app() else {
  416. XCTFail("Default app was not configured.")
  417. return
  418. }
  419. let testModelName = "pose-detection"
  420. let testName = #function
  421. // TODO: Figure out a better way to test this.
  422. var conditions = ModelDownloadConditions()
  423. conditions.allowsCellularAccess = false
  424. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  425. .createTestInstance(testName: testName),
  426. app: testApp
  427. )
  428. let latestModelExpectation = expectation(description: "Get latest model with conditions.")
  429. modelDownloader.getModel(
  430. name: testModelName,
  431. downloadType: .latestModel,
  432. conditions: conditions,
  433. progressHandler: { progress in
  434. XCTAssertLessThanOrEqual(progress, 1)
  435. XCTAssertGreaterThanOrEqual(progress, 0)
  436. }
  437. ) { result in
  438. switch result {
  439. case let .success(model):
  440. XCTAssertNotNil(model.path)
  441. guard let filePath = URL(string: model.path) else {
  442. XCTFail("Invalid model path.")
  443. return
  444. }
  445. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  446. case let .failure(error):
  447. XCTFail("Failed to download model - \(error)")
  448. }
  449. latestModelExpectation.fulfill()
  450. }
  451. waitForExpectations(timeout: 5, handler: nil)
  452. }
  453. }