Bladeren bron

Make ContentModality rawValue and init internal

The  property and the initializer of the
struct were made  instead of .
labs-code-app[bot] 1 jaar geleden
bovenliggende
commit
798bd3a24b

+ 25 - 0
FirebaseVertexAI/Sources/ContentModality.swift

@@ -0,0 +1,25 @@
+import Foundation
+
+/// Represents the different content modalities that can be used in a response.
+@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
+public struct ContentModality: Codable, Hashable {
+    let rawValue: String
+
+    public static let text = ContentModality(rawValue: "TEXT")
+    public static let image = ContentModality(rawValue: "IMAGE")
+    public static let audio = ContentModality(rawValue: "AUDIO")
+    
+    public func hash(into hasher: inout Hasher) -> Int {
+        return rawValue.hashValue
+    }
+
+    init(rawValue: String) {
+        self.rawValue = rawValue
+    }
+}
+
+extension ContentModality: Equatable {
+    public static func == (lhs: ContentModality, rhs: ContentModality) -> Bool {
+        return lhs.rawValue == rhs.rawValue
+    }
+}

+ 6 - 1
FirebaseVertexAI/Sources/GenerationConfig.swift

@@ -48,6 +48,9 @@ public struct GenerationConfig {
   /// Output schema of the generated candidate text.
   let responseSchema: Schema?
 
+    /// Specifies the allowed output response modalities.
+    let responseModalities: [ContentModality]?
+
   /// Creates a new `GenerationConfig` value.
   ///
   /// See the
@@ -143,7 +146,7 @@ public struct GenerationConfig {
               candidateCount: Int? = nil, maxOutputTokens: Int? = nil,
               presencePenalty: Float? = nil, frequencyPenalty: Float? = nil,
               stopSequences: [String]? = nil, responseMIMEType: String? = nil,
-              responseSchema: Schema? = nil) {
+              responseSchema: Schema? = nil, responseModalities: [ContentModality]? = nil) {
     // Explicit init because otherwise if we re-arrange the above variables it changes the API
     // surface.
     self.temperature = temperature
@@ -156,6 +159,7 @@ public struct GenerationConfig {
     self.stopSequences = stopSequences
     self.responseMIMEType = responseMIMEType
     self.responseSchema = responseSchema
+    self.responseModalities = responseModalities
   }
 }
 
@@ -174,5 +178,6 @@ extension GenerationConfig: Encodable {
     case stopSequences
     case responseMIMEType = "responseMimeType"
     case responseSchema
+      case responseModalities = "responseModalities"
   }
 }

+ 108 - 38
FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift

@@ -41,27 +41,74 @@ final class GenerationConfigTests: XCTestCase {
     """)
   }
 
-  func testEncodeGenerationConfig_allOptions() throws {
-    let temperature: Float = 0.5
-    let topP: Float = 0.75
-    let topK = 40
-    let candidateCount = 2
-    let maxOutputTokens = 256
-    let presencePenalty: Float = 0.5
-    let frequencyPenalty: Float = 0.75
-    let stopSequences = ["END", "DONE"]
-    let responseMIMEType = "application/json"
+    func testEncodeGenerationConfig_allOptions() throws {
+        let temperature: Float = 0.5
+        let topP: Float = 0.75
+        let topK = 40
+        let candidateCount = 2
+        let maxOutputTokens = 256
+        let presencePenalty: Float = 0.5
+        let frequencyPenalty: Float = 0.75
+        let stopSequences = ["END", "DONE"]
+        let responseMIMEType = "application/json"
+        let responseModalities: [ContentModality] = [.text, .image, .audio]
+        let generationConfig = GenerationConfig(
+          temperature: temperature,
+          topP: topP,
+          topK: topK,
+          candidateCount: candidateCount,
+          maxOutputTokens: maxOutputTokens,
+          presencePenalty: presencePenalty,
+          frequencyPenalty: frequencyPenalty,
+          stopSequences: stopSequences,
+          responseMIMEType: responseMIMEType,
+          responseSchema: .array(items: .string()),
+            responseModalities: responseModalities
+        )
+
+        let jsonData = try encoder.encode(generationConfig)
+
+        let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+        XCTAssertEqual(json, """
+        {
+          "candidateCount" : \(candidateCount),
+          "frequencyPenalty" : \(frequencyPenalty),
+          "maxOutputTokens" : \(maxOutputTokens),
+          "presencePenalty" : \(presencePenalty),
+          "responseMimeType" : "\(responseMIMEType)",
+          "responseModalities" : [
+            "TEXT",
+            "IMAGE",
+            "AUDIO"
+          ],
+          "responseSchema" : {
+            "items" : {
+              "nullable" : false,
+              "type" : "STRING"
+            },
+            "nullable" : false,
+            "type" : "ARRAY"
+          },
+          "stopSequences" : [
+            "END",
+            "DONE"
+          ],
+          "temperature" : \(temperature),
+          "topK" : \(topK),
+          "topP" : \(topP)
+        }
+        """)
+    }
+
+  func testEncodeGenerationConfig_jsonResponse() throws {
+    let mimeType = "application/json"
     let generationConfig = GenerationConfig(
-      temperature: temperature,
-      topP: topP,
-      topK: topK,
-      candidateCount: candidateCount,
-      maxOutputTokens: maxOutputTokens,
-      presencePenalty: presencePenalty,
-      frequencyPenalty: frequencyPenalty,
-      stopSequences: stopSequences,
-      responseMIMEType: responseMIMEType,
-      responseSchema: .array(items: .string())
+      responseMIMEType: mimeType,
+      responseSchema: .object(properties: [
+        "firstName": .string(),
+        "lastName": .string(),
+        "age": .integer(),
+      ])
     )
 
     let jsonData = try encoder.encode(generationConfig)
@@ -69,29 +116,52 @@ final class GenerationConfigTests: XCTestCase {
     let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
     XCTAssertEqual(json, """
     {
-      "candidateCount" : \(candidateCount),
-      "frequencyPenalty" : \(frequencyPenalty),
-      "maxOutputTokens" : \(maxOutputTokens),
-      "presencePenalty" : \(presencePenalty),
-      "responseMimeType" : "\(responseMIMEType)",
+      "responseMimeType" : "\(mimeType)",
       "responseSchema" : {
-        "items" : {
-          "nullable" : false,
-          "type" : "STRING"
-        },
         "nullable" : false,
-        "type" : "ARRAY"
-      },
-      "stopSequences" : [
-        "END",
-        "DONE"
-      ],
-      "temperature" : \(temperature),
-      "topK" : \(topK),
-      "topP" : \(topP)
+        "properties" : {
+          "age" : {
+            "nullable" : false,
+            "type" : "INTEGER"
+          },
+          "firstName" : {
+            "nullable" : false,
+            "type" : "STRING"
+          },
+          "lastName" : {
+            "nullable" : false,
+            "type" : "STRING"
+          }
+        },
+        "required" : [
+          "age",
+          "firstName",
+          "lastName"
+        ],
+        "type" : "OBJECT"
+      }
     }
     """)
   }
+    
+    func testEncodeGenerationConfig_with_responseModalities() throws {
+      let responseModalities: [ContentModality] = [.text, .image, .audio]
+      let generationConfig = GenerationConfig(responseModalities: responseModalities)
+
+      let jsonData = try encoder.encode(generationConfig)
+
+      let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+      XCTAssertEqual(json, """
+      {
+        "responseModalities" : [
+          "TEXT",
+          "IMAGE",
+          "AUDIO"
+        ]
+      }
+      """)
+    }
+}
 
   func testEncodeGenerationConfig_jsonResponse() throws {
     let mimeType = "application/json"