浏览代码

chore(ai): Add unit tests for Live API (#15411)

Co-authored-by: Andrew Heard <andrewheard@google.com>
Daymon 5 月之前
父节点
当前提交
0e5a4e0424

+ 191 - 0
FirebaseAI/Tests/Unit/Types/Live/BidiGenerateContentServerMessageTests.swift

@@ -0,0 +1,191 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import XCTest
+
+@testable import FirebaseAILogic
+
+@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
+@available(watchOS, unavailable)
+final class BidiGenerateContentServerMessageTests: XCTestCase {
+  let decoder = JSONDecoder()
+
+  func testDecodeBidiGenerateContentServerMessage_setupComplete() throws {
+    let json = """
+    {
+      "setupComplete" : {}
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let serverMessage = try decoder.decode(BidiGenerateContentServerMessage.self, from: jsonData)
+    guard case .setupComplete = serverMessage.messageType else {
+      XCTFail("Decoded message is not a setupComplete message.")
+      return
+    }
+  }
+
+  func testDecodeBidiGenerateContentServerMessage_serverContent() throws {
+    let json = """
+    {
+      "serverContent" : {
+        "modelTurn" : {
+          "parts" : [
+            {
+              "inlineData" : {
+                "data" : "BQUFBQU=",
+                "mimeType" : "audio/pcm"
+              }
+            }
+          ],
+          "role" : "model"
+        },
+        "turnComplete": true,
+        "groundingMetadata": {
+          "webSearchQueries": ["query1", "query2"],
+          "groundingChunks": [
+            { "web": { "uri": "uri1", "title": "title1" } }
+          ],
+          "groundingSupports": [
+            { "segment": { "endIndex": 10, "text": "text" }, "groundingChunkIndices": [0] }
+          ],
+          "searchEntryPoint": { "renderedContent": "html" }
+        },
+        "inputTranscription": {
+          "text": "What day of the week is it?"
+        },
+        "outputTranscription": {
+          "text": "Today is friday"
+        }
+      }
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let serverMessage = try decoder.decode(BidiGenerateContentServerMessage.self, from: jsonData)
+    guard case let .serverContent(serverContent) = serverMessage.messageType else {
+      XCTFail("Decoded message is not a serverContent message.")
+      return
+    }
+
+    XCTAssertEqual(serverContent.turnComplete, true)
+    XCTAssertNil(serverContent.interrupted)
+    XCTAssertNil(serverContent.generationComplete)
+
+    let modelTurn = try XCTUnwrap(serverContent.modelTurn)
+    XCTAssertEqual(modelTurn.role, "model")
+    XCTAssertEqual(modelTurn.parts.count, 1)
+    let part = try XCTUnwrap(modelTurn.parts.first as? InlineDataPart)
+    XCTAssertEqual(part.data, Data(repeating: 5, count: 5))
+    XCTAssertEqual(part.mimeType, "audio/pcm")
+
+    let metadata = try XCTUnwrap(serverContent.groundingMetadata)
+    XCTAssertEqual(metadata.webSearchQueries, ["query1", "query2"])
+    XCTAssertEqual(metadata.groundingChunks.count, 1)
+    let groundingChunk = try XCTUnwrap(metadata.groundingChunks.first)
+    let webChunk = try XCTUnwrap(groundingChunk.web)
+    XCTAssertEqual(webChunk.uri, "uri1")
+    XCTAssertEqual(metadata.groundingSupports.count, 1)
+    let groundingSupport = try XCTUnwrap(metadata.groundingSupports.first)
+    XCTAssertEqual(groundingSupport.segment.startIndex, 0)
+    XCTAssertEqual(groundingSupport.segment.partIndex, 0)
+    XCTAssertEqual(groundingSupport.segment.endIndex, 10)
+    XCTAssertEqual(groundingSupport.segment.text, "text")
+    let searchEntryPoint = try XCTUnwrap(metadata.searchEntryPoint)
+    XCTAssertEqual(searchEntryPoint.renderedContent, "html")
+
+    let inputTranscription = try XCTUnwrap(serverContent.inputTranscription)
+    XCTAssertEqual(inputTranscription.text, "What day of the week is it?")
+
+    let outputTranscription = try XCTUnwrap(serverContent.outputTranscription)
+    XCTAssertEqual(outputTranscription.text, "Today is friday")
+  }
+
+  func testDecodeBidiGenerateContentServerMessage_toolCall() throws {
+    let json = """
+    {
+      "toolCall" : {
+        "functionCalls" : [
+          {
+            "name": "changeBackgroundColor",
+            "id": "functionCall-12345-67890",
+            "args" : {
+              "color": "#F54927"
+            }
+          }
+        ]
+      }
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let serverMessage = try decoder.decode(BidiGenerateContentServerMessage.self, from: jsonData)
+    guard case let .toolCall(toolCall) = serverMessage.messageType else {
+      XCTFail("Decoded message is not a toolCall message.")
+      return
+    }
+
+    let functionCalls = try XCTUnwrap(toolCall.functionCalls)
+    XCTAssertEqual(functionCalls.count, 1)
+    let functionCall = try XCTUnwrap(functionCalls.first)
+    XCTAssertEqual(functionCall.name, "changeBackgroundColor")
+    XCTAssertEqual(functionCall.id, "functionCall-12345-67890")
+    let args = try XCTUnwrap(functionCall.args)
+    guard case let .string(color) = args["color"] else {
+      XCTFail("Missing color argument")
+      return
+    }
+    XCTAssertEqual(color, "#F54927")
+  }
+
+  func testDecodeBidiGenerateContentServerMessage_toolCallCancellation() throws {
+    let json = """
+    {
+      "toolCallCancellation" : {
+        "ids" : ["functionCall-12345-67890"]
+      }
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let serverMessage = try decoder.decode(BidiGenerateContentServerMessage.self, from: jsonData)
+    guard case let .toolCallCancellation(toolCallCancellation) = serverMessage.messageType else {
+      XCTFail("Decoded message is not a toolCallCancellation message.")
+      return
+    }
+
+    let ids = try XCTUnwrap(toolCallCancellation.ids)
+    XCTAssertEqual(ids, ["functionCall-12345-67890"])
+  }
+
+  func testDecodeBidiGenerateContentServerMessage_goAway() throws {
+    let json = """
+    {
+      "goAway" : {
+        "timeLeft": "1.23456789s"
+      }
+    }
+    """
+    let jsonData = try XCTUnwrap(json.data(using: .utf8))
+
+    let serverMessage = try decoder.decode(BidiGenerateContentServerMessage.self, from: jsonData)
+    guard case let .goAway(goAway) = serverMessage.messageType else {
+      XCTFail("Decoded message is not a goAway message.")
+      return
+    }
+
+    XCTAssertEqual(goAway.timeLeft?.seconds, 1)
+    XCTAssertEqual(goAway.timeLeft?.nanos, 234_567_890)
+  }
+}

+ 62 - 0
FirebaseAI/Tests/Unit/Types/Live/VoiceConfigTests.swift

@@ -0,0 +1,62 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import XCTest
+
+@testable import FirebaseAILogic
+
+@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
+@available(watchOS, unavailable)
+final class VoiceConfigTests: XCTestCase {
+  let encoder = JSONEncoder()
+
+  override func setUp() {
+    super.setUp()
+    encoder.outputFormatting = [.prettyPrinted, .sortedKeys, .withoutEscapingSlashes]
+  }
+
+  func testEncodeVoiceConfig_prebuiltVoice() throws {
+    let voice = VoiceConfig.prebuiltVoiceConfig(
+      PrebuiltVoiceConfig(voiceName: "Zephyr")
+    )
+
+    let jsonData = try encoder.encode(voice)
+
+    let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+    XCTAssertEqual(json, """
+    {
+      "prebuiltVoiceConfig" : {
+        "voiceName" : "Zephyr"
+      }
+    }
+    """)
+  }
+
+  func testEncodeVoiceConfig_customVoice() throws {
+    let voice = VoiceConfig.customVoiceConfig(
+      CustomVoiceConfig(customVoiceSample: Data(repeating: 5, count: 5))
+    )
+
+    let jsonData = try encoder.encode(voice)
+
+    let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
+    XCTAssertEqual(json, """
+    {
+      "customVoiceConfig" : {
+        "customVoiceSample" : "BQUFBQU="
+      }
+    }
+    """)
+  }
+}