Просмотр исходного кода

Increase Firebase AI Logic unit test coverage (#15126)

Co-authored-by: Andrew Heard <andrewheard@google.com>
Paul Beusterien 8 месяцев назад
Родитель
Сommit
e0c6d91ce1

+ 100 - 0
FirebaseAI/Tests/Unit/ChatTests.swift

@@ -94,4 +94,104 @@ final class ChatTests: XCTestCase {
       XCTAssertEqual(chat.history[1], assembledExpectation)
     #endif // os(watchOS)
   }
+
+  func testSendMessage_unary_appendsHistory() async throws {
+    let expectedInput = "Test input"
+    MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
+      forResource: "unary-success-basic-reply-short",
+      withExtension: "json",
+      subdirectory: "mock-responses/googleai"
+    )
+    let model = GenerativeModel(
+      modelName: modelName,
+      modelResourceName: modelResourceName,
+      firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
+      apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
+      tools: nil,
+      requestOptions: RequestOptions(),
+      urlSession: urlSession
+    )
+    let chat = model.startChat()
+
+    // Pre-condition: History should be empty.
+    XCTAssertTrue(chat.history.isEmpty)
+
+    let response = try await chat.sendMessage(expectedInput)
+
+    XCTAssertNotNil(response.text)
+    let text = try XCTUnwrap(response.text)
+    XCTAssertFalse(text.isEmpty)
+
+    // Post-condition: History should have the user's message and the model's response.
+    XCTAssertEqual(chat.history.count, 2)
+    let userInput = try XCTUnwrap(chat.history.first)
+    XCTAssertEqual(userInput.role, "user")
+    XCTAssertEqual(userInput.parts.count, 1)
+    let userInputText = try XCTUnwrap(userInput.parts.first as? TextPart)
+    XCTAssertEqual(userInputText.text, expectedInput)
+
+    let modelResponse = try XCTUnwrap(chat.history.last)
+    XCTAssertEqual(modelResponse.role, "model")
+    XCTAssertEqual(modelResponse.parts.count, 1)
+    let modelResponseText = try XCTUnwrap(modelResponse.parts.first as? TextPart)
+    XCTAssertFalse(modelResponseText.text.isEmpty)
+  }
+
+  func testSendMessageStream_error_doesNotAppendHistory() async throws {
+    MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
+      forResource: "streaming-failure-finish-reason-safety",
+      withExtension: "txt",
+      subdirectory: "mock-responses/vertexai"
+    )
+    let model = GenerativeModel(
+      modelName: modelName,
+      modelResourceName: modelResourceName,
+      firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
+      apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
+      tools: nil,
+      requestOptions: RequestOptions(),
+      urlSession: urlSession
+    )
+    let chat = model.startChat()
+    let input = "Test input"
+
+    // Pre-condition: History should be empty.
+    XCTAssertTrue(chat.history.isEmpty)
+
+    do {
+      let stream = try chat.sendMessageStream(input)
+      for try await _ in stream {
+        // Consume the stream.
+      }
+      XCTFail("Should have thrown a responseStoppedEarly error.")
+    } catch let GenerateContentError.responseStoppedEarly(reason, _) {
+      XCTAssertEqual(reason, .safety)
+    } catch {
+      XCTFail("Unexpected error thrown: \(error)")
+    }
+
+    // Post-condition: History should still be empty.
+    XCTAssertEqual(chat.history.count, 0)
+  }
+
+  func testStartChat_withHistory_initializesCorrectly() async throws {
+    let history = [
+      ModelContent(role: "user", parts: "Question 1"),
+      ModelContent(role: "model", parts: "Answer 1"),
+    ]
+    let model = GenerativeModel(
+      modelName: modelName,
+      modelResourceName: modelResourceName,
+      firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
+      apiConfig: FirebaseAI.defaultVertexAIAPIConfig,
+      tools: nil,
+      requestOptions: RequestOptions(),
+      urlSession: urlSession
+    )
+
+    let chat = model.startChat(history: history)
+
+    XCTAssertEqual(chat.history.count, 2)
+    XCTAssertEqual(chat.history, history)
+  }
 }

+ 68 - 0
FirebaseAI/Tests/Unit/JSONValueTests.swift

@@ -97,6 +97,48 @@ final class JSONValueTests: XCTestCase {
     XCTAssertEqual(json, "null")
   }
 
+  func testDecodeNestedObject() throws {
+    let nestedObject: JSONObject = [
+      "nestedKey": .string("nestedValue"),
+    ]
+    let expectedObject: JSONObject = [
+      "numberKey": .number(numberValue),
+      "objectKey": .object(nestedObject),
+    ]
+    let json = """
+    {
+      "numberKey": \(numberValue),
+      "objectKey": {
+        "nestedKey": "nestedValue"
+      }
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData))
+
+    XCTAssertEqual(jsonObject, .object(expectedObject))
+  }
+
+  func testDecodeNestedArray() throws {
+    let nestedArray: [JSONValue] = [.string("a"), .string("b")]
+    let expectedObject: JSONObject = [
+      "numberKey": .number(numberValue),
+      "arrayKey": .array(nestedArray),
+    ]
+    let json = """
+    {
+      "numberKey": \(numberValue),
+      "arrayKey": ["a", "b"]
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData))
+
+    XCTAssertEqual(jsonObject, .object(expectedObject))
+  }
+
   func testEncodeNumber() throws {
     let jsonData = try encoder.encode(JSONValue.number(numberValue))
 
@@ -143,4 +185,30 @@ final class JSONValueTests: XCTestCase {
     let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
     XCTAssertEqual(json, "[null,\(numberValueEncoded)]")
   }
+
+  func testEncodeNestedObject() throws {
+    let nestedObject: JSONObject = [
+      "nestedKey": .string("nestedValue"),
+    ]
+    let objectValue: JSONObject = [
+      "numberKey": .number(numberValue),
+      "objectKey": .object(nestedObject),
+    ]
+
+    let jsonData = try encoder.encode(JSONValue.object(objectValue))
+    let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData))
+    XCTAssertEqual(jsonObject, .object(objectValue))
+  }
+
+  func testEncodeNestedArray() throws {
+    let nestedArray: [JSONValue] = [.string("a"), .string("b")]
+    let objectValue: JSONObject = [
+      "numberKey": .number(numberValue),
+      "arrayKey": .array(nestedArray),
+    ]
+
+    let jsonData = try encoder.encode(JSONValue.object(objectValue))
+    let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData))
+    XCTAssertEqual(jsonObject, .object(objectValue))
+  }
 }

+ 51 - 0
FirebaseAI/Tests/Unit/PartsRepresentableTests.swift

@@ -121,4 +121,55 @@ final class PartsRepresentableTests: XCTestCase {
       }
     }
   #endif
+
+  func testMixedParts() throws {
+    let text = "This is a test"
+    let data = try XCTUnwrap("This is some data".data(using: .utf8))
+    let inlineData = InlineDataPart(data: data, mimeType: "text/plain")
+
+    let parts: [any PartsRepresentable] = [text, inlineData]
+    let modelContent = ModelContent(parts: parts)
+
+    XCTAssertEqual(modelContent.parts.count, 2)
+    let textPart = try XCTUnwrap(modelContent.parts[0] as? TextPart)
+    XCTAssertEqual(textPart.text, text)
+    let dataPart = try XCTUnwrap(modelContent.parts[1] as? InlineDataPart)
+    XCTAssertEqual(dataPart, inlineData)
+  }
+
+  #if canImport(UIKit)
+    func testMixedParts_withImage() throws {
+      let text = "This is a test"
+      let image = try XCTUnwrap(UIImage(systemName: "star"))
+      let parts: [any PartsRepresentable] = [text, image]
+      let modelContent = ModelContent(parts: parts)
+
+      XCTAssertEqual(modelContent.parts.count, 2)
+      let textPart = try XCTUnwrap(modelContent.parts[0] as? TextPart)
+      XCTAssertEqual(textPart.text, text)
+      let imagePart = try XCTUnwrap(modelContent.parts[1] as? InlineDataPart)
+      XCTAssertEqual(imagePart.mimeType, "image/jpeg")
+      XCTAssertFalse(imagePart.data.isEmpty)
+    }
+
+  #elseif canImport(AppKit)
+    func testMixedParts_withImage() throws {
+      let text = "This is a test"
+      let coreImage = CIImage(color: CIColor.blue)
+        .cropped(to: CGRect(origin: CGPoint.zero, size: CGSize(width: 16, height: 16)))
+      let rep = NSCIImageRep(ciImage: coreImage)
+      let image = NSImage(size: rep.size)
+      image.addRepresentation(rep)
+
+      let parts: [any PartsRepresentable] = [text, image]
+      let modelContent = ModelContent(parts: parts)
+
+      XCTAssertEqual(modelContent.parts.count, 2)
+      let textPart = try XCTUnwrap(modelContent.parts[0] as? TextPart)
+      XCTAssertEqual(textPart.text, text)
+      let imagePart = try XCTUnwrap(modelContent.parts[1] as? InlineDataPart)
+      XCTAssertEqual(imagePart.mimeType, "image/jpeg")
+      XCTAssertFalse(imagePart.data.isEmpty)
+    }
+  #endif
 }

+ 123 - 0
FirebaseAI/Tests/Unit/SafetyTests.swift

@@ -0,0 +1,123 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import XCTest
+
+@testable import FirebaseAI
+
+@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
+final class SafetyTests: XCTestCase {
+  let decoder = JSONDecoder()
+  let encoder = JSONEncoder()
+
+  override func setUp() {
+    encoder.outputFormatting = .init(
+      arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
+    )
+  }
+
+  // MARK: - SafetyRating Decoding
+
+  func testDecodeSafetyRating_allFieldsPresent() throws {
+    let json = """
+    {
+      "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+      "probability": "NEGLIGIBLE",
+      "probabilityScore": 0.1,
+      "severity": "HARM_SEVERITY_LOW",
+      "severityScore": 0.2,
+      "blocked": true
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+    let rating = try decoder.decode(SafetyRating.self, from: jsonData)
+
+    XCTAssertEqual(rating.category, .dangerousContent)
+    XCTAssertEqual(rating.probability, .negligible)
+    XCTAssertEqual(rating.probabilityScore, 0.1)
+    XCTAssertEqual(rating.severity, .low)
+    XCTAssertEqual(rating.severityScore, 0.2)
+    XCTAssertTrue(rating.blocked)
+  }
+
+  func testDecodeSafetyRating_missingOptionalFields() throws {
+    let json = """
+    {
+      "category": "HARM_CATEGORY_HARASSMENT",
+      "probability": "LOW"
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+    let rating = try decoder.decode(SafetyRating.self, from: jsonData)
+
+    XCTAssertEqual(rating.category, .harassment)
+    XCTAssertEqual(rating.probability, .low)
+    XCTAssertEqual(rating.probabilityScore, 0.0)
+    XCTAssertEqual(rating.severity, .unspecified)
+    XCTAssertEqual(rating.severityScore, 0.0)
+    XCTAssertFalse(rating.blocked)
+  }
+
+  func testDecodeSafetyRating_unknownEnums() throws {
+    let json = """
+    {
+      "category": "HARM_CATEGORY_UNKNOWN",
+      "probability": "UNKNOWN_PROBABILITY",
+      "severity": "UNKNOWN_SEVERITY"
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+    let rating = try decoder.decode(SafetyRating.self, from: jsonData)
+
+    XCTAssertEqual(rating.category.rawValue, "HARM_CATEGORY_UNKNOWN")
+    XCTAssertEqual(rating.probability.rawValue, "UNKNOWN_PROBABILITY")
+    XCTAssertEqual(rating.severity.rawValue, "UNKNOWN_SEVERITY")
+  }
+
+  // MARK: - SafetySetting Encoding
+
+  func testEncodeSafetySetting_allFields() throws {
+    let setting = SafetySetting(
+      harmCategory: .hateSpeech,
+      threshold: .blockMediumAndAbove,
+      method: .severity
+    )
+    let jsonData = try encoder.encode(setting)
+    let jsonString = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+
+    XCTAssertEqual(jsonString, """
+    {
+      "category" : "HARM_CATEGORY_HATE_SPEECH",
+      "method" : "SEVERITY",
+      "threshold" : "BLOCK_MEDIUM_AND_ABOVE"
+    }
+    """)
+  }
+
+  func testEncodeSafetySetting_nilMethod() throws {
+    let setting = SafetySetting(
+      harmCategory: .sexuallyExplicit,
+      threshold: .blockOnlyHigh
+    )
+    let jsonData = try encoder.encode(setting)
+    let jsonString = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+
+    XCTAssertEqual(jsonString, """
+    {
+      "category" : "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+      "threshold" : "BLOCK_ONLY_HIGH"
+    }
+    """)
+  }
+}