|
|
@@ -28,8 +28,6 @@ import Testing
|
|
|
/// Test the schema fields.
|
|
|
@Suite(.serialized)
|
|
|
struct SchemaTests {
|
|
|
- // Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
|
|
|
- let generationConfig = GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1)
|
|
|
let safetySettings = [
|
|
|
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
|
|
|
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
|
|
|
@@ -37,31 +35,32 @@ struct SchemaTests {
|
|
|
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
|
|
|
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
|
|
|
]
|
|
|
- // Candidates and total token counts may differ slightly between runs due to whitespace tokens.
|
|
|
- let tokenCountAccuracy = 1
|
|
|
|
|
|
- let storage: Storage
|
|
|
- let userID1: String
|
|
|
-
|
|
|
- init() async throws {
|
|
|
- userID1 = try await TestHelpers.getUserID()
|
|
|
- storage = Storage.storage()
|
|
|
- }
|
|
|
-
|
|
|
- @Test(arguments: InstanceConfig.allConfigs)
|
|
|
- func generateContentSchemaItems(_ config: InstanceConfig) async throws {
|
|
|
- let model = FirebaseAI.componentInstance(config).generativeModel(
|
|
|
- modelName: ModelNames.gemini2FlashLite,
|
|
|
- generationConfig: GenerationConfig(
|
|
|
- responseMIMEType: "application/json",
|
|
|
- responseSchema:
|
|
|
- .array(
|
|
|
- items: .string(description: "The name of the city"),
|
|
|
- description: "A list of city names",
|
|
|
- minItems: 3,
|
|
|
- maxItems: 5
|
|
|
- )
|
|
|
+ @Test(
|
|
|
+ arguments: testConfigs(
|
|
|
+ instanceConfigs: InstanceConfig.allConfigs,
|
|
|
+ openAPISchema: .array(
|
|
|
+ items: .string(description: "The name of the city"),
|
|
|
+ description: "A list of city names",
|
|
|
+ minItems: 3,
|
|
|
+ maxItems: 5
|
|
|
),
|
|
|
+ jsonSchema: [
|
|
|
+ "type": .string("array"),
|
|
|
+ "description": .string("A list of city names"),
|
|
|
+ "items": .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string("The name of the city"),
|
|
|
+ ]),
|
|
|
+ "minItems": .number(3),
|
|
|
+ "maxItems": .number(5),
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ )
|
|
|
+ func generateContentItemsSchema(_ config: InstanceConfig, _ schema: SchemaType) async throws {
|
|
|
+ let model = FirebaseAI.componentInstance(config).generativeModel(
|
|
|
+ modelName: ModelNames.gemini2_5_FlashLite,
|
|
|
+ generationConfig: SchemaTests.generationConfig(schema: schema),
|
|
|
safetySettings: safetySettings
|
|
|
)
|
|
|
let prompt = "What are the biggest cities in Canada?"
|
|
|
@@ -73,18 +72,25 @@ struct SchemaTests {
|
|
|
#expect(decodedJSON.count <= 5, "Expected at most 5 cities, but got \(decodedJSON.count)")
|
|
|
}
|
|
|
|
|
|
- @Test(arguments: InstanceConfig.allConfigs)
|
|
|
- func generateContentSchemaNumberRange(_ config: InstanceConfig) async throws {
|
|
|
+ @Test(arguments: testConfigs(
|
|
|
+ instanceConfigs: InstanceConfig.allConfigs,
|
|
|
+ openAPISchema: .integer(
|
|
|
+ description: "A number",
|
|
|
+ minimum: 110,
|
|
|
+ maximum: 120
|
|
|
+ ),
|
|
|
+ jsonSchema: [
|
|
|
+ "type": .string("integer"),
|
|
|
+ "description": .string("A number"),
|
|
|
+ "minimum": .number(110),
|
|
|
+ "maximum": .number(120),
|
|
|
+ ]
|
|
|
+ ))
|
|
|
+ func generateContentSchemaNumberRange(_ config: InstanceConfig,
|
|
|
+ _ schema: SchemaType) async throws {
|
|
|
let model = FirebaseAI.componentInstance(config).generativeModel(
|
|
|
- modelName: ModelNames.gemini2FlashLite,
|
|
|
- generationConfig: GenerationConfig(
|
|
|
- responseMIMEType: "application/json",
|
|
|
- responseSchema: .integer(
|
|
|
- description: "A number",
|
|
|
- minimum: 110,
|
|
|
- maximum: 120
|
|
|
- )
|
|
|
- ),
|
|
|
+ modelName: ModelNames.gemini2_5_FlashLite,
|
|
|
+ generationConfig: SchemaTests.generationConfig(schema: schema),
|
|
|
safetySettings: safetySettings
|
|
|
)
|
|
|
let prompt = "Give me a number"
|
|
|
@@ -96,41 +102,83 @@ struct SchemaTests {
|
|
|
#expect(decodedNumber <= 120.0, "Expected a number <= 120, but got \(decodedNumber)")
|
|
|
}
|
|
|
|
|
|
- @Test(arguments: InstanceConfig.allConfigs)
|
|
|
- func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig) async throws {
|
|
|
+ @Test(arguments: testConfigs(
|
|
|
+ instanceConfigs: InstanceConfig.allConfigs,
|
|
|
+ openAPISchema: .object(
|
|
|
+ properties: [
|
|
|
+ "productName": .string(description: "The name of the product"),
|
|
|
+ "price": .double(
|
|
|
+ description: "A price",
|
|
|
+ minimum: 10.00,
|
|
|
+ maximum: 120.00
|
|
|
+ ),
|
|
|
+ "salePrice": .float(
|
|
|
+ description: "A sale price",
|
|
|
+ minimum: 5.00,
|
|
|
+ maximum: 90.00
|
|
|
+ ),
|
|
|
+ "rating": .integer(
|
|
|
+ description: "A rating",
|
|
|
+ minimum: 1,
|
|
|
+ maximum: 5
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ propertyOrdering: ["salePrice", "rating", "price", "productName"],
|
|
|
+ title: "ProductInfo"
|
|
|
+ ),
|
|
|
+ jsonSchema: [
|
|
|
+ "type": .string("object"),
|
|
|
+ "title": .string("ProductInfo"),
|
|
|
+ "properties": .object([
|
|
|
+ "productName": .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string("The name of the product"),
|
|
|
+ ]),
|
|
|
+ "price": .object([
|
|
|
+ "type": .string("number"),
|
|
|
+ "description": .string("A price"),
|
|
|
+ "minimum": .number(10.00),
|
|
|
+ "maximum": .number(120.00),
|
|
|
+ ]),
|
|
|
+ "salePrice": .object([
|
|
|
+ "type": .string("number"),
|
|
|
+ "description": .string("A sale price"),
|
|
|
+ "minimum": .number(5.00),
|
|
|
+ "maximum": .number(90.00),
|
|
|
+ ]),
|
|
|
+ "rating": .object([
|
|
|
+ "type": .string("integer"),
|
|
|
+ "description": .string("A rating"),
|
|
|
+ "minimum": .number(1),
|
|
|
+ "maximum": .number(5),
|
|
|
+ ]),
|
|
|
+ ]),
|
|
|
+ "required": .array([
|
|
|
+ .string("productName"),
|
|
|
+ .string("price"),
|
|
|
+ .string("salePrice"),
|
|
|
+ .string("rating"),
|
|
|
+ ]),
|
|
|
+ "propertyOrdering": .array([
|
|
|
+ .string("salePrice"),
|
|
|
+ .string("rating"),
|
|
|
+ .string("price"),
|
|
|
+ .string("productName"),
|
|
|
+ ]),
|
|
|
+ ]
|
|
|
+ ))
|
|
|
+ func generateContentSchemaNumberRangeMultiType(_ config: InstanceConfig,
|
|
|
+ _ schema: SchemaType) async throws {
|
|
|
struct ProductInfo: Codable {
|
|
|
let productName: String
|
|
|
- let rating: Int // Will correspond to .integer in schema
|
|
|
- let price: Double // Will correspond to .double in schema
|
|
|
- let salePrice: Float // Will correspond to .float in schema
|
|
|
+ let rating: Int
|
|
|
+ let price: Double
|
|
|
+ let salePrice: Float
|
|
|
}
|
|
|
+
|
|
|
let model = FirebaseAI.componentInstance(config).generativeModel(
|
|
|
- modelName: ModelNames.gemini2FlashLite,
|
|
|
- generationConfig: GenerationConfig(
|
|
|
- responseMIMEType: "application/json",
|
|
|
- responseSchema: .object(
|
|
|
- properties: [
|
|
|
- "productName": .string(description: "The name of the product"),
|
|
|
- "price": .double(
|
|
|
- description: "A price",
|
|
|
- minimum: 10.00,
|
|
|
- maximum: 120.00
|
|
|
- ),
|
|
|
- "salePrice": .float(
|
|
|
- description: "A sale price",
|
|
|
- minimum: 5.00,
|
|
|
- maximum: 90.00
|
|
|
- ),
|
|
|
- "rating": .integer(
|
|
|
- description: "A rating",
|
|
|
- minimum: 1,
|
|
|
- maximum: 5
|
|
|
- ),
|
|
|
- ],
|
|
|
- propertyOrdering: ["salePrice", "rating", "price", "productName"],
|
|
|
- title: "ProductInfo"
|
|
|
- ),
|
|
|
- ),
|
|
|
+ modelName: ModelNames.gemini2_5_FlashLite,
|
|
|
+ generationConfig: SchemaTests.generationConfig(schema: schema),
|
|
|
safetySettings: safetySettings
|
|
|
)
|
|
|
let prompt = "Describe a premium wireless headphone, including a user rating and price."
|
|
|
@@ -149,63 +197,127 @@ struct SchemaTests {
|
|
|
#expect(rating <= 5, "Expected a rating <= 5, but got \(rating)")
|
|
|
}
|
|
|
|
|
|
- @Test(arguments: InstanceConfig.allConfigs)
|
|
|
- func generateContentAnyOfSchema(_ config: InstanceConfig) async throws {
|
|
|
- struct MailingAddress: Decodable {
|
|
|
- let streetAddress: String
|
|
|
- let city: String
|
|
|
-
|
|
|
- // Canadian-specific
|
|
|
- let province: String?
|
|
|
- let postalCode: String?
|
|
|
-
|
|
|
- // U.S.-specific
|
|
|
- let state: String?
|
|
|
- let zipCode: String?
|
|
|
-
|
|
|
- var isCanadian: Bool {
|
|
|
- return province != nil && postalCode != nil && state == nil && zipCode == nil
|
|
|
+ fileprivate struct MailingAddress {
|
|
|
+ enum PostalInfo {
|
|
|
+ struct Canada: Decodable {
|
|
|
+ let province: String
|
|
|
+ let postalCode: String
|
|
|
}
|
|
|
|
|
|
- var isAmerican: Bool {
|
|
|
- return province == nil && postalCode == nil && state != nil && zipCode != nil
|
|
|
+ struct UnitedStates: Decodable {
|
|
|
+ let state: String
|
|
|
+ let zipCode: String
|
|
|
}
|
|
|
+
|
|
|
+ case canada(province: String, postalCode: String)
|
|
|
+ case unitedStates(state: String, zipCode: String)
|
|
|
}
|
|
|
|
|
|
+ let streetAddress: String
|
|
|
+ let city: String
|
|
|
+ let postalInfo: PostalInfo
|
|
|
+ }
|
|
|
+
|
|
|
+ private static let generateContentAnyOfOpenAPISchema = {
|
|
|
let streetSchema = Schema.string(description:
|
|
|
"The civic number and street name, for example, '123 Main Street'.")
|
|
|
let citySchema = Schema.string(description: "The name of the city.")
|
|
|
- let canadianAddressSchema = Schema.object(
|
|
|
+ let canadaPostalInfoSchema = Schema.object(
|
|
|
properties: [
|
|
|
- "streetAddress": streetSchema,
|
|
|
- "city": citySchema,
|
|
|
"province": .string(description:
|
|
|
"The 2-letter province or territory code, for example, 'ON', 'QC', or 'NU'."),
|
|
|
"postalCode": .string(description: "The postal code, for example, 'A1A 1A1'."),
|
|
|
- ],
|
|
|
- description: "A Canadian mailing address"
|
|
|
+ ]
|
|
|
)
|
|
|
- let americanAddressSchema = Schema.object(
|
|
|
+ let unitedStatesPostalInfoSchema = Schema.object(
|
|
|
properties: [
|
|
|
- "streetAddress": streetSchema,
|
|
|
- "city": citySchema,
|
|
|
"state": .string(description:
|
|
|
"The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."),
|
|
|
"zipCode": .string(description: "The 5-digit ZIP code, for example, '12345'."),
|
|
|
- ],
|
|
|
- description: "A U.S. mailing address"
|
|
|
+ ]
|
|
|
)
|
|
|
+ let mailingAddressSchema = Schema.object(properties: [
|
|
|
+ "streetAddress": streetSchema,
|
|
|
+ "city": citySchema,
|
|
|
+ "postalInfo": .anyOf(schemas: [canadaPostalInfoSchema, unitedStatesPostalInfoSchema]),
|
|
|
+ ])
|
|
|
+ return Schema.array(items: mailingAddressSchema)
|
|
|
+ }()
|
|
|
+
|
|
|
+ private static let generateContentAnyOfJSONSchema = {
|
|
|
+ let streetSchema: JSONValue = .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string("The civic number and street name, for example, '123 Main Street'."),
|
|
|
+ ])
|
|
|
+ let citySchema: JSONValue = .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string("The name of the city."),
|
|
|
+ ])
|
|
|
+ let postalInfoSchema: JSONValue = .object([
|
|
|
+ "anyOf": .array([
|
|
|
+ .object([
|
|
|
+ "type": .string("object"),
|
|
|
+ "properties": .object([
|
|
|
+ "province": .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string(
|
|
|
+ "The 2-letter Canadian province or territory code, for example, 'ON', 'QC', or 'NU'."
|
|
|
+ ),
|
|
|
+ ]),
|
|
|
+ "postalCode": .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string("The Canadian postal code, for example, 'A1A 1A1'."),
|
|
|
+ ]),
|
|
|
+ ]),
|
|
|
+ "required": .array([.string("province"), .string("postalCode")]),
|
|
|
+ ]),
|
|
|
+ .object([
|
|
|
+ "type": .string("object"),
|
|
|
+ "properties": .object([
|
|
|
+ "state": .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string(
|
|
|
+ "The 2-letter U.S. state or territory code, for example, 'CA', 'NY', or 'TX'."
|
|
|
+ ),
|
|
|
+ ]),
|
|
|
+ "zipCode": .object([
|
|
|
+ "type": .string("string"),
|
|
|
+ "description": .string("The 5-digit U.S. ZIP code, for example, '12345'."),
|
|
|
+ ]),
|
|
|
+ ]),
|
|
|
+ "required": .array([.string("state"), .string("zipCode")]),
|
|
|
+ ]),
|
|
|
+ ]),
|
|
|
+ ])
|
|
|
+ let mailingAddressSchema: JSONObject = [
|
|
|
+ "type": .string("object"),
|
|
|
+ "description": .string("A mailing address"),
|
|
|
+ "properties": .object([
|
|
|
+ "streetAddress": streetSchema,
|
|
|
+ "city": citySchema,
|
|
|
+ "postalInfo": postalInfoSchema,
|
|
|
+ ]),
|
|
|
+ "required": .array([
|
|
|
+ .string("streetAddress"),
|
|
|
+ .string("city"),
|
|
|
+ .string("postalInfo"),
|
|
|
+ ]),
|
|
|
+ ]
|
|
|
+ return [
|
|
|
+ "type": .string("array"),
|
|
|
+ "items": .object(mailingAddressSchema),
|
|
|
+ ] as JSONObject
|
|
|
+ }()
|
|
|
+
|
|
|
+ @Test(arguments: testConfigs(
|
|
|
+ instanceConfigs: InstanceConfig.allConfigs,
|
|
|
+ openAPISchema: generateContentAnyOfOpenAPISchema,
|
|
|
+ jsonSchema: generateContentAnyOfJSONSchema
|
|
|
+ ))
|
|
|
+ func generateContentAnyOfSchema(_ config: InstanceConfig, _ schema: SchemaType) async throws {
|
|
|
let model = FirebaseAI.componentInstance(config).generativeModel(
|
|
|
- modelName: ModelNames.gemini2Flash,
|
|
|
- generationConfig: GenerationConfig(
|
|
|
- temperature: 0.0,
|
|
|
- topP: 0.0,
|
|
|
- topK: 1,
|
|
|
- responseMIMEType: "application/json",
|
|
|
- responseSchema: .array(items: .anyOf(
|
|
|
- schemas: [canadianAddressSchema, americanAddressSchema]
|
|
|
- ))
|
|
|
- ),
|
|
|
+ modelName: ModelNames.gemini2_5_Flash,
|
|
|
+ generationConfig: SchemaTests.generationConfig(schema: schema),
|
|
|
safetySettings: safetySettings
|
|
|
)
|
|
|
let prompt = """
|
|
|
@@ -217,19 +329,102 @@ struct SchemaTests {
|
|
|
let decodedAddresses = try JSONDecoder().decode([MailingAddress].self, from: jsonData)
|
|
|
try #require(decodedAddresses.count == 3, "Expected 3 JSON addresses, got \(text).")
|
|
|
let waterlooAddress = decodedAddresses[0]
|
|
|
- #expect(
|
|
|
- waterlooAddress.isCanadian,
|
|
|
- "Expected Canadian University of Waterloo address, got \(waterlooAddress)."
|
|
|
- )
|
|
|
+ #expect(waterlooAddress.city == "Waterloo")
|
|
|
+ if case let .canada(province, postalCode) = waterlooAddress.postalInfo {
|
|
|
+ #expect(province == "ON")
|
|
|
+ #expect(postalCode == "N2L 3G1")
|
|
|
+ } else {
|
|
|
+ Issue.record("Expected Canadian University of Waterloo address, got \(waterlooAddress).")
|
|
|
+ }
|
|
|
let berkeleyAddress = decodedAddresses[1]
|
|
|
- #expect(
|
|
|
- berkeleyAddress.isAmerican,
|
|
|
- "Expected American UC Berkeley address, got \(berkeleyAddress)."
|
|
|
- )
|
|
|
+ #expect(berkeleyAddress.city == "Berkeley")
|
|
|
+ if case let .unitedStates(state, zipCode) = berkeleyAddress.postalInfo {
|
|
|
+ #expect(state == "CA")
|
|
|
+ #expect(zipCode == "94720")
|
|
|
+ } else {
|
|
|
+ Issue.record("Expected American UC Berkeley address, got \(berkeleyAddress).")
|
|
|
+ }
|
|
|
let queensAddress = decodedAddresses[2]
|
|
|
- #expect(
|
|
|
- queensAddress.isCanadian,
|
|
|
- "Expected Canadian Queen's University address, got \(queensAddress)."
|
|
|
+ #expect(queensAddress.city == "Kingston")
|
|
|
+ if case let .canada(province, postalCode) = queensAddress.postalInfo {
|
|
|
+ #expect(province == "ON")
|
|
|
+ #expect(postalCode == "K7L 3N6")
|
|
|
+ } else {
|
|
|
+ Issue.record("Expected Canadian Queen's University address, got \(queensAddress).")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ enum SchemaType: CustomTestStringConvertible {
|
|
|
+ case openAPI(Schema)
|
|
|
+ case json(JSONObject)
|
|
|
+
|
|
|
+ var testDescription: String {
|
|
|
+ switch self {
|
|
|
+ case .openAPI:
|
|
|
+ return "OpenAPI Schema"
|
|
|
+ case .json:
|
|
|
+ return "JSON Schema"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static func generationConfig(schema: SchemaType) -> GenerationConfig {
|
|
|
+ let mimeType = "application/json"
|
|
|
+ switch schema {
|
|
|
+ case let .openAPI(openAPISchema):
|
|
|
+ return GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1, responseMIMEType: mimeType,
|
|
|
+ responseSchema: openAPISchema)
|
|
|
+ case let .json(jsonSchema):
|
|
|
+ return GenerationConfig(temperature: 0.0, topP: 0.0, topK: 1, responseMIMEType: mimeType,
|
|
|
+ responseJSONSchema: jsonSchema)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static func testConfigs(instanceConfigs: [InstanceConfig], openAPISchema: Schema,
|
|
|
+ jsonSchema: JSONObject) -> [(InstanceConfig, SchemaType)] {
|
|
|
+ return instanceConfigs.flatMap { [($0, .openAPI(openAPISchema)), ($0, .json(jsonSchema))] }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+extension SchemaTests.MailingAddress: Decodable {
|
|
|
+ enum CodingKeys: CodingKey {
|
|
|
+ case streetAddress
|
|
|
+ case city
|
|
|
+ case postalInfo
|
|
|
+ }
|
|
|
+
|
|
|
+ init(from decoder: any Decoder) throws {
|
|
|
+ let container = try decoder.container(keyedBy: CodingKeys.self)
|
|
|
+ streetAddress = try container.decode(String.self, forKey: .streetAddress)
|
|
|
+ city = try container.decode(String.self, forKey: .city)
|
|
|
+ let canadaPostalInfo = try? container.decode(PostalInfo.Canada.self, forKey: .postalInfo)
|
|
|
+ let unitedStatesPostalInfo = try? container.decode(
|
|
|
+ PostalInfo.UnitedStates.self, forKey: .postalInfo
|
|
|
)
|
|
|
+
|
|
|
+ if canadaPostalInfo != nil, unitedStatesPostalInfo != nil {
|
|
|
+ throw DecodingError.dataCorruptedError(
|
|
|
+ forKey: .postalInfo,
|
|
|
+ in: container,
|
|
|
+ debugDescription: "Ambiguous postal info: matches both Canadian and U.S. formats."
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ if let canadaPostalInfo {
|
|
|
+ postalInfo = .canada(
|
|
|
+ province: canadaPostalInfo.province, postalCode: canadaPostalInfo.postalCode
|
|
|
+ )
|
|
|
+ } else if let unitedStatesPostalInfo {
|
|
|
+ postalInfo = .unitedStates(
|
|
|
+ state: unitedStatesPostalInfo.state, zipCode: unitedStatesPostalInfo.zipCode
|
|
|
+ )
|
|
|
+ } else {
|
|
|
+ throw DecodingError.typeMismatch(
|
|
|
+ PostalInfo.self, .init(
|
|
|
+ codingPath: container.codingPath,
|
|
|
+ debugDescription: "Expected Canadian or U.S. postal info."
|
|
|
+ )
|
|
|
+ )
|
|
|
+ }
|
|
|
}
|
|
|
}
|