ModelDownloaderTests.swift 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseMLModelDownloader
  17. enum Constants {
  18. enum App {
  19. static let defaultName = "__FIRAPP_DEFAULT"
  20. static let googleAppIDKey = "FIRGoogleAppIDKey"
  21. static let nameKey = "FIRAppNameKey"
  22. static let isDefaultAppKey = "FIRAppIsDefaultAppKey"
  23. }
  24. enum Options {
  25. static let apiKey = "correct_api_key"
  26. static let bundleID = "com.google.FirebaseSDKTests"
  27. static let clientID = "correct_client_id"
  28. static let trackingID = "correct_tracking_id"
  29. static let gcmSenderID = "correct_gcm_sender_id"
  30. static let projectID = "correct_project_id"
  31. static let androidClientID = "correct_android_client_id"
  32. static let googleAppID = "correct_app_id"
  33. static let databaseURL = "https://abc-xyz-123.firebaseio.com"
  34. static let deepLinkURLScheme = "comgoogledeeplinkurl"
  35. static let storageBucket = "project-id-123.storage.firebase.com"
  36. static let appGroupID: String? = nil
  37. }
  38. }
  39. extension UserDefaults {
  40. /// For testing: returns a new cleared instance of user defaults.
  41. static func getTestInstance() -> UserDefaults {
  42. let suiteName = "com.google.firebase.ml.test"
  43. // TODO: reconsider force unwrapping
  44. let defaults = UserDefaults(suiteName: suiteName)!
  45. defaults.removePersistentDomain(forName: suiteName)
  46. return defaults
  47. }
  48. }
  49. final class ModelDownloaderTests: XCTestCase {
  50. override class func setUp() {
  51. super.setUp()
  52. let options = FirebaseOptions(googleAppID: Constants.Options.googleAppID,
  53. gcmSenderID: Constants.Options.gcmSenderID)
  54. options.apiKey = Constants.Options.apiKey
  55. options.projectID = Constants.Options.projectID
  56. options.clientID = Constants.Options.clientID
  57. // TODO: Replace with custom options
  58. FirebaseApp.configure()
  59. }
  60. /// Unit test for reading and writing to user defaults.
  61. func testUserDefaults() {
  62. let testApp = FirebaseApp.app()!
  63. let functionName = #function
  64. let testModelName = "\(functionName)-test-model"
  65. let modelInfoRetriever = ModelInfoRetriever(
  66. app: testApp,
  67. projectID: Constants.Options.projectID,
  68. modelName: testModelName
  69. )
  70. modelInfoRetriever.modelInfo = ModelInfo(
  71. app: testApp,
  72. name: testModelName,
  73. defaults: .getTestInstance()
  74. )
  75. XCTAssertEqual(modelInfoRetriever.modelInfo?.downloadURL, "")
  76. modelInfoRetriever.modelInfo?.downloadURL = "testurl.com"
  77. XCTAssertEqual(modelInfoRetriever.modelInfo?.downloadURL, "testurl.com")
  78. XCTAssertEqual(modelInfoRetriever.modelInfo?.modelHash, "")
  79. XCTAssertEqual(modelInfoRetriever.modelInfo?.size, 0)
  80. XCTAssertEqual(modelInfoRetriever.modelInfo?.path, nil)
  81. }
  82. func testDownloadModelInfo() {
  83. let testApp = FirebaseApp.app()!
  84. let functionName = #function
  85. let testModelName = "\(functionName)-test-model"
  86. let modelInfoRetriever = ModelInfoRetriever(
  87. app: testApp,
  88. projectID: Constants.Options.projectID,
  89. modelName: testModelName
  90. )
  91. let expectation = self.expectation(description: "Wait for model info to download.")
  92. modelInfoRetriever.downloadModelInfo(completion: { error in
  93. guard let downloadError = error else { return }
  94. XCTAssertEqual(downloadError, .notFound)
  95. print("ERROR: Model not found on server.")
  96. expectation.fulfill()
  97. })
  98. waitForExpectations(timeout: 10, handler: nil)
  99. }
  100. func testExample() {
  101. // This is an example of a functional test case.
  102. // Use XCTAssert and related functions to verify your tests produce the correct
  103. // results.
  104. let modelDownloader = ModelDownloader()
  105. let conditions = ModelDownloadConditions()
  106. // Download model w/ progress handler
  107. modelDownloader.getModel(
  108. name: "your_model_name",
  109. downloadType: .latestModel,
  110. conditions: conditions,
  111. progressHandler: { progress in
  112. // Handle progress
  113. }
  114. ) { result in
  115. switch result {
  116. case .success:
  117. // Use model with your inference API
  118. // let interpreter = Interpreter(modelPath: customModel.modelPath)
  119. break
  120. case .failure:
  121. // Handle download error
  122. break
  123. }
  124. }
  125. // Access array of downloaded models
  126. modelDownloader.listDownloadedModels { result in
  127. switch result {
  128. case .success:
  129. // Pick model(s) for further use
  130. break
  131. case .failure:
  132. // Handle failure
  133. break
  134. }
  135. }
  136. // Delete downloaded model
  137. modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
  138. switch result {
  139. case .success():
  140. // Apply any other clean up
  141. break
  142. case .failure:
  143. // Handle failure
  144. break
  145. }
  146. }
  147. }
  148. }