ソースを参照

[Firebase AI] Add workaround for invalid SafetyRatings in response (#14817)

Andrew Heard 11 ヶ月 前
コミット
08d0f44a94

+ 3 - 0
FirebaseAI/CHANGELOG.md

@@ -1,6 +1,9 @@
 # Unreleased
 - [fixed] Fixed `ModalityTokenCount` decoding when the `tokenCount` field is
   omitted; this occurs when the count is 0. (#14745)
+- [fixed] Fixed `Candidate` decoding when `SafetyRating` values are missing a
+  category or probability; this may occur when using `gemini-2.0-flash-exp` for
+  image generation. (#14817)
 
 # 11.12.0
 - [added] **Public Preview**: Added support for specifying response modalities

+ 7 - 3
FirebaseAI/Sources/GenerateContentResponse.swift

@@ -381,10 +381,14 @@ extension Candidate: Decodable {
     }
 
     if let safetyRatings = try container.decodeIfPresent(
-      [SafetyRating].self,
-      forKey: .safetyRatings
+      [SafetyRating].self, forKey: .safetyRatings
     ) {
-      self.safetyRatings = safetyRatings
+      self.safetyRatings = safetyRatings.filter {
+        // Due to a bug in the backend, the SDK may receive invalid `SafetyRating` values that do
+        // not include a category or probability; these are filtered out of the safety ratings.
+        $0.category != HarmCategory.unspecified
+          && $0.probability != SafetyRating.HarmProbability.unspecified
+      }
     } else {
       safetyRatings = []
     }

+ 18 - 5
FirebaseAI/Sources/Safety.swift

@@ -78,12 +78,16 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
   @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
   public struct HarmProbability: DecodableProtoEnum, Hashable, Sendable {
     enum Kind: String {
+      case unspecified = "HARM_PROBABILITY_UNSPECIFIED"
       case negligible = "NEGLIGIBLE"
       case low = "LOW"
       case medium = "MEDIUM"
       case high = "HIGH"
     }
 
+    /// Internal-only; harm probability is unknown or unspecified by the backend.
+    static let unspecified = HarmProbability(kind: .unspecified)
+
     /// The probability is zero or close to zero.
     ///
     /// For benign content, the probability across all categories will be this value.
@@ -114,12 +118,16 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
   @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
   public struct HarmSeverity: DecodableProtoEnum, Hashable, Sendable {
     enum Kind: String {
+      case unspecified = "HARM_SEVERITY_UNSPECIFIED"
       case negligible = "HARM_SEVERITY_NEGLIGIBLE"
       case low = "HARM_SEVERITY_LOW"
       case medium = "HARM_SEVERITY_MEDIUM"
       case high = "HARM_SEVERITY_HIGH"
     }
 
+    /// Internal-only; harm severity is unknown or unspecified by the backend.
+    static let unspecified: HarmSeverity = .init(kind: .unspecified)
+
     /// Negligible level of harm severity.
     public static let negligible = HarmSeverity(kind: .negligible)
 
@@ -234,6 +242,7 @@ public struct SafetySetting: Sendable {
 @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
 public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
   enum Kind: String {
+    case unspecified = "HARM_CATEGORY_UNSPECIFIED"
     case harassment = "HARM_CATEGORY_HARASSMENT"
     case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
     case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
@@ -241,6 +250,9 @@ public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
     case civicIntegrity = "HARM_CATEGORY_CIVIC_INTEGRITY"
   }
 
+  /// Internal-only; harm category is unknown or unspecified by the backend.
+  static let unspecified = HarmCategory(kind: .unspecified)
+
   /// Harassment content.
   public static let harassment = HarmCategory(kind: .harassment)
 
@@ -281,13 +293,14 @@ extension SafetyRating: Decodable {
 
   public init(from decoder: any Decoder) throws {
     let container = try decoder.container(keyedBy: CodingKeys.self)
-    category = try container.decode(HarmCategory.self, forKey: .category)
-    probability = try container.decode(HarmProbability.self, forKey: .probability)
+    category = try container.decodeIfPresent(HarmCategory.self, forKey: .category) ?? .unspecified
+    probability = try container.decodeIfPresent(
+      HarmProbability.self, forKey: .probability
+    ) ?? .unspecified
 
-    // The following 3 fields are only omitted in our test data.
+    // The following 3 fields are only provided when using the Vertex AI backend (not Google AI).
     probabilityScore = try container.decodeIfPresent(Float.self, forKey: .probabilityScore) ?? 0.0
-    severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ??
-      HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED")
+    severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ?? .unspecified
     severityScore = try container.decodeIfPresent(Float.self, forKey: .severityScore) ?? 0.0
 
     // The blocked field is only included when true.

+ 2 - 3
FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift

@@ -115,9 +115,8 @@ struct GenerateContentIntegrationTests {
   }
 
   @Test(arguments: [
-    // TODO(andrewheard): Vertex AI configs temporarily disabled to due empty SafetyRatings bug.
-    // InstanceConfig.vertexV1,
-    // InstanceConfig.vertexV1Beta,
+    InstanceConfig.vertexAI_v1,
+    InstanceConfig.vertexAI_v1beta,
     InstanceConfig.googleAI_v1beta,
     InstanceConfig.googleAI_v1beta_staging,
     InstanceConfig.googleAI_v1beta_freeTier_bypassProxy,

+ 81 - 1
FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift

@@ -56,6 +56,41 @@ final class GenerativeModelVertexAITests: XCTestCase {
       blocked: false
     ),
   ].sorted()
+  let safetyRatingsInvalidIgnored = [
+    SafetyRating(
+      category: .hateSpeech,
+      probability: .negligible,
+      probabilityScore: 0.00039444832,
+      severity: .negligible,
+      severityScore: 0.0,
+      blocked: false
+    ),
+    SafetyRating(
+      category: .dangerousContent,
+      probability: .negligible,
+      probabilityScore: 0.0010654529,
+      severity: .negligible,
+      severityScore: 0.0049325973,
+      blocked: false
+    ),
+    SafetyRating(
+      category: .harassment,
+      probability: .negligible,
+      probabilityScore: 0.00026658305,
+      severity: .negligible,
+      severityScore: 0.0,
+      blocked: false
+    ),
+    SafetyRating(
+      category: .sexuallyExplicit,
+      probability: .negligible,
+      probabilityScore: 0.0013701695,
+      severity: .negligible,
+      severityScore: 0.07626295,
+      blocked: false
+    ),
+    // Ignored Invalid Safety Ratings: {},{},{},{}
+  ].sorted()
   let testModelName = "test-model"
   let testModelResourceName =
     "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
@@ -399,6 +434,26 @@ final class GenerativeModelVertexAITests: XCTestCase {
     XCTAssertEqual(text, "The sum of [1, 2, 3] is")
   }
 
+  func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
+    MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
+      forResource: "unary-success-image-invalid-safety-ratings",
+      withExtension: "json",
+      subdirectory: vertexSubdirectory
+    )
+
+    let response = try await model.generateContent(testPrompt)
+
+    XCTAssertEqual(response.candidates.count, 1)
+    let candidate = try XCTUnwrap(response.candidates.first)
+    XCTAssertEqual(candidate.content.parts.count, 1)
+    XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
+    let inlineDataParts = response.inlineDataParts
+    XCTAssertEqual(inlineDataParts.count, 1)
+    let imagePart = try XCTUnwrap(inlineDataParts.first)
+    XCTAssertEqual(imagePart.mimeType, "image/png")
+    XCTAssertGreaterThan(imagePart.data.count, 0)
+  }
+
   func testGenerateContent_appCheck_validToken() async throws {
     let appCheckToken = "test-valid-token"
     model = GenerativeModel(
@@ -1118,7 +1173,7 @@ final class GenerativeModelVertexAITests: XCTestCase {
       responses += 1
     }
 
-    XCTAssertEqual(responses, 6)
+    XCTAssertEqual(responses, 4)
   }
 
   func testGenerateContentStream_successBasicReplyShort() async throws {
@@ -1220,6 +1275,31 @@ final class GenerativeModelVertexAITests: XCTestCase {
     XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
   }
 
+  func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
+    MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
+      forResource: "streaming-success-image-invalid-safety-ratings",
+      withExtension: "txt",
+      subdirectory: vertexSubdirectory
+    )
+
+    let stream = try model.generateContentStream(testPrompt)
+    var responses = [GenerateContentResponse]()
+    for try await content in stream {
+      responses.append(content)
+    }
+
+    let response = try XCTUnwrap(responses.first)
+    XCTAssertEqual(response.candidates.count, 1)
+    let candidate = try XCTUnwrap(response.candidates.first)
+    XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
+    XCTAssertEqual(candidate.content.parts.count, 1)
+    let inlineDataParts = response.inlineDataParts
+    XCTAssertEqual(inlineDataParts.count, 1)
+    let imagePart = try XCTUnwrap(inlineDataParts.first)
+    XCTAssertEqual(imagePart.mimeType, "image/png")
+    XCTAssertGreaterThan(imagePart.data.count, 0)
+  }
+
   func testGenerateContentStream_appCheck_validToken() async throws {
     let appCheckToken = "test-valid-token"
     model = GenerativeModel(