ModelDownloaderIntegrationTests.swift 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import XCTest
  15. @testable import FirebaseCore
  16. @testable import FirebaseMLModelDownloader
  17. extension UserDefaults {
  18. /// For testing: returns a new cleared instance of user defaults.
  19. static func getTestInstance() -> UserDefaults {
  20. let suiteName = "com.google.firebase.ml.test"
  21. // TODO: reconsider force unwrapping
  22. let defaults = UserDefaults(suiteName: suiteName)!
  23. defaults.removePersistentDomain(forName: suiteName)
  24. return defaults
  25. }
  26. }
  27. final class ModelDownloaderIntegrationTests: XCTestCase {
  28. override class func setUp() {
  29. super.setUp()
  30. let bundle = Bundle(for: self)
  31. if let plistPath = bundle.path(forResource: "GoogleService-Info", ofType: "plist"),
  32. let options = FirebaseOptions(contentsOfFile: plistPath) {
  33. FirebaseApp.configure(options: options)
  34. } else {
  35. XCTFail("Could not locate GoogleService-Info.plist.")
  36. }
  37. }
  38. /// Test to retrieve FIS token - makes an actual network call.
  39. func testGetAuthToken() {
  40. guard let testApp = FirebaseApp.app() else {
  41. XCTFail("Default app was not configured.")
  42. return
  43. }
  44. let testModelName = "image-classification"
  45. let modelInfoRetriever = ModelInfoRetriever(
  46. app: testApp,
  47. modelName: testModelName,
  48. defaults: .getTestInstance()
  49. )
  50. let expectation = self.expectation(description: "Wait for FIS auth token.")
  51. modelInfoRetriever.getAuthToken(completion: { result in
  52. switch result {
  53. case let .success(token):
  54. XCTAssertNotNil(token)
  55. case let .failure(error):
  56. XCTFail(error.localizedDescription)
  57. }
  58. expectation.fulfill()
  59. })
  60. waitForExpectations(timeout: 5, handler: nil)
  61. }
  62. /// Test to download model info - makes an actual network call.
  63. func testDownloadModelInfo() {
  64. guard let testApp = FirebaseApp.app() else {
  65. XCTFail("Default app was not configured.")
  66. return
  67. }
  68. let testModelName = "pose-detection"
  69. let modelInfoRetriever = ModelInfoRetriever(
  70. app: testApp,
  71. modelName: testModelName,
  72. defaults: .getTestInstance()
  73. )
  74. let downloadExpectation = expectation(description: "Wait for model info to download.")
  75. modelInfoRetriever.downloadModelInfo(completion: { error in
  76. XCTAssertNil(error)
  77. guard let modelInfo = modelInfoRetriever.modelInfo else {
  78. XCTFail("Empty model info.")
  79. return
  80. }
  81. XCTAssertNotEqual(modelInfo.downloadURL, "")
  82. XCTAssertNotEqual(modelInfo.modelHash, "")
  83. XCTAssertGreaterThan(modelInfo.size, 0)
  84. downloadExpectation.fulfill()
  85. })
  86. waitForExpectations(timeout: 5, handler: nil)
  87. let retrieveExpectation = expectation(description: "Wait for model info to be retrieved.")
  88. modelInfoRetriever.downloadModelInfo(completion: { error in
  89. XCTAssertNil(error)
  90. guard let modelInfo = modelInfoRetriever.modelInfo else {
  91. XCTFail("Empty model info.")
  92. return
  93. }
  94. XCTAssertNotEqual(modelInfo.downloadURL, "")
  95. XCTAssertNotEqual(modelInfo.modelHash, "")
  96. XCTAssertGreaterThan(modelInfo.size, 0)
  97. retrieveExpectation.fulfill()
  98. })
  99. waitForExpectations(timeout: 500, handler: nil)
  100. }
  101. }