ModelDownloaderIntegrationTests.swift 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseInstallations
  17. @testable import FirebaseMLModelDownloader
  18. extension UserDefaults {
  19. /// Returns a new cleared instance of user defaults.
  20. static func createTestInstance(testName: String) -> UserDefaults {
  21. let suiteName = "com.google.firebase.ml.test.\(testName)"
  22. // TODO: reconsider force unwrapping
  23. let defaults = UserDefaults(suiteName: suiteName)!
  24. defaults.removePersistentDomain(forName: suiteName)
  25. return defaults
  26. }
  27. /// Returns the existing user defaults instance.
  28. static func getTestInstance(testName: String) -> UserDefaults {
  29. let suiteName = "com.google.firebase.ml.test.\(testName)"
  30. // TODO: reconsider force unwrapping
  31. return UserDefaults(suiteName: suiteName)!
  32. }
  33. }
  34. // TODO: Use FirebaseApp internal init for testApp
  35. final class ModelDownloaderIntegrationTests: XCTestCase {
  36. override class func setUp() {
  37. super.setUp()
  38. let bundle = Bundle(for: self)
  39. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  40. let options = FirebaseOptions(contentsOfFile: plistPath) {
  41. FirebaseApp.configure(options: options)
  42. } else {
  43. XCTFail("Could not locate GoogleService-Info.plist.")
  44. }
  45. }
  46. /// Test to download model info - makes an actual network call.
  47. func testDownloadModelInfo() {
  48. guard let testApp = FirebaseApp.app() else {
  49. XCTFail("Default app was not configured.")
  50. return
  51. }
  52. let testModelName = "pose-detection"
  53. let modelInfoRetriever = ModelInfoRetriever(
  54. modelName: testModelName,
  55. options: testApp.options,
  56. installations: Installations.installations(app: testApp),
  57. appName: testApp.name
  58. )
  59. let downloadExpectation = expectation(description: "Wait for model info to download.")
  60. modelInfoRetriever.downloadModelInfo(completion: { result in
  61. switch result {
  62. case let .success(modelInfoResult):
  63. switch modelInfoResult {
  64. case let .modelInfo(modelInfo):
  65. XCTAssertNotNil(modelInfo.urlExpiryTime)
  66. XCTAssertGreaterThan(modelInfo.downloadURL.absoluteString.count, 0)
  67. XCTAssertGreaterThan(modelInfo.modelHash.count, 0)
  68. XCTAssertGreaterThan(modelInfo.size, 0)
  69. let localModelInfo = LocalModelInfo(from: modelInfo, path: "mock-valid-path")
  70. localModelInfo.writeToDefaults(
  71. .createTestInstance(testName: #function),
  72. appName: testApp.name
  73. )
  74. case .notModified:
  75. XCTFail("Failed to retrieve model info.")
  76. }
  77. case let .failure(error):
  78. XCTAssertNotNil(error)
  79. XCTFail("Failed to retrieve model info - \(error)")
  80. }
  81. downloadExpectation.fulfill()
  82. })
  83. waitForExpectations(timeout: 5, handler: nil)
  84. if let localInfo = LocalModelInfo(
  85. fromDefaults: .getTestInstance(testName: #function),
  86. name: testModelName,
  87. appName: testApp.name
  88. ) {
  89. XCTAssertNotNil(localInfo)
  90. testRetrieveModelInfo(localInfo: localInfo)
  91. } else {
  92. XCTFail("Could not save model info locally.")
  93. }
  94. }
  95. func testRetrieveModelInfo(localInfo: LocalModelInfo) {
  96. guard let testApp = FirebaseApp.app() else {
  97. XCTFail("Default app was not configured.")
  98. return
  99. }
  100. let testModelName = "pose-detection"
  101. let modelInfoRetriever = ModelInfoRetriever(
  102. modelName: testModelName,
  103. options: testApp.options,
  104. installations: Installations.installations(app: testApp),
  105. appName: testApp.name,
  106. localModelInfo: localInfo
  107. )
  108. let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
  109. modelInfoRetriever.downloadModelInfo(completion: { result in
  110. switch result {
  111. case let .success(modelInfoResult):
  112. switch modelInfoResult {
  113. case .modelInfo:
  114. XCTFail("Local model info is already the latest and should not be set again.")
  115. case .notModified: break
  116. }
  117. case let .failure(error):
  118. XCTAssertNotNil(error)
  119. XCTFail("Failed to retrieve model info - \(error)")
  120. }
  121. retrieveExpectation.fulfill()
  122. })
  123. waitForExpectations(timeout: 5, handler: nil)
  124. }
  125. /// Test to download model file - makes an actual network call.
  126. func testResumeModelDownload() throws {
  127. guard let testApp = FirebaseApp.app() else {
  128. XCTFail("Default app was not configured.")
  129. return
  130. }
  131. let functionName = #function.dropLast(2)
  132. let testModelName = "\(functionName)-test-model"
  133. let urlString =
  134. "https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite"
  135. let url = URL(string: urlString)!
  136. let remoteModelInfo = RemoteModelInfo(
  137. name: testModelName,
  138. downloadURL: url,
  139. modelHash: "mock-valid-hash",
  140. size: 10,
  141. urlExpiryTime: Date()
  142. )
  143. let modelInfoRetriever = ModelInfoRetriever(
  144. modelName: testModelName,
  145. options: testApp.options,
  146. installations: Installations.installations(app: testApp),
  147. appName: testApp.name
  148. )
  149. let expectation = self.expectation(description: "Wait for model to download.")
  150. let modelDownloadManager = ModelDownloadTask(
  151. remoteModelInfo: remoteModelInfo,
  152. appName: testApp.name,
  153. defaults: .createTestInstance(testName: #function),
  154. modelInfoRetriever: modelInfoRetriever,
  155. progressHandler: { progress in
  156. XCTAssertLessThanOrEqual(progress, 1)
  157. XCTAssertGreaterThanOrEqual(progress, 0)
  158. }
  159. ) { result in
  160. switch result {
  161. case let .success(model):
  162. guard let modelPath = URL(string: model.path) else {
  163. XCTFail("Invalid or empty model path.")
  164. return
  165. }
  166. XCTAssertTrue(ModelFileManager.isFileReachable(at: modelPath))
  167. /// Remove downloaded model file.
  168. do {
  169. try ModelFileManager.removeFile(at: modelPath)
  170. } catch {
  171. XCTFail("Model removal failed - \(error)")
  172. }
  173. case let .failure(error):
  174. XCTFail("Error: \(error)")
  175. }
  176. expectation.fulfill()
  177. }
  178. modelDownloadManager.resumeModelDownload()
  179. waitForExpectations(timeout: 5, handler: nil)
  180. XCTAssertEqual(modelDownloadManager.downloadStatus, .successful)
  181. }
  182. func testGetModel() {
  183. guard let testApp = FirebaseApp.app() else {
  184. XCTFail("Default app was not configured.")
  185. return
  186. }
  187. let testName = "model-downloader-test"
  188. let testModelName = "image-classification"
  189. let conditions = ModelDownloadConditions()
  190. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  191. .createTestInstance(testName: testName),
  192. app: testApp
  193. )
  194. /// Test download type - latest model.
  195. var downloadType: ModelDownloadType = .latestModel
  196. let latestModelExpectation = expectation(description: "Get latest model.")
  197. modelDownloader.getModel(
  198. name: testModelName,
  199. downloadType: downloadType,
  200. conditions: conditions,
  201. progressHandler: { progress in
  202. XCTAssertLessThanOrEqual(progress, 1)
  203. XCTAssertGreaterThanOrEqual(progress, 0)
  204. }
  205. ) { result in
  206. switch result {
  207. case let .success(model):
  208. XCTAssertNotNil(model.path)
  209. guard let filePath = URL(string: model.path) else {
  210. XCTFail("Invalid model path.")
  211. return
  212. }
  213. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  214. case let .failure(error):
  215. XCTFail("Failed to download model - \(error)")
  216. }
  217. latestModelExpectation.fulfill()
  218. }
  219. waitForExpectations(timeout: 5, handler: nil)
  220. /// Test download type - local model.
  221. downloadType = .localModel
  222. let localModelExpectation = expectation(description: "Get local model.")
  223. modelDownloader.getModel(
  224. name: testModelName,
  225. downloadType: downloadType,
  226. conditions: conditions,
  227. progressHandler: { progress in
  228. XCTFail("Model is already available on device.")
  229. }
  230. ) { result in
  231. switch result {
  232. case let .success(model):
  233. XCTAssertNotNil(model.path)
  234. guard let filePath = URL(string: model.path) else {
  235. XCTFail("Invalid model path.")
  236. return
  237. }
  238. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  239. case let .failure(error):
  240. XCTFail("Failed to download model - \(error)")
  241. }
  242. localModelExpectation.fulfill()
  243. }
  244. waitForExpectations(timeout: 5, handler: nil)
  245. /// Test download type - local model update in background.
  246. downloadType = .localModelUpdateInBackground
  247. let backgroundModelExpectation =
  248. expectation(description: "Get local model and update in background.")
  249. modelDownloader.getModel(
  250. name: testModelName,
  251. downloadType: downloadType,
  252. conditions: conditions,
  253. progressHandler: { progress in
  254. XCTFail("Model is already available on device.")
  255. }
  256. ) { result in
  257. switch result {
  258. case let .success(model):
  259. XCTAssertNotNil(model.path)
  260. guard let filePath = URL(string: model.path) else {
  261. XCTFail("Invalid model path.")
  262. return
  263. }
  264. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  265. case let .failure(error):
  266. XCTFail("Failed to download model - \(error)")
  267. }
  268. backgroundModelExpectation.fulfill()
  269. }
  270. waitForExpectations(timeout: 5, handler: nil)
  271. }
  272. /// Delete previously downloaded model.
  273. func testDeleteModel() {
  274. guard let testApp = FirebaseApp.app() else {
  275. XCTFail("Default app was not configured.")
  276. return
  277. }
  278. let testName = "model-downloader-test"
  279. let testModelName = "pose-detection"
  280. let conditions = ModelDownloadConditions()
  281. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  282. .getTestInstance(testName: testName),
  283. app: testApp
  284. )
  285. let downloadType: ModelDownloadType = .latestModel
  286. let latestModelExpectation = expectation(description: "Get latest model.")
  287. modelDownloader.getModel(
  288. name: testModelName,
  289. downloadType: downloadType,
  290. conditions: conditions,
  291. progressHandler: { progress in
  292. XCTAssertLessThanOrEqual(progress, 1)
  293. XCTAssertGreaterThanOrEqual(progress, 0)
  294. }
  295. ) { result in
  296. switch result {
  297. case let .success(model):
  298. XCTAssertNotNil(model.path)
  299. guard let filePath = URL(string: model.path) else {
  300. XCTFail("Invalid model path.")
  301. return
  302. }
  303. XCTAssertTrue(ModelFileManager.isFileReachable(at: filePath))
  304. case let .failure(error):
  305. XCTFail("Failed to download model - \(error)")
  306. }
  307. latestModelExpectation.fulfill()
  308. }
  309. waitForExpectations(timeout: 5, handler: nil)
  310. modelDownloader.deleteDownloadedModel(name: testModelName) { result in
  311. switch result {
  312. case .success: break
  313. case let .failure(error):
  314. XCTFail("Failed to delete model - \(error)")
  315. }
  316. }
  317. }
  318. /// Test listing models in model directory.
  319. func testListModels() {
  320. guard let testApp = FirebaseApp.app() else {
  321. XCTFail("Default app was not configured.")
  322. return
  323. }
  324. let testName = "model-downloader-test"
  325. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  326. .getTestInstance(testName: testName),
  327. app: testApp
  328. )
  329. modelDownloader.listDownloadedModels { result in
  330. switch result {
  331. case let .success(models):
  332. XCTAssertGreaterThan(models.count, 0)
  333. case let .failure(error):
  334. XCTFail("Failed to list models - \(error)")
  335. }
  336. }
  337. }
  338. /// Test logging telemetry event.
  339. func testLogTelemetryEvent() {
  340. guard let testApp = FirebaseApp.app() else {
  341. XCTFail("Default app was not configured.")
  342. return
  343. }
  344. let testModelName = "image-classification"
  345. let testName = "model-downloader-test"
  346. let conditions = ModelDownloadConditions()
  347. let modelDownloader = ModelDownloader.modelDownloaderWithDefaults(
  348. .createTestInstance(testName: testName),
  349. app: testApp
  350. )
  351. /// Test download type - latest model.
  352. let downloadType: ModelDownloadType = .latestModel
  353. let latestModelExpectation = expectation(description: "Get latest model.")
  354. modelDownloader.getModel(
  355. name: testModelName,
  356. downloadType: downloadType,
  357. conditions: conditions,
  358. progressHandler: { progress in
  359. XCTAssertLessThanOrEqual(progress, 1)
  360. XCTAssertGreaterThanOrEqual(progress, 0)
  361. }
  362. ) { result in
  363. switch result {
  364. case let .success(model):
  365. guard let telemetryLogger = TelemetryLogger(app: testApp) else {
  366. XCTFail("Could not initialize logger.")
  367. return
  368. }
  369. // TODO: Remove this and stub out with mocks.
  370. telemetryLogger.logModelDownloadEvent(
  371. eventName: .modelDownload,
  372. status: .successful,
  373. model: model
  374. )
  375. case let .failure(error):
  376. XCTFail("Failed to download model - \(error)")
  377. }
  378. latestModelExpectation.fulfill()
  379. }
  380. waitForExpectations(timeout: 5, handler: nil)
  381. }
  382. }