ModelDownloaderUnitTests.swift 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseMLModelDownloader
  17. /// Mock options to configure default Firebase app.
  18. private enum MockOptions {
  19. static let appID = "1:123:ios:123abc"
  20. static let gcmSenderID = "mock-sender-id"
  21. static let projectID = "mock-project-id"
  22. static let apiKey = "ABcdEf-APIKeyWithValidFormat_0123456789"
  23. }
  24. extension UserDefaults {
  25. /// For testing: returns a new cleared instance of user defaults.
  26. static func getTestInstance() -> UserDefaults {
  27. let suiteName = "com.google.firebase.ml.test"
  28. // TODO: reconsider force unwrapping
  29. let defaults = UserDefaults(suiteName: suiteName)!
  30. defaults.removePersistentDomain(forName: suiteName)
  31. return defaults
  32. }
  33. }
  34. final class ModelDownloaderUnitTests: XCTestCase {
  35. override class func setUp() {
  36. let options = FirebaseOptions(
  37. googleAppID: MockOptions.appID,
  38. gcmSenderID: MockOptions.gcmSenderID
  39. )
  40. options.apiKey = MockOptions.apiKey
  41. options.projectID = MockOptions.projectID
  42. FirebaseApp.configure(options: options)
  43. }
  44. /// Unit test for reading and writing to user defaults.
  45. func testUserDefaults() {
  46. guard let testApp = FirebaseApp.app() else {
  47. XCTFail("Default app was not configured.")
  48. return
  49. }
  50. let functionName = #function
  51. let testModelName = "\(functionName)-test-model"
  52. let modelInfo = ModelInfo(
  53. app: testApp,
  54. name: testModelName,
  55. defaults: .getTestInstance()
  56. )
  57. XCTAssertEqual(modelInfo.downloadURL, "")
  58. modelInfo.downloadURL = "testurl.com"
  59. XCTAssertEqual(modelInfo.downloadURL, "testurl.com")
  60. XCTAssertEqual(modelInfo.modelHash, "")
  61. XCTAssertEqual(modelInfo.size, 0)
  62. XCTAssertEqual(modelInfo.path, nil)
  63. }
  64. /// Test to download model info.
  65. // TODO: Add unit test with mocks.
  66. func testDownloadModelInfo() {}
  67. /// Unit test to save model info to user defaults.
  68. func testSaveModelInfo() {
  69. guard let testApp = FirebaseApp.app() else {
  70. XCTFail("Default app was not configured.")
  71. return
  72. }
  73. let functionName = #function
  74. let testModelName = "\(functionName)-test-model"
  75. let modelInfoRetriever = ModelInfoRetriever(
  76. app: testApp,
  77. modelName: testModelName
  78. )
  79. let sampleResponse: String = """
  80. {
  81. "downloadUri": "https://storage.googleapis.com",
  82. "expireTime": "2020-11-10T04:58:49.643Z",
  83. "sizeBytes": "562336"
  84. }
  85. """
  86. let data: Data = sampleResponse.data(using: .utf8)!
  87. modelInfoRetriever.saveModelInfo(data: data, modelHash: "test-model-hash")
  88. XCTAssertEqual(modelInfoRetriever.modelInfo?.downloadURL, "https://storage.googleapis.com")
  89. XCTAssertEqual(modelInfoRetriever.modelInfo?.size, 562_336)
  90. }
  91. func testExample() {
  92. // This is an example of a functional test case.
  93. // Use XCTAssert and related functions to verify your tests produce the correct
  94. // results.
  95. let modelDownloader = ModelDownloader()
  96. let conditions = ModelDownloadConditions()
  97. // Download model w/ progress handler
  98. modelDownloader.getModel(
  99. name: "your_model_name",
  100. downloadType: .latestModel,
  101. conditions: conditions,
  102. progressHandler: { progress in
  103. // Handle progress
  104. }
  105. ) { result in
  106. switch result {
  107. case .success:
  108. // Use model with your inference API
  109. // let interpreter = Interpreter(modelPath: customModel.modelPath)
  110. break
  111. case .failure:
  112. // Handle download error
  113. break
  114. }
  115. }
  116. // Access array of downloaded models
  117. modelDownloader.listDownloadedModels { result in
  118. switch result {
  119. case .success:
  120. // Pick model(s) for further use
  121. break
  122. case .failure:
  123. // Handle failure
  124. break
  125. }
  126. }
  127. // Delete downloaded model
  128. modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
  129. switch result {
  130. case .success():
  131. // Apply any other clean up
  132. break
  133. case .failure:
  134. // Handle failure
  135. break
  136. }
  137. }
  138. }
  139. }