Explorar el Código

[Firebase AI] Add support for Gemma models with Developer API (#14823)

Andrew Heard hace 11 meses
padre
commit
0be8ceca17

+ 2 - 1
FirebaseAI/Sources/FirebaseAI.swift

@@ -72,7 +72,8 @@ public final class FirebaseAI: Sendable {
                               systemInstruction: ModelContent? = nil,
                               requestOptions: RequestOptions = RequestOptions())
     -> GenerativeModel {
-    if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix) {
+    if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix)
+      && !modelName.starts(with: GenerativeModel.gemmaModelNamePrefix) {
       AILog.warning(code: .unsupportedGeminiModel, """
       Unsupported Gemini model "\(modelName)"; see \
       https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names.

+ 10 - 7
FirebaseAI/Sources/GenerateContentResponse.swift

@@ -371,13 +371,7 @@ extension Candidate: Decodable {
         content = ModelContent(parts: [])
       }
     } catch {
-      // Check if `content` can be decoded as an empty dictionary to detect the `"content": {}` bug.
-      if let content = try? container.decode([String: String].self, forKey: .content),
-         content.isEmpty {
-        throw InvalidCandidateError.emptyContent(underlyingError: error)
-      } else {
-        throw InvalidCandidateError.malformedContent(underlyingError: error)
-      }
+      throw InvalidCandidateError.malformedContent(underlyingError: error)
     }
 
     if let safetyRatings = try container.decodeIfPresent(
@@ -395,6 +389,15 @@ extension Candidate: Decodable {
 
     finishReason = try container.decodeIfPresent(FinishReason.self, forKey: .finishReason)
 
+    // The `content` may only be empty if a `finishReason` is included; if neither are included in
+    // the response then this is likely the `"content": {}` bug.
+    guard !content.parts.isEmpty || finishReason != nil else {
+      throw InvalidCandidateError.emptyContent(underlyingError: DecodingError.dataCorrupted(.init(
+        codingPath: [CodingKeys.content, CodingKeys.finishReason],
+        debugDescription: "Invalid Candidate: empty content and no finish reason"
+      )))
+    }
+
     citationMetadata = try container.decodeIfPresent(
       CitationMetadata.self,
       forKey: .citationMetadata

+ 3 - 0
FirebaseAI/Sources/GenerativeModel.swift

@@ -23,6 +23,9 @@ public final class GenerativeModel: Sendable {
   /// Model name prefix to identify Gemini models.
   static let geminiModelNamePrefix = "gemini-"
 
+  /// Model name prefix to identify Gemma models.
+  static let gemmaModelNamePrefix = "gemma-"
+
   /// The name of the model, for example "gemini-2.0-flash".
   let modelName: String
 

+ 6 - 0
FirebaseAI/Sources/ModelContent.swift

@@ -112,6 +112,12 @@ extension ModelContent: Codable {
     case role
     case internalParts = "parts"
   }
+
+  public init(from decoder: any Decoder) throws {
+    let container = try decoder.container(keyedBy: CodingKeys.self)
+    role = try container.decodeIfPresent(String.self, forKey: .role)
+    internalParts = try container.decodeIfPresent([InternalPart].self, forKey: .internalParts) ?? []
+  }
 }
 
 @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

+ 1 - 0
FirebaseAI/Tests/TestApp/Sources/Constants.swift

@@ -24,4 +24,5 @@ public enum ModelNames {
   public static let gemini2Flash = "gemini-2.0-flash-001"
   public static let gemini2FlashLite = "gemini-2.0-flash-lite-001"
   public static let gemini2FlashExperimental = "gemini-2.0-flash-exp"
+  public static let gemma3_27B = "gemma-3-27b-it"
 }

+ 62 - 30
FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift

@@ -47,12 +47,24 @@ struct GenerateContentIntegrationTests {
     storage = Storage.storage()
   }
 
-  @Test(arguments: InstanceConfig.allConfigs)
-  func generateContent(_ config: InstanceConfig) async throws {
+  @Test(arguments: [
+    (InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite),
+    (InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite),
+    (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite),
+    (InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B),
+    (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B),
+    (InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B),
+  ])
+  func generateContent(_ config: InstanceConfig, modelName: String) async throws {
     let model = FirebaseAI.componentInstance(config).generativeModel(
-      modelName: ModelNames.gemini2FlashLite,
+      modelName: modelName,
       generationConfig: generationConfig,
-      safetySettings: safetySettings
+      safetySettings: safetySettings,
     )
     let prompt = "Where is Google headquarters located? Answer with the city name only."
 
@@ -62,17 +74,22 @@ struct GenerateContentIntegrationTests {
     #expect(text == "Mountain View")
 
     let usageMetadata = try #require(response.usageMetadata)
-    #expect(usageMetadata.promptTokenCount == 13)
+    #expect(usageMetadata.promptTokenCount.isEqual(to: 13, accuracy: tokenCountAccuracy))
     #expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy))
     #expect(usageMetadata.totalTokenCount.isEqual(to: 16, accuracy: tokenCountAccuracy))
     #expect(usageMetadata.promptTokensDetails.count == 1)
     let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first)
     #expect(promptTokensDetails.modality == .text)
     #expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount)
-    #expect(usageMetadata.candidatesTokensDetails.count == 1)
-    let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
-    #expect(candidatesTokensDetails.modality == .text)
-    #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
+    // The field `candidatesTokensDetails` is not included when using Gemma models.
+    if modelName == ModelNames.gemma3_27B {
+      #expect(usageMetadata.candidatesTokensDetails.isEmpty)
+    } else {
+      #expect(usageMetadata.candidatesTokensDetails.count == 1)
+      let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
+      #expect(candidatesTokensDetails.modality == .text)
+      #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
+    }
   }
 
   @Test(
@@ -168,24 +185,35 @@ struct GenerateContentIntegrationTests {
 
   // MARK: Streaming Tests
 
-  @Test(arguments: InstanceConfig.allConfigs)
-  func generateContentStream(_ config: InstanceConfig) async throws {
-    let expectedText = """
-    1. Mercury
-    2. Venus
-    3. Earth
-    4. Mars
-    5. Jupiter
-    6. Saturn
-    7. Uranus
-    8. Neptune
-    """
+  @Test(arguments: [
+    (InstanceConfig.vertexAI_v1, ModelNames.gemini2FlashLite),
+    (InstanceConfig.vertexAI_v1_staging, ModelNames.gemini2FlashLite),
+    (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2FlashLite),
+    (InstanceConfig.vertexAI_v1beta_staging, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta, ModelNames.gemma3_27B),
+    (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta_staging, ModelNames.gemma3_27B),
+    (InstanceConfig.googleAI_v1_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2FlashLite),
+    (InstanceConfig.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemma3_27B),
+  ])
+  func generateContentStream(_ config: InstanceConfig, modelName: String) async throws {
+    let expectedResponse = [
+      "Mercury", "Venus", "Earth", "Mars", "Jupiter", "Saturn", "Uranus", "Neptune",
+    ]
     let prompt = """
-    What are the names of the planets in the solar system, ordered from closest to furthest from
-    the sun? Answer with a Markdown numbered list of the names and no other text.
+    Generate a JSON array of strings. The array must contain the names of the planets in Earth's \
+    solar system, ordered from closest to furthest from the Sun.
+
+    Constraints:
+    - Output MUST be only the JSON array.
+    - Do NOT include any introductory or explanatory text.
+    - Do NOT wrap the JSON in Markdown code blocks (e.g., ```json ... ``` or ``` ... ```).
+    - The response must start with '[' and end with ']'.
     """
     let model = FirebaseAI.componentInstance(config).generativeModel(
-      modelName: ModelNames.gemini2FlashLite,
+      modelName: modelName,
       generationConfig: generationConfig,
       safetySettings: safetySettings
     )
@@ -194,7 +222,13 @@ struct GenerateContentIntegrationTests {
     let stream = try chat.sendMessageStream(prompt)
     var textValues = [String]()
     for try await value in stream {
-      try textValues.append(#require(value.text))
+      if let text = value.text {
+        textValues.append(text)
+      } else if let finishReason = value.candidates.first?.finishReason {
+        #expect(finishReason == .stop)
+      } else {
+        Issue.record("Expected a candidate with a `TextPart` or a `finishReason`; got \(value).")
+      }
     }
 
     let userHistory = try #require(chat.history.first)
@@ -206,11 +240,9 @@ struct GenerateContentIntegrationTests {
     #expect(modelHistory.role == "model")
     #expect(modelHistory.parts.count == 1)
     let modelTextPart = try #require(modelHistory.parts.first as? TextPart)
-    let modelText = modelTextPart.text.trimmingCharacters(in: .whitespacesAndNewlines)
-    #expect(modelText == expectedText)
-    #expect(textValues.count > 1)
-    let text = textValues.joined().trimmingCharacters(in: .whitespacesAndNewlines)
-    #expect(text == expectedText)
+    let modelJSONData = try #require(modelTextPart.text.data(using: .utf8))
+    let response = try JSONDecoder().decode([String].self, from: modelJSONData)
+    #expect(response == expectedResponse)
   }
 
   // MARK: - App Check Tests

+ 12 - 6
FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift

@@ -918,6 +918,9 @@ final class GenerativeModelVertexAITests: XCTestCase {
   func testGenerateContent_failure_malformedContent() async throws {
     MockURLProtocol
       .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
+        // Note: Although this file does not contain `parts` in `content`, it is not actually
+        // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
+        // the proto API. Therefore, this test checks for the `emptyContent` error instead.
         forResource: "unary-failure-malformed-content",
         withExtension: "json",
         subdirectory: vertexSubdirectory
@@ -939,13 +942,13 @@ final class GenerativeModelVertexAITests: XCTestCase {
       return
     }
     let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
-    guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
-      XCTFail("Not a malformed content error: \(invalidCandidateError)")
+    guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
+      XCTFail("Not an empty content error: \(invalidCandidateError)")
       return
     }
     _ = try XCTUnwrap(
-      malformedContentUnderlyingError as? DecodingError,
-      "Not a decoding error: \(malformedContentUnderlyingError)"
+      emptyContentUnderlyingError as? DecodingError,
+      "Not a decoding error: \(emptyContentUnderlyingError)"
     )
   }
 
@@ -1446,6 +1449,9 @@ final class GenerativeModelVertexAITests: XCTestCase {
   func testGenerateContentStream_malformedContent() async throws {
     MockURLProtocol
       .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
+        // Note: Although this file does not contain `parts` in `content`, it is not actually
+        // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
+        // the proto API. Therefore, this test checks for the `emptyContent` error instead.
         forResource: "streaming-failure-malformed-content",
         withExtension: "txt",
         subdirectory: vertexSubdirectory
@@ -1457,8 +1463,8 @@ final class GenerativeModelVertexAITests: XCTestCase {
         XCTFail("Unexpected content in stream: \(content)")
       }
     } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
-      guard case let .malformedContent(contentError) = underlyingError else {
-        XCTFail("Not a malformed content error: \(underlyingError)")
+      guard case let .emptyContent(contentError) = underlyingError else {
+        XCTFail("Not an empty content error: \(underlyingError)")
         return
       }