ModelDownloaderIntegrationTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. extension UserDefaults {
  19. /// Returns a new cleared instance of user defaults.
  20. static func createTestInstance(testName: String) -> UserDefaults {
  21. let suiteName = "com.google.firebase.ml.test.\(testName)"
  22. // TODO: reconsider force unwrapping
  23. let defaults = UserDefaults(suiteName: suiteName)!
  24. defaults.removePersistentDomain(forName: suiteName)
  25. return defaults
  26. }
  27. /// Returns the existing user defaults instance.
  28. static func getTestInstance(testName: String) -> UserDefaults {
  29. let suiteName = "com.google.firebase.ml.test.\(testName)"
  30. // TODO: reconsider force unwrapping
  31. return UserDefaults(suiteName: suiteName)!
  32. }
  33. }
  34. // TODO: Use FirebaseApp internal init for testApp
  35. final class ModelDownloaderIntegrationTests: XCTestCase {
  36. override class func setUp() {
  37. super.setUp()
  38. let bundle = Bundle(for: self)
  39. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  40. let options = FirebaseOptions(contentsOfFile: plistPath) {
  41. FirebaseApp.configure(options: options)
  42. } else {
  43. XCTFail("Could not locate GoogleService-Info.plist.")
  44. }
  45. }
  46. /// Test to download model info - makes an actual network call.
  47. func testDownloadModelInfo() {
  48. guard let testApp = FirebaseApp.app() else {
  49. XCTFail("Default app was not configured.")
  50. return
  51. }
  52. let testModelName = "pose-detection"
  53. let modelInfoRetriever = ModelInfoRetriever(
  54. modelName: testModelName,
  55. options: testApp.options,
  56. installations: Installations.installations(app: testApp),
  57. appName: testApp.name
  58. )
  59. let downloadExpectation = expectation(description: "Wait for model info to download.")
  60. modelInfoRetriever.downloadModelInfo(completion: { result in
  61. switch result {
  62. case let .success(modelInfoResult):
  63. switch modelInfoResult {
  64. case let .modelInfo(modelInfo):
  65. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  66. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  67. XCTAssertGreaterThan(modelInfo.size, 0)
  68. let localModelInfo = LocalModelInfo(from: modelInfo, path: "mock-valid-path")
  69. localModelInfo.writeToDefaults(
  70. .createTestInstance(testName: #function),
  71. appName: testApp.name
  72. )
  73. case .notModified:
  74. XCTFail("Failed to retrieve model info.")
  75. }
  76. case let .failure(error):
  77. XCTAssertNotNil(error)
  78. XCTFail("Failed to retrieve model info - \(error)")
  79. }
  80. downloadExpectation.fulfill()
  81. })
  82. waitForExpectations(timeout: 5, handler: nil)
  83. if let localInfo = LocalModelInfo(
  84. fromDefaults: .getTestInstance(testName: #function),
  85. name: testModelName,
  86. appName: testApp.name
  87. ) {
  88. XCTAssertNotNil(localInfo)
  89. testRetrieveModelInfo(localInfo: localInfo)
  90. } else {
  91. XCTFail("Could not save model info locally.")
  92. }
  93. }
  94. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  95. guard let testApp = FirebaseApp.app() else {
  96. XCTFail("Default app was not configured.")
  97. return
  98. }
  99. let testModelName = "pose-detection"
  100. let modelInfoRetriever = ModelInfoRetriever(
  101. modelName: testModelName,
  102. options: testApp.options,
  103. installations: Installations.installations(app: testApp),
  104. appName: testApp.name,
  105. localModelInfo: localInfo
  106. )
  107. // TODO: This check seems to be flaky.
  108. let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
  109. modelInfoRetriever.downloadModelInfo(completion: { result in
  110. switch result {
  111. case let .success(modelInfoResult):
  112. switch modelInfoResult {
  113. case .modelInfo:
  114. XCTFail("Local model info is already the latest and should not be set again.")
  115. case .notModified: break
  116. }
  117. case let .failure(error):
  118. XCTAssertNotNil(error)
  119. XCTFail("Failed to retrieve model info - \(error)")
  120. }
  121. retrieveExpectation.fulfill()
  122. })
  123. waitForExpectations(timeout: 5, handler: nil)
  124. }
  125. /// Test to download model file - makes an actual network call.
  126. func testResumeModelDownload() throws {
  127. guard let testApp = FirebaseApp.app() else {
  128. XCTFail("Default app was not configured.")
  129. return
  130. }
  131. let functionName = #function.dropLast(2)
  132. let testModelName = "\(functionName)-test-model"
  133. let urlString =
  134. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  135. let url = URL(string: urlString)!
  136. let remoteModelInfo = RemoteModelInfo(
  137. name: testModelName,
  138. downloadURL: url,
  139. modelHash: "mock-valid-hash",
  140. size: 10
  141. )
  142. let expectation = self.expectation(description: "Wait for model to download.")
  143. let modelDownloadManager = ModelDownloadTask(
  144. remoteModelInfo: remoteModelInfo,
  145. appName: testApp.name,
  146. defaults: .createTestInstance(testName: #function),
  147. progressHandler: { progress in
  148. XCTAssertLessThanOrEqual(progress, 1)
  149. XCTAssertGreaterThanOrEqual(progress, 0)
  150. }
  151. ) { result in
  152. switch result {
  153. case let .success(model):
  154. guard let modelPath = URL(string: model.path) else {
  155. XCTFail("Invalid or empty model path.")
  156. return
  157. }
  158. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  159. /// Remove downloaded model file.
  160. do {
  161. try ModelFileManager.removeFile(at: modelPath)
  162. } catch {
  163. XCTFail("Model removal failed - \(error)")
  164. }
  165. case let .failure(error):
  166. XCTFail("Error: \(error)")
  167. }
  168. expectation.fulfill()
  169. }
  170. modelDownloadManager.resumeModelDownload()
  171. waitForExpectations(timeout: 5, handler: nil)
  172. XCTAssertEqual(modelDownloadManager.downloadStatus, .successful)
  173. }
  174. func testGetModel() {
  175. guard let testApp = FirebaseApp.app() else {
  176. XCTFail("Default app was not configured.")
  177. return
  178. }
  179. let testName = "model-downloader-test"
  180. let testModelName = "image-classification"
  181. let conditions = ModelDownloadConditions()
  182. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  183. .createTestInstance(testName: testName),
  184. app: testApp
  185. )
  186. /// Test download type - latest model.
  187. var downloadType: ModelDownloadType = .latestModel
  188. let latestModelExpectation = expectation(description: "Get latest model.")
  189. modelDownloader.getModel(
  190. name: testModelName,
  191. downloadType: downloadType,
  192. conditions: conditions,
  193. progressHandler: { progress in
  194. XCTAssertLessThanOrEqual(progress, 1)
  195. XCTAssertGreaterThanOrEqual(progress, 0)
  196. }
  197. ) { result in
  198. switch result {
  199. case let .success(model):
  200. XCTAssertNotNil(model.path)
  201. guard let filePath = URL(string: model.path) else {
  202. XCTFail("Invalid model path.")
  203. return
  204. }
  205. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  206. case let .failure(error):
  207. XCTFail("Failed to download model - \(error)")
  208. }
  209. latestModelExpectation.fulfill()
  210. }
  211. waitForExpectations(timeout: 5, handler: nil)
  212. /// Test download type - local model.
  213. downloadType = .localModel
  214. let localModelExpectation = expectation(description: "Get local model.")
  215. modelDownloader.getModel(
  216. name: testModelName,
  217. downloadType: downloadType,
  218. conditions: conditions,
  219. progressHandler: { progress in
  220. XCTFail("Model is already available on device.")
  221. }
  222. ) { result in
  223. switch result {
  224. case let .success(model):
  225. XCTAssertNotNil(model.path)
  226. guard let filePath = URL(string: model.path) else {
  227. XCTFail("Invalid model path.")
  228. return
  229. }
  230. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  231. case let .failure(error):
  232. XCTFail("Failed to download model - \(error)")
  233. }
  234. localModelExpectation.fulfill()
  235. }
  236. waitForExpectations(timeout: 5, handler: nil)
  237. /// Test download type - local model update in background.
  238. downloadType = .localModelUpdateInBackground
  239. let backgroundModelExpectation =
  240. expectation(description: "Get local model and update in background.")
  241. modelDownloader.getModel(
  242. name: testModelName,
  243. downloadType: downloadType,
  244. conditions: conditions,
  245. progressHandler: { progress in
  246. XCTFail("Model is already available on device.")
  247. }
  248. ) { result in
  249. switch result {
  250. case let .success(model):
  251. XCTAssertNotNil(model.path)
  252. guard let filePath = URL(string: model.path) else {
  253. XCTFail("Invalid model path.")
  254. return
  255. }
  256. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  257. case let .failure(error):
  258. XCTFail("Failed to download model - \(error)")
  259. }
  260. backgroundModelExpectation.fulfill()
  261. }
  262. waitForExpectations(timeout: 5, handler: nil)
  263. }
  264. /// Delete previously downloaded model.
  265. func testDeleteModel() {
  266. guard let testApp = FirebaseApp.app() else {
  267. XCTFail("Default app was not configured.")
  268. return
  269. }
  270. let testName = "model-downloader-test"
  271. let testModelName = "pose-detection"
  272. let conditions = ModelDownloadConditions()
  273. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  274. .getTestInstance(testName: testName),
  275. app: testApp
  276. )
  277. let downloadType: ModelDownloadType = .latestModel
  278. let latestModelExpectation = expectation(description: "Get latest model.")
  279. modelDownloader.getModel(
  280. name: testModelName,
  281. downloadType: downloadType,
  282. conditions: conditions,
  283. progressHandler: { progress in
  284. XCTAssertLessThanOrEqual(progress, 1)
  285. XCTAssertGreaterThanOrEqual(progress, 0)
  286. }
  287. ) { result in
  288. switch result {
  289. case let .success(model):
  290. XCTAssertNotNil(model.path)
  291. guard let filePath = URL(string: model.path) else {
  292. XCTFail("Invalid model path.")
  293. return
  294. }
  295. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  296. case let .failure(error):
  297. XCTFail("Failed to download model - \(error)")
  298. }
  299. latestModelExpectation.fulfill()
  300. }
  301. waitForExpectations(timeout: 5, handler: nil)
  302. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  303. switch result {
  304. case .success: break
  305. case let .failure(error):
  306. XCTFail("Failed to delete model - \(error)")
  307. }
  308. }
  309. }
  310. /// Test listing models in model directory.
  311. func testListModels() {
  312. guard let testApp = FirebaseApp.app() else {
  313. XCTFail("Default app was not configured.")
  314. return
  315. }
  316. let testName = "model-downloader-test"
  317. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  318. .getTestInstance(testName: testName),
  319. app: testApp
  320. )
  321. modelDownloader.listDownloadedModels { result in
  322. switch result {
  323. case let .success(models):
  324. XCTAssertGreaterThan(models.count, 0)
  325. case let .failure(error):
  326. XCTFail("Failed to list models - \(error)")
  327. }
  328. }
  329. }
  330. /// Test logging telemetry event.
  331. func testLogTelemetryEvent() {
  332. guard let testApp = FirebaseApp.app() else {
  333. XCTFail("Default app was not configured.")
  334. return
  335. }
  336. let testModelName = "image-classification"
  337. let testName = "model-downloader-test"
  338. let conditions = ModelDownloadConditions()
  339. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  340. .createTestInstance(testName: testName),
  341. app: testApp
  342. )
  343. /// Test download type - latest model.
  344. let downloadType: ModelDownloadType = .latestModel
  345. let latestModelExpectation = expectation(description: "Get latest model.")
  346. modelDownloader.getModel(
  347. name: testModelName,
  348. downloadType: downloadType,
  349. conditions: conditions,
  350. progressHandler: { progress in
  351. XCTAssertLessThanOrEqual(progress, 1)
  352. XCTAssertGreaterThanOrEqual(progress, 0)
  353. }
  354. ) { result in
  355. switch result {
  356. case let .success(model):
  357. guard let telemetryLogger = TelemetryLogger(app: testApp) else {
  358. XCTFail("Could not initialize logger.")
  359. return
  360. }
  361. // TODO: Remove this and stub out with mocks.
  362. telemetryLogger.logModelDownloadEvent(
  363. eventName: .modelDownload,
  364. status: .successful,
  365. model: model
  366. )
  367. case let .failure(error):
  368. XCTFail("Failed to download model - \(error)")
  369. }
  370. latestModelExpectation.fulfill()
  371. }
  372. waitForExpectations(timeout: 5, handler: nil)
  373. }
  374. }