GenerativeModelTestUtil.swift 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. // Copyright 2025 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. import XCTest
  19. @testable import FirebaseAI
  20. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  21. enum GenerativeModelTestUtil {
  22. /// Returns an HTTP request handler
  23. static func httpRequestHandler(forResource name: String,
  24. withExtension ext: String,
  25. subdirectory subpath: String,
  26. statusCode: Int = 200,
  27. timeout: TimeInterval = RequestOptions().timeout,
  28. appCheckToken: String? = nil,
  29. authToken: String? = nil,
  30. dataCollection: Bool = true,
  31. isTemplateRequest: Bool = false) throws
  32. -> ((URLRequest) throws -> (
  33. URLResponse,
  34. AsyncLineSequence<URL.AsyncBytes>?
  35. )) {
  36. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  37. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  38. #if os(watchOS)
  39. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  40. #else // os(watchOS)
  41. let bundle = BundleTestUtil.bundle()
  42. let fileURL = try XCTUnwrap(
  43. bundle.url(forResource: name, withExtension: ext, subdirectory: subpath)
  44. )
  45. return { request in
  46. let requestURL = try XCTUnwrap(request.url)
  47. if isTemplateRequest {
  48. XCTAssertEqual(
  49. requestURL.path.occurrenceCount(of: "templates/test-template.prompt:template"),
  50. 1
  51. )
  52. } else {
  53. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  54. }
  55. XCTAssertEqual(request.timeoutInterval, timeout)
  56. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  57. .components(separatedBy: " ")
  58. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  59. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  60. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  61. let firebaseAppID = request.value(forHTTPHeaderField: "X-Firebase-AppId")
  62. let appVersion = request.value(forHTTPHeaderField: "X-Firebase-AppVersion")
  63. let expectedAppVersion =
  64. try? XCTUnwrap(Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String)
  65. XCTAssertEqual(firebaseAppID, dataCollection ? "My app ID" : nil)
  66. XCTAssertEqual(appVersion, dataCollection ? expectedAppVersion : nil)
  67. if let authToken {
  68. XCTAssertEqual(
  69. request.value(forHTTPHeaderField: "Authorization"),
  70. "Firebase \(authToken)"
  71. )
  72. } else {
  73. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  74. }
  75. let response = try XCTUnwrap(HTTPURLResponse(
  76. url: requestURL,
  77. statusCode: statusCode,
  78. httpVersion: nil,
  79. headerFields: nil
  80. ))
  81. return (response, fileURL.lines)
  82. }
  83. #endif // os(watchOS)
  84. }
  85. static func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  86. URLResponse,
  87. AsyncLineSequence<URL.AsyncBytes>?
  88. )) {
  89. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  90. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  91. #if os(watchOS)
  92. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  93. #else // os(watchOS)
  94. return { request in
  95. // This is *not* an HTTPURLResponse
  96. let response = URLResponse(
  97. url: request.url!,
  98. mimeType: nil,
  99. expectedContentLength: 0,
  100. textEncodingName: nil
  101. )
  102. return (response, nil)
  103. }
  104. #endif // os(watchOS)
  105. }
  106. static func testFirebaseInfo(appCheck: AppCheckInterop? = nil,
  107. auth: AuthInterop? = nil,
  108. privateAppID: Bool = false,
  109. useLimitedUseAppCheckTokens: Bool = false) -> FirebaseInfo {
  110. let app = FirebaseApp(instanceWithName: "testApp",
  111. options: FirebaseOptions(googleAppID: "ignore",
  112. gcmSenderID: "ignore"))
  113. app.isDataCollectionDefaultEnabled = !privateAppID
  114. return FirebaseInfo(
  115. appCheck: appCheck,
  116. auth: auth,
  117. projectID: "my-project-id",
  118. apiKey: "API_KEY",
  119. firebaseAppID: "My app ID",
  120. firebaseApp: app,
  121. useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens
  122. )
  123. }
  124. }
  125. private extension String {
  126. /// Returns the number of occurrences of `substring` in the `String`.
  127. func occurrenceCount(of substring: String) -> Int {
  128. return components(separatedBy: substring).count - 1
  129. }
  130. }