Explorar el Código

fix(ai): Add retry mechanism to flakey interrupt test (#15421)

Co-authored-by: Andrew Heard <andrewheard@google.com>
Daymon hace 5 meses
padre
commit
0c50b1b336

+ 22 - 16
FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift

@@ -276,38 +276,39 @@ struct LiveSessionTests {
     await session.close()
   }
 
-  // Getting a limited use token adds too much of an overhead; we can't interrupt the model in time
   @Test(
     arguments: arguments.filter { !$0.0.useLimitedUseAppCheckTokens }
   )
+  // Getting a limited use token adds too much of an overhead; we can't interrupt the model in time
   func realtime_interruption(_ config: InstanceConfig, modelName: String) async throws {
     let model = FirebaseAI.componentInstance(config).liveModel(
       modelName: modelName,
       generationConfig: audioConfig
     )
 
-    let session = try await model.connect()
-
     guard let audioFile = NSDataAsset(name: "hello") else {
       Issue.record("Missing audio file 'hello.wav' in Assets")
       return
     }
-    await session.sendAudioRealtime(audioFile.data)
-    await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
 
-    // wait a second to allow the model to start generating (and cuase a proper interruption)
-    try await Task.sleep(nanoseconds: oneSecondInNanoseconds)
-    await session.sendAudioRealtime(audioFile.data)
-    await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
+    try await retry(times: 3, delayInSeconds: 2.0) {
+      let session = try await model.connect()
+      await session.sendAudioRealtime(audioFile.data)
+      await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
 
-    for try await content in session.responsesOf(LiveServerContent.self) {
-      if content.wasInterrupted {
-        break
-      }
+      // wait a second to allow the model to start generating (and cuase a proper interruption)
+      try await Task.sleep(nanoseconds: oneSecondInNanoseconds)
+      await session.sendAudioRealtime(audioFile.data)
+      await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
 
-      if content.isTurnComplete {
-        Issue.record("The model never sent an interrupted message.")
-        return
+      for try await content in session.responsesOf(LiveServerContent.self) {
+        if content.wasInterrupted {
+          break
+        }
+
+        if content.isTurnComplete {
+          throw NoInterruptionError()
+        }
       }
     }
   }
@@ -472,6 +473,11 @@ private extension LiveSession {
   }
 }
 
+private struct NoInterruptionError: Error,
+  CustomStringConvertible {
+  var description: String { "The model never sent an interrupted message." }
+}
+
 private extension ModelContent {
   /// A collection of text from all parts.
   ///

+ 33 - 0
FirebaseAI/Tests/TestApp/Tests/Utilities/IntegrationTestUtils.swift

@@ -13,6 +13,7 @@
 // limitations under the License.
 
 import Foundation
+import Testing
 import XCTest
 
 enum IntegrationTestUtils {
@@ -43,3 +44,35 @@ extension Numeric where Self: Strideable, Self.Stride.Magnitude: Comparable {
     return distance(to: other).magnitude <= accuracy.magnitude
   }
 }
+
+/// Retry a flakey test N times before failing.
+///
+/// - Parameters:
+///   - times: The amount of attempts to retry before failing. Must be greater than 0.
+///   - delayInSeconds: How long to wait before performing the next attempt.
+@discardableResult
+func retry<T>(times: Int,
+              delayInSeconds: TimeInterval = 0.1,
+              _ test: () async throws -> T) async throws -> T {
+  if times <= 0 {
+    precondition(times <= 0, "Times must be greater than 0.")
+  }
+  let delayNanos = UInt64(delayInSeconds * 1e+9)
+  var lastError: Error?
+  for attempt in 1 ... times {
+    do { return try await test() }
+    catch {
+      lastError = error
+      // only wait if we have more attempts
+      if attempt < times {
+        try? await Task.sleep(nanoseconds: delayNanos)
+      }
+    }
+  }
+  guard let lastError else {
+    // should not happen unless we change the above code in some way
+    fatalError("Internal error: retry loop finished without error")
+  }
+  Issue.record("Flaky test failed after \(times) attempt(s): \(String(describing: lastError))")
+  throw lastError
+}