GenerativeModelTestUtil.swift 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. // Copyright 2025 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import Foundation
  18. import XCTest
  19. @testable import FirebaseAILogic
  20. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  21. enum GenerativeModelTestUtil {
  22. /// Returns an HTTP request handler
  23. static func httpRequestHandler(forResource name: String,
  24. withExtension ext: String,
  25. subdirectory subpath: String,
  26. statusCode: Int = 200,
  27. timeout: TimeInterval = RequestOptions().timeout,
  28. appCheckToken: String? = nil,
  29. authToken: String? = nil,
  30. dataCollection: Bool = true,
  31. isTemplateRequest: Bool = false) throws
  32. -> ((URLRequest) throws -> (
  33. URLResponse,
  34. AsyncLineSequence<URL.AsyncBytes>?
  35. )) {
  36. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  37. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  38. #if os(watchOS)
  39. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  40. #else // os(watchOS)
  41. let bundle = BundleTestUtil.bundle()
  42. let fileURL = try XCTUnwrap(
  43. bundle.url(forResource: name, withExtension: ext, subdirectory: subpath)
  44. )
  45. return { request in
  46. let requestURL = try XCTUnwrap(request.url)
  47. if isTemplateRequest {
  48. XCTAssertEqual(
  49. requestURL.path.occurrenceCount(of: "templates/test-template:template"),
  50. 1
  51. )
  52. } else {
  53. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  54. }
  55. XCTAssertEqual(request.timeoutInterval, timeout)
  56. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  57. .components(separatedBy: " ")
  58. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  59. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  60. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  61. let firebaseAppID = request.value(forHTTPHeaderField: "X-Firebase-AppId")
  62. let appVersion = request.value(forHTTPHeaderField: "X-Firebase-AppVersion")
  63. let expectedAppVersion =
  64. try? XCTUnwrap(Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String)
  65. XCTAssertEqual(firebaseAppID, dataCollection ? "My app ID" : nil)
  66. XCTAssertEqual(appVersion, dataCollection ? expectedAppVersion : nil)
  67. if let authToken {
  68. XCTAssertEqual(
  69. request.value(forHTTPHeaderField: "Authorization"),
  70. "Firebase \(authToken)"
  71. )
  72. } else {
  73. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  74. }
  75. let response = try XCTUnwrap(HTTPURLResponse(
  76. url: requestURL,
  77. statusCode: statusCode,
  78. httpVersion: nil,
  79. headerFields: nil
  80. ))
  81. return (response, fileURL.lines)
  82. }
  83. #endif // os(watchOS)
  84. }
  85. static func collectTextFromStream(_ stream: AsyncThrowingStream<
  86. GenerateContentResponse,
  87. Error
  88. >) async throws -> String {
  89. var content = ""
  90. for try await response in stream {
  91. if let text = response.text {
  92. content += text
  93. }
  94. }
  95. return content
  96. }
  97. static func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  98. URLResponse,
  99. AsyncLineSequence<URL.AsyncBytes>?
  100. )) {
  101. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  102. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  103. #if os(watchOS)
  104. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  105. #else // os(watchOS)
  106. return { request in
  107. // This is *not* an HTTPURLResponse
  108. let response = URLResponse(
  109. url: request.url!,
  110. mimeType: nil,
  111. expectedContentLength: 0,
  112. textEncodingName: nil
  113. )
  114. return (response, nil)
  115. }
  116. #endif // os(watchOS)
  117. }
  118. static func testFirebaseInfo(appCheck: AppCheckInterop? = nil,
  119. auth: AuthInterop? = nil,
  120. privateAppID: Bool = false,
  121. useLimitedUseAppCheckTokens: Bool = false) -> FirebaseInfo {
  122. let app = FirebaseApp(instanceWithName: "testApp",
  123. options: FirebaseOptions(googleAppID: "ignore",
  124. gcmSenderID: "ignore"))
  125. app.isDataCollectionDefaultEnabled = !privateAppID
  126. return FirebaseInfo(
  127. appCheck: appCheck,
  128. auth: auth,
  129. projectID: "my-project-id",
  130. apiKey: "API_KEY",
  131. firebaseAppID: "My app ID",
  132. firebaseApp: app,
  133. useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens
  134. )
  135. }
  136. }
  137. private extension String {
  138. /// Returns the number of occurrences of `substring` in the `String`.
  139. func occurrenceCount(of substring: String) -> Int {
  140. return components(separatedBy: substring).count - 1
  141. }
  142. }