Sfoglia il codice sorgente

[Firebase AI] Add internal JSON Schema support in `GenerationConfig` (#15404)

Andrew Heard 5 mesi fa
parent
commit
f9754ba5fb

+ 26 - 0
FirebaseAI/Sources/GenerationConfig.swift

@@ -48,6 +48,11 @@ public struct GenerationConfig: Sendable {
   /// Output schema of the generated candidate text.
   let responseSchema: Schema?
 
+  /// Output schema of the generated response in [JSON Schema](https://json-schema.org/) format.
+  ///
+  /// If set, `responseSchema` must be omitted and `responseMIMEType` is required.
+  let responseJSONSchema: JSONObject?
+
   /// Supported modalities of the response.
   let responseModalities: [ResponseModality]?
 
@@ -175,6 +180,26 @@ public struct GenerationConfig: Sendable {
     self.stopSequences = stopSequences
     self.responseMIMEType = responseMIMEType
     self.responseSchema = responseSchema
+    responseJSONSchema = nil
+    self.responseModalities = responseModalities
+    self.thinkingConfig = thinkingConfig
+  }
+
+  init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil,
+       maxOutputTokens: Int? = nil, presencePenalty: Float? = nil, frequencyPenalty: Float? = nil,
+       stopSequences: [String]? = nil, responseMIMEType: String, responseJSONSchema: JSONObject,
+       responseModalities: [ResponseModality]? = nil, thinkingConfig: ThinkingConfig? = nil) {
+    self.temperature = temperature
+    self.topP = topP
+    self.topK = topK
+    self.candidateCount = candidateCount
+    self.maxOutputTokens = maxOutputTokens
+    self.presencePenalty = presencePenalty
+    self.frequencyPenalty = frequencyPenalty
+    self.stopSequences = stopSequences
+    self.responseMIMEType = responseMIMEType
+    responseSchema = nil
+    self.responseJSONSchema = responseJSONSchema
     self.responseModalities = responseModalities
     self.thinkingConfig = thinkingConfig
   }
@@ -195,6 +220,7 @@ extension GenerationConfig: Encodable {
     case stopSequences
     case responseMIMEType = "responseMimeType"
     case responseSchema
+    case responseJSONSchema = "responseJsonSchema"
     case responseModalities
     case thinkingConfig
   }

+ 311 - 116
FirebaseAI/Tests/TestApp/Tests/Integration/SchemaTests.swift

@@ -28,8 +28,6 @@ import Testing
 /// Test the schema fields.
 @Suite(.serialized)
 struct SchemaTests {
-  // Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
-  let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
   let safetySettings = [
     SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
     SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
@@ -37,31 +35,32 @@ struct SchemaTests {
     SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
     SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
   ]
-  // Candidates and total token counts may differ slightly between runs due to whitespace tokens.
-  let tokenCountAccuracy = 1
 
-  let storage: Storage
-  let userID1: String
-
-  init() async throws {
-    userID1 = try await TestHelpers.getUserID()
-    storage = Storage.storage()
-  }
-
-  @Test(arguments: InstanceConfig.allConfigs)
-  func generateContentSchemaItems(_ config: InstanceConfig) async throws {
-    let model = FirebaseAI.componentInstance(config).generativeModel(
-      modelName: ModelNames.gemini2FlashLite,
-      generationConfig: GenerationConfig(
-        responseMIMEType: "application/json",
-        responseSchema:
-        .array(
-          items: .string(description: "The name of the city"),
-          description: "A list of city names",
-          minItems: 3,
-          maxItems: 5
-        )
+  @Test(
+    arguments: testConfigs(
+      instanceConfigs: InstanceConfig.allConfigs,
+      openAPISchema: .array(
+        items: .string(description: "The name of the city"),
+        description: "A list of city names",
+        minItems: 3,
+        maxItems: 5
       ),
+      jsonSchema: [
+        "type": .string("array"),
+        "description": .string("A list of city names"),
+        "items": .object([
+          "type": .string("string"),
+          "description": .string("The name of the city"),
+        ]),
+        "minItems": .number(3),
+        "maxItems": .number(5),
+      ]
+    )
+  )
+  func generateContentItemsSchema(_ config: InstanceConfig, _ schema: SchemaType) async throws {
+    let model = FirebaseAI.componentInstance(config).generativeModel(
+      modelName: ModelNames.gemini2_5_FlashLite,
+      generationConfig: SchemaTests.generationConfig(schema: schema),
       safetySettings: safetySettings
     )
     let prompt = "What are the biggest cities in Canada?"
@@ -73,18 +72,25 @@ struct SchemaTests {
     #expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
   }
 
-  @Test(arguments: InstanceConfig.allConfigs)
-  func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws {
+  @Test(arguments: testConfigs(
+    instanceConfigs: InstanceConfig.allConfigs,
+    openAPISchema: .integer(
+      description: "A number",
+      minimum: 110,
+      maximum: 120
+    ),
+    jsonSchema: [
+      "type": .string("integer"),
+      "description": .string("A number"),
+      "minimum": .number(110),
+      "maximum": .number(120),
+    ]
+  ))
+  func generateContentSchemaNumberRange(_ config: InstanceConfig,
+                                        _ schema: SchemaType) async throws {
     let model = FirebaseAI.componentInstance(config).generativeModel(
-      modelName: ModelNames.gemini2FlashLite,
-      generationConfig: GenerationConfig(
-        responseMIMEType: "application/json",
-        responseSchema: .integer(
-          description: "A number",
-          minimum: 110,
-          maximum: 120
-        )
-      ),
+      modelName: ModelNames.gemini2_5_FlashLite,
+      generationConfig: SchemaTests.generationConfig(schema: schema),
       safetySettings: safetySettings
     )
     let prompt = "Give me a number"
@@ -96,41 +102,83 @@ struct SchemaTests {
     #expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
   }
 
-  @Test(arguments: InstanceConfig.allConfigs)
-  func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
+  @Test(arguments: testConfigs(
+    instanceConfigs: InstanceConfig.allConfigs,
+    openAPISchema: .object(
+      properties: [
+        "productName": .string(description: "The name of the product"),
+        "price": .double(
+          description: "A price",
+          minimum: 10.00,
+          maximum: 120.00
+        ),
+        "salePrice": .float(
+          description: "A sale price",
+          minimum: 5.00,
+          maximum: 90.00
+        ),
+        "rating": .integer(
+          description: "A rating",
+          minimum: 1,
+          maximum: 5
+        ),
+      ],
+      propertyOrdering: ["salePrice", "rating", "price", "productName"],
+      title: "ProductInfo"
+    ),
+    jsonSchema: [
+      "type": .string("object"),
+      "title": .string("ProductInfo"),
+      "properties": .object([
+        "productName": .object([
+          "type": .string("string"),
+          "description": .string("The name of the product"),
+        ]),
+        "price": .object([
+          "type": .string("number"),
+          "description": .string("A price"),
+          "minimum": .number(10.00),
+          "maximum": .number(120.00),
+        ]),
+        "salePrice": .object([
+          "type": .string("number"),
+          "description": .string("A sale price"),
+          "minimum": .number(5.00),
+          "maximum": .number(90.00),
+        ]),
+        "rating": .object([
+          "type": .string("integer"),
+          "description": .string("A rating"),
+          "minimum": .number(1),
+          "maximum": .number(5),
+        ]),
+      ]),
+      "required": .array([
+        .string("productName"),
+        .string("price"),
+        .string("salePrice"),
+        .string("rating"),
+      ]),
+      "propertyOrdering": .array([
+        .string("salePrice"),
+        .string("rating"),
+        .string("price"),
+        .string("productName"),
+      ]),
+    ]
+  ))
+  func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig,
+                                                 _ schema: SchemaType) async throws {
     struct ProductInfo: Codable {
       let productName: String
-      let rating: Int // Will correspond to .integer in schema
-      let price: Double // Will correspond to .double in schema
-      let salePrice: Float // Will correspond to .float in schema
+      let rating: Int
+      let price: Double
+      let salePrice: Float
     }
+
     let model = FirebaseAI.componentInstance(config).generativeModel(
-      modelName: ModelNames.gemini2FlashLite,
-      generationConfig: GenerationConfig(
-        responseMIMEType: "application/json",
-        responseSchema: .object(
-          properties: [
-            "productName": .string(description: "The name of the product"),
-            "price": .double(
-              description: "A price",
-              minimum: 10.00,
-              maximum: 120.00
-            ),
-            "salePrice": .float(
-              description: "A sale price",
-              minimum: 5.00,
-              maximum: 90.00
-            ),
-            "rating": .integer(
-              description: "A rating",
-              minimum: 1,
-              maximum: 5
-            ),
-          ],
-          propertyOrdering: ["salePrice", "rating", "price", "productName"],
-          title: "ProductInfo"
-        ),
-      ),
+      modelName: ModelNames.gemini2_5_FlashLite,
+      generationConfig: SchemaTests.generationConfig(schema: schema),
       safetySettings: safetySettings
     )
     let prompt = "Describe a premium wireless headphone, including a user rating and price."
@@ -149,63 +197,127 @@ struct SchemaTests {
     #expect(rating <= 5, "Expected a rating <= 5, but got \(rating)")
   }
 
-  @Test(arguments: InstanceConfig.allConfigs)
-  func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
-    struct MailingAddress: Decodable {
-      let streetAddress: String
-      let city: String
-
-      // Canadian-specific
-      let province: String?
-      let postalCode: String?
-
-      // U.S.-specific
-      let state: String?
-      let zipCode: String?
-
-      var isCanadian: Bool {
-        return province != nil && postalCode != nil && state == nil && zipCode == nil
+  fileprivate struct MailingAddress {
+    enum PostalInfo {
+      struct Canada: Decodable {
+        let province: String
+        let postalCode: String
       }
 
-      var isAmerican: Bool {
-        return province == nil && postalCode == nil && state != nil && zipCode != nil
+      struct UnitedStates: Decodable {
+        let state: String
+        let zipCode: String
       }
+
+      case canada(province: String, postalCode: String)
+      case unitedStates(state: String, zipCode: String)
     }
 
+    let streetAddress: String
+    let city: String
+    let postalInfo: PostalInfo
+  }
+
+  private static let generateContentAnyOfOpenAPISchema = {
     let streetSchema = Schema.string(description:
       "The civic number and street name, for example, '123 Main Street'.")
     let citySchema = Schema.string(description: "The name of the city.")
-    let canadianAddressSchema = Schema.object(
+    let canadaPostalInfoSchema = Schema.object(
       properties: [
-        "streetAddress": streetSchema,
-        "city": citySchema,
         "province": .string(description:
           "The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'."),
         "postalCode": .string(description: "The postal code, for example, 'A1A 1A1'."),
-      ],
-      description: "A Canadian mailing address"
+      ]
     )
-    let americanAddressSchema = Schema.object(
+    let unitedStatesPostalInfoSchema = Schema.object(
       properties: [
-        "streetAddress": streetSchema,
-        "city": citySchema,
         "state": .string(description:
           "The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."),
         "zipCode": .string(description: "The 5-digit ZIP code, for example, '12345'."),
-      ],
-      description: "A U.S. mailing address"
+      ]
     )
+    let mailingAddressSchema = Schema.object(properties: [
+      "streetAddress": streetSchema,
+      "city": citySchema,
+      "postalInfo": .anyOf(schemas: [canadaPostalInfoSchema, unitedStatesPostalInfoSchema]),
+    ])
+    return Schema.array(items: mailingAddressSchema)
+  }()
+
+  private static let generateContentAnyOfJSONSchema = {
+    let streetSchema: JSONValue = .object([
+      "type": .string("string"),
+      "description": .string("The civic number and street name, for example, '123 Main Street'."),
+    ])
+    let citySchema: JSONValue = .object([
+      "type": .string("string"),
+      "description": .string("The name of the city."),
+    ])
+    let postalInfoSchema: JSONValue = .object([
+      "anyOf": .array([
+        .object([
+          "type": .string("object"),
+          "properties": .object([
+            "province": .object([
+              "type": .string("string"),
+              "description": .string(
+                "The 2-letter Canadian province or territory code, for example, 'ON', 'QC', or 'NU'."
+              ),
+            ]),
+            "postalCode": .object([
+              "type": .string("string"),
+              "description": .string("The Canadian postal code, for example, 'A1A 1A1'."),
+            ]),
+          ]),
+          "required": .array([.string("province"), .string("postalCode")]),
+        ]),
+        .object([
+          "type": .string("object"),
+          "properties": .object([
+            "state": .object([
+              "type": .string("string"),
+              "description": .string(
+                "The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."
+              ),
+            ]),
+            "zipCode": .object([
+              "type": .string("string"),
+              "description": .string("The 5-digit U.S. ZIP code, for example, '12345'."),
+            ]),
+          ]),
+          "required": .array([.string("state"), .string("zipCode")]),
+        ]),
+      ]),
+    ])
+    let mailingAddressSchema: JSONObject = [
+      "type": .string("object"),
+      "description": .string("A mailing address"),
+      "properties": .object([
+        "streetAddress": streetSchema,
+        "city": citySchema,
+        "postalInfo": postalInfoSchema,
+      ]),
+      "required": .array([
+        .string("streetAddress"),
+        .string("city"),
+        .string("postalInfo"),
+      ]),
+    ]
+    return [
+      "type": .string("array"),
+      "items": .object(mailingAddressSchema),
+    ] as JSONObject
+  }()
+
+  @Test(arguments: testConfigs(
+    instanceConfigs: InstanceConfig.allConfigs,
+    openAPISchema: generateContentAnyOfOpenAPISchema,
+    jsonSchema: generateContentAnyOfJSONSchema
+  ))
+  func generateContentAnyOfSchema(_ config: InstanceConfig, _ schema: SchemaType) async throws {
     let model = FirebaseAI.componentInstance(config).generativeModel(
-      modelName: ModelNames.gemini2Flash,
-      generationConfig: GenerationConfig(
-        temperature: 0.0,
-        topP: 0.0,
-        topK: 1,
-        responseMIMEType: "application/json",
-        responseSchema: .array(items: .anyOf(
-          schemas: [canadianAddressSchema, americanAddressSchema]
-        ))
-      ),
+      modelName: ModelNames.gemini2_5_Flash,
+      generationConfig: SchemaTests.generationConfig(schema: schema),
       safetySettings: safetySettings
     )
     let prompt = """
@@ -217,19 +329,102 @@ struct SchemaTests {
     let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData)
     try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).")
     let waterlooAddress = decodedAddresses[0]
-    #expect(
-      waterlooAddress.isCanadian,
-      "Expected Canadian University of Waterloo address, got \(waterlooAddress)."
-    )
+    #expect(waterlooAddress.city == "Waterloo")
+    if case let .canada(province, postalCode) = waterlooAddress.postalInfo {
+      #expect(province == "ON")
+      #expect(postalCode == "N2L 3G1")
+    } else {
+      Issue.record("Expected Canadian University of Waterloo address, got \(waterlooAddress).")
+    }
     let berkeleyAddress = decodedAddresses[1]
-    #expect(
-      berkeleyAddress.isAmerican,
-      "Expected American UC Berkeley address, got \(berkeleyAddress)."
-    )
+    #expect(berkeleyAddress.city == "Berkeley")
+    if case let .unitedStates(state, zipCode) = berkeleyAddress.postalInfo {
+      #expect(state == "CA")
+      #expect(zipCode == "94720")
+    } else {
+      Issue.record("Expected American UC Berkeley address, got \(berkeleyAddress).")
+    }
     let queensAddress = decodedAddresses[2]
-    #expect(
-      queensAddress.isCanadian,
-      "Expected Canadian Queen's University address, got \(queensAddress)."
+    #expect(queensAddress.city == "Kingston")
+    if case let .canada(province, postalCode) = queensAddress.postalInfo {
+      #expect(province == "ON")
+      #expect(postalCode == "K7L 3N6")
+    } else {
+      Issue.record("Expected Canadian Queen's University address, got \(queensAddress).")
+    }
+  }
+
+  enum SchemaType: CustomTestStringConvertible {
+    case openAPI(Schema)
+    case json(JSONObject)
+
+    var testDescription: String {
+      switch self {
+      case .openAPI:
+        return "OpenAPI Schema"
+      case .json:
+        return "JSON Schema"
+      }
+    }
+  }
+
+  private static func generationConfig(schema: SchemaType) -> GenerationConfig {
+    let mimeType = "application/json"
+    switch schema {
+    case let .openAPI(openAPISchema):
+      return GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1, responseMIMEType: mimeType,
+                              responseSchema: openAPISchema)
+    case let .json(jsonSchema):
+      return GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1, responseMIMEType: mimeType,
+                              responseJSONSchema: jsonSchema)
+    }
+  }
+
+  private static func testConfigs(instanceConfigs: [InstanceConfig], openAPISchema: Schema,
+                                  jsonSchema: JSONObject) -> [(InstanceConfig, SchemaType)] {
+    return instanceConfigs.flatMap { [($0, .openAPI(openAPISchema)), ($0, .json(jsonSchema))] }
+  }
+}
+
+extension SchemaTests.MailingAddress: Decodable {
+  enum CodingKeys: CodingKey {
+    case streetAddress
+    case city
+    case postalInfo
+  }
+
+  init(from decoder: any Decoder) throws {
+    let container = try decoder.container(keyedBy: CodingKeys.self)
+    streetAddress = try container.decode(String.self, forKey: .streetAddress)
+    city = try container.decode(String.self, forKey: .city)
+    let canadaPostalInfo = try? container.decode(PostalInfo.Canada.self, forKey: .postalInfo)
+    let unitedStatesPostalInfo = try? container.decode(
+      PostalInfo.UnitedStates.self, forKey: .postalInfo
     )
+
+    if canadaPostalInfo != nil, unitedStatesPostalInfo != nil {
+      throw DecodingError.dataCorruptedError(
+        forKey: .postalInfo,
+        in: container,
+        debugDescription: "Ambiguous postal info: matches both Canadian and U.S. formats."
+      )
+    }
+
+    if let canadaPostalInfo {
+      postalInfo = .canada(
+        province: canadaPostalInfo.province, postalCode: canadaPostalInfo.postalCode
+      )
+    } else if let unitedStatesPostalInfo {
+      postalInfo = .unitedStates(
+        state: unitedStatesPostalInfo.state, zipCode: unitedStatesPostalInfo.zipCode
+      )
+    } else {
+      throw DecodingError.typeMismatch(
+        PostalInfo.self, .init(
+          codingPath: container.codingPath,
+          debugDescription: "Expected Canadian or U.S. postal info."
+        )
+      )
+    }
   }
 }

+ 82 - 1
FirebaseAI/Tests/Unit/GenerationConfigTests.swift

@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-import FirebaseAILogic
+@testable import FirebaseAILogic
 import Foundation
 import XCTest
 
@@ -153,4 +153,85 @@ final class GenerationConfigTests: XCTestCase {
     }
     """)
   }
+
+  func testEncodeGenerationConfig_responseJSONSchema() throws {
+    let mimeType = "application/json"
+    let responseJSONSchema: JSONObject = [
+      "type": .string("object"),
+      "title": .string("Person"),
+      "properties": .object([
+        "firstName": .object(["type": .string("string")]),
+        "middleNames": .object([
+          "type": .string("array"),
+          "items": .object(["type": .string("string")]),
+          "minItems": .number(0),
+          "maxItems": .number(3),
+        ]),
+        "lastName": .object(["type": .string("string")]),
+        "age": .object(["type": .string("integer")]),
+      ]),
+      "required": .array([
+        .string("firstName"),
+        .string("middleNames"),
+        .string("lastName"),
+        .string("age"),
+      ]),
+      "propertyOrdering": .array([
+        .string("firstName"),
+        .string("middleNames"),
+        .string("lastName"),
+        .string("age"),
+      ]),
+      "additionalProperties": .bool(false),
+    ]
+    let generationConfig = GenerationConfig(
+      responseMIMEType: mimeType,
+      responseJSONSchema: responseJSONSchema
+    )
+
+    let jsonData = try encoder.encode(generationConfig)
+
+    let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+    XCTAssertEqual(json, """
+    {
+      "responseJsonSchema" : {
+        "additionalProperties" : false,
+        "properties" : {
+          "age" : {
+            "type" : "integer"
+          },
+          "firstName" : {
+            "type" : "string"
+          },
+          "lastName" : {
+            "type" : "string"
+          },
+          "middleNames" : {
+            "items" : {
+              "type" : "string"
+            },
+            "maxItems" : 3,
+            "minItems" : 0,
+            "type" : "array"
+          }
+        },
+        "propertyOrdering" : [
+          "firstName",
+          "middleNames",
+          "lastName",
+          "age"
+        ],
+        "required" : [
+          "firstName",
+          "middleNames",
+          "lastName",
+          "age"
+        ],
+        "title" : "Person",
+        "type" : "object"
+      },
+      "responseMimeType" : "\(mimeType)"
+    }
+    """)
+  }
 }