ChatTests.swift 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseCore
  15. import Foundation
  16. import XCTest
  17. @testable import FirebaseAILogic
  18. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  19. final class ChatTests: XCTestCase {
  20. let modelName = "test-model-name"
  21. let modelResourceName = "projects/my-project/locations/us-central1/models/test-model-name"
  22. var urlSession: URLSession!
  23. override func setUp() {
  24. let configuration = URLSessionConfiguration.default
  25. configuration.protocolClasses = [MockURLProtocol.self]
  26. urlSession = URLSession(configuration: configuration)
  27. }
  28. override func tearDown() {
  29. MockURLProtocol.requestHandler = nil
  30. }
  31. func testMergingText() async throws {
  32. let bundle = BundleTestUtil.bundle()
  33. let fileURL = try XCTUnwrap(bundle.url(
  34. forResource: "streaming-success-basic-reply-parts",
  35. withExtension: "txt",
  36. subdirectory: "mock-responses/vertexai"
  37. ))
  38. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  39. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  40. #if os(watchOS)
  41. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  42. #else // os(watchOS)
  43. MockURLProtocol.requestHandler = { request in
  44. let response = HTTPURLResponse(
  45. url: request.url!,
  46. statusCode: 200,
  47. httpVersion: nil,
  48. headerFields: nil
  49. )!
  50. return (response, fileURL.lines)
  51. }
  52. let app = FirebaseApp(instanceWithName: "testApp",
  53. options: FirebaseOptions(googleAppID: "ignore",
  54. gcmSenderID: "ignore"))
  55. let model = GenerativeModel(
  56. modelName: modelName,
  57. modelResourceName: modelResourceName,
  58. firebaseInfo: FirebaseInfo(
  59. projectID: "my-project-id",
  60. apiKey: "API_KEY",
  61. firebaseAppID: "My app ID",
  62. firebaseApp: app,
  63. useLimitedUseAppCheckTokens: false
  64. ),
  65. apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
  66. tools: nil,
  67. requestOptions: RequestOptions(),
  68. urlSession: urlSession
  69. )
  70. let chat = Chat(model: model, history: [])
  71. let input = "Test input"
  72. let stream = try chat.sendMessageStream(input)
  73. // Ensure the values are parsed correctly
  74. for try await value in stream {
  75. XCTAssertNotNil(value.text)
  76. }
  77. XCTAssertEqual(chat.history.count, 2)
  78. let part = try XCTUnwrap(chat.history[0].parts[0])
  79. let textPart = try XCTUnwrap(part as? TextPart)
  80. XCTAssertEqual(textPart.text, input)
  81. let finalText = "1 2 3 4 5 6 7 8"
  82. let assembledExpectation = ModelContent(role: "model", parts: finalText)
  83. XCTAssertEqual(chat.history[1], assembledExpectation)
  84. #endif // os(watchOS)
  85. }
  86. func testSendMessage_unary_appendsHistory() async throws {
  87. let expectedInput = "Test input"
  88. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  89. forResource: "unary-success-basic-reply-short",
  90. withExtension: "json",
  91. subdirectory: "mock-responses/googleai"
  92. )
  93. let model = GenerativeModel(
  94. modelName: modelName,
  95. modelResourceName: modelResourceName,
  96. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  97. apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
  98. tools: nil,
  99. requestOptions: RequestOptions(),
  100. urlSession: urlSession
  101. )
  102. let chat = model.startChat()
  103. // Pre-condition: History should be empty.
  104. XCTAssertTrue(chat.history.isEmpty)
  105. let response = try await chat.sendMessage(expectedInput)
  106. XCTAssertNotNil(response.text)
  107. let text = try XCTUnwrap(response.text)
  108. XCTAssertFalse(text.isEmpty)
  109. // Post-condition: History should have the user's message and the model's response.
  110. XCTAssertEqual(chat.history.count, 2)
  111. let userInput = try XCTUnwrap(chat.history.first)
  112. XCTAssertEqual(userInput.role, "user")
  113. XCTAssertEqual(userInput.parts.count, 1)
  114. let userInputText = try XCTUnwrap(userInput.parts.first as? TextPart)
  115. XCTAssertEqual(userInputText.text, expectedInput)
  116. let modelResponse = try XCTUnwrap(chat.history.last)
  117. XCTAssertEqual(modelResponse.role, "model")
  118. XCTAssertEqual(modelResponse.parts.count, 1)
  119. let modelResponseText = try XCTUnwrap(modelResponse.parts.first as? TextPart)
  120. XCTAssertFalse(modelResponseText.text.isEmpty)
  121. }
  122. func testSendMessageStream_error_doesNotAppendHistory() async throws {
  123. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  124. forResource: "streaming-failure-finish-reason-safety",
  125. withExtension: "txt",
  126. subdirectory: "mock-responses/vertexai"
  127. )
  128. let model = GenerativeModel(
  129. modelName: modelName,
  130. modelResourceName: modelResourceName,
  131. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  132. apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
  133. tools: nil,
  134. requestOptions: RequestOptions(),
  135. urlSession: urlSession
  136. )
  137. let chat = model.startChat()
  138. let input = "Test input"
  139. // Pre-condition: History should be empty.
  140. XCTAssertTrue(chat.history.isEmpty)
  141. do {
  142. let stream = try chat.sendMessageStream(input)
  143. for try await _ in stream {
  144. // Consume the stream.
  145. }
  146. XCTFail("Should have thrown a responseStoppedEarly error.")
  147. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  148. XCTAssertEqual(reason, .safety)
  149. } catch {
  150. XCTFail("Unexpected error thrown: \(error)")
  151. }
  152. // Post-condition: History should still be empty.
  153. XCTAssertEqual(chat.history.count, 0)
  154. }
  155. func testStartChat_withHistory_initializesCorrectly() async throws {
  156. let history = [
  157. ModelContent(role: "user", parts: "Question 1"),
  158. ModelContent(role: "model", parts: "Answer 1"),
  159. ]
  160. let model = GenerativeModel(
  161. modelName: modelName,
  162. modelResourceName: modelResourceName,
  163. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  164. apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
  165. tools: nil,
  166. requestOptions: RequestOptions(),
  167. urlSession: urlSession
  168. )
  169. let chat = model.startChat(history: history)
  170. XCTAssertEqual(chat.history.count, 2)
  171. XCTAssertEqual(chat.history, history)
  172. }
  173. }