ChatTests.swift 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import XCTest
  16. @testable import FirebaseVertexAI
  17. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  18. final class ChatTests: XCTestCase {
  19. var urlSession: URLSession!
  20. override func setUp() {
  21. let configuration = URLSessionConfiguration.default
  22. configuration.protocolClasses = [MockURLProtocol.self]
  23. urlSession = URLSession(configuration: configuration)
  24. }
  25. override func tearDown() {
  26. MockURLProtocol.requestHandler = nil
  27. }
  28. func testMergingText() async throws {
  29. let fileURL = try XCTUnwrap(Bundle.module.url(
  30. forResource: "streaming-success-basic-reply-parts",
  31. withExtension: "txt"
  32. ))
  33. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  34. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  35. guard #unavailable(watchOS 2) else {
  36. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  37. }
  38. MockURLProtocol.requestHandler = { request in
  39. let response = HTTPURLResponse(
  40. url: request.url!,
  41. statusCode: 200,
  42. httpVersion: nil,
  43. headerFields: nil
  44. )!
  45. return (response, fileURL.lines)
  46. }
  47. let model = GenerativeModel(
  48. name: "my-model",
  49. projectID: "my-project-id",
  50. apiKey: "API_KEY",
  51. tools: nil,
  52. requestOptions: RequestOptions(),
  53. appCheck: nil,
  54. auth: nil,
  55. urlSession: urlSession
  56. )
  57. let chat = Chat(model: model, history: [])
  58. let input = "Test input"
  59. let stream = chat.sendMessageStream(input)
  60. // Ensure the values are parsed correctly
  61. for try await value in stream {
  62. XCTAssertNotNil(value.text)
  63. }
  64. XCTAssertEqual(chat.history.count, 2)
  65. XCTAssertEqual(chat.history[0].parts[0].text, input)
  66. let finalText = "1 2 3 4 5 6 7 8"
  67. let assembledExpectation = ModelContent(role: "model", parts: finalText)
  68. XCTAssertEqual(chat.history[0].parts[0].text, input)
  69. XCTAssertEqual(chat.history[1], assembledExpectation)
  70. }
  71. }