| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695 |
- // Copyright 2023 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- import FirebaseAppCheckInterop
- import FirebaseAuthInterop
- import FirebaseCore
- import XCTest
- @testable import FirebaseAI
- @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
- final class GenerativeModelVertexAITests: XCTestCase {
- let testPrompt = "What sorts of questions can I ask you?"
- let safetyRatingsNegligible: [SafetyRating] = [
- .init(
- category: .sexuallyExplicit,
- probability: .negligible,
- probabilityScore: 0.1431877,
- severity: .negligible,
- severityScore: 0.11027937,
- blocked: false
- ),
- .init(
- category: .hateSpeech,
- probability: .negligible,
- probabilityScore: 0.029035643,
- severity: .negligible,
- severityScore: 0.05613278,
- blocked: false
- ),
- .init(
- category: .harassment,
- probability: .negligible,
- probabilityScore: 0.087252244,
- severity: .negligible,
- severityScore: 0.04509957,
- blocked: false
- ),
- .init(
- category: .dangerousContent,
- probability: .negligible,
- probabilityScore: 0.2641685,
- severity: .negligible,
- severityScore: 0.082253955,
- blocked: false
- ),
- ].sorted()
- let safetyRatingsInvalidIgnored = [
- SafetyRating(
- category: .hateSpeech,
- probability: .negligible,
- probabilityScore: 0.00039444832,
- severity: .negligible,
- severityScore: 0.0,
- blocked: false
- ),
- SafetyRating(
- category: .dangerousContent,
- probability: .negligible,
- probabilityScore: 0.0010654529,
- severity: .negligible,
- severityScore: 0.0049325973,
- blocked: false
- ),
- SafetyRating(
- category: .harassment,
- probability: .negligible,
- probabilityScore: 0.00026658305,
- severity: .negligible,
- severityScore: 0.0,
- blocked: false
- ),
- SafetyRating(
- category: .sexuallyExplicit,
- probability: .negligible,
- probabilityScore: 0.0013701695,
- severity: .negligible,
- severityScore: 0.07626295,
- blocked: false
- ),
- // Ignored Invalid Safety Ratings: {},{},{},{}
- ].sorted()
- let testModelName = "test-model"
- let testModelResourceName =
- "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
- let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
- let vertexSubdirectory = "mock-responses/vertexai"
- var urlSession: URLSession!
- var model: GenerativeModel!
- override func setUp() async throws {
- let configuration = URLSessionConfiguration.default
- configuration.protocolClasses = [MockURLProtocol.self]
- urlSession = try XCTUnwrap(URLSession(configuration: configuration))
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- }
- override func tearDown() {
- MockURLProtocol.requestHandler = nil
- }
- // MARK: - Generate Content
- func testGenerateContent_success_basicReplyLong() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-long",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- let finishReason = try XCTUnwrap(candidate.finishReason)
- XCTAssertEqual(finishReason, .stop)
- XCTAssertEqual(candidate.safetyRatings.count, 4)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let part = try XCTUnwrap(candidate.content.parts.first)
- let partText = try XCTUnwrap(part as? TextPart).text
- XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
- XCTAssertEqual(response.text, partText)
- XCTAssertEqual(response.functionCalls, [])
- XCTAssertEqual(response.inlineDataParts, [])
- }
- func testGenerateContent_success_basicReplyShort() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- let finishReason = try XCTUnwrap(candidate.finishReason)
- XCTAssertEqual(finishReason, .stop)
- XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let part = try XCTUnwrap(candidate.content.parts.first)
- let textPart = try XCTUnwrap(part as? TextPart)
- XCTAssertEqual(textPart.text, "Mountain View, California")
- XCTAssertEqual(response.text, textPart.text)
- XCTAssertEqual(response.functionCalls, [])
- }
- func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-response-long-usage-metadata",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- let finishReason = try XCTUnwrap(candidate.finishReason)
- XCTAssertEqual(finishReason, .stop)
- let usageMetadata = try XCTUnwrap(response.usageMetadata)
- XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
- XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
- XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
- XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
- XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
- XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
- XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
- XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
- }
- func testGenerateContent_success_citations() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-citations",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let expectedPublicationDate = DateComponents(
- calendar: Calendar(identifier: .gregorian),
- year: 2019,
- month: 5,
- day: 10
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 1)
- XCTAssertEqual(response.text, "Some information cited from an external source")
- let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
- XCTAssertEqual(citationMetadata.citations.count, 3)
- let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
- XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
- XCTAssertEqual(citationSource1.startIndex, 0)
- XCTAssertEqual(citationSource1.endIndex, 128)
- XCTAssertNil(citationSource1.title)
- XCTAssertNil(citationSource1.license)
- XCTAssertNil(citationSource1.publicationDate)
- let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
- XCTAssertEqual(citationSource2.title, "some-citation-2")
- XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
- XCTAssertEqual(citationSource2.startIndex, 130)
- XCTAssertEqual(citationSource2.endIndex, 265)
- XCTAssertNil(citationSource2.uri)
- XCTAssertNil(citationSource2.license)
- let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
- XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
- XCTAssertEqual(citationSource3.startIndex, 272)
- XCTAssertEqual(citationSource3.endIndex, 431)
- XCTAssertEqual(citationSource3.license, "mit")
- XCTAssertNil(citationSource3.title)
- XCTAssertNil(citationSource3.publicationDate)
- }
- func testGenerateContent_success_quoteReply() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-quote-reply",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- let finishReason = try XCTUnwrap(candidate.finishReason)
- XCTAssertEqual(finishReason, .stop)
- XCTAssertEqual(candidate.safetyRatings.count, 4)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let part = try XCTUnwrap(candidate.content.parts.first)
- let textPart = try XCTUnwrap(part as? TextPart)
- XCTAssertTrue(textPart.text.hasPrefix("Google"))
- XCTAssertEqual(response.text, textPart.text)
- let promptFeedback = try XCTUnwrap(response.promptFeedback)
- XCTAssertNil(promptFeedback.blockReason)
- XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
- }
- func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
- let expectedSafetyRatings = [
- SafetyRating(
- category: .harassment,
- probability: .medium,
- probabilityScore: 0.0,
- severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
- severityScore: 0.0,
- blocked: false
- ),
- SafetyRating(
- category: .dangerousContent,
- probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
- probabilityScore: 0.0,
- severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
- severityScore: 0.0,
- blocked: false
- ),
- SafetyRating(
- category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
- probability: .high,
- probabilityScore: 0.0,
- severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
- severityScore: 0.0,
- blocked: false
- ),
- ]
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-unknown-enum-safety-ratings",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.text, "Some text")
- XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
- XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
- }
- func testGenerateContent_success_prefixedModelName() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_success_functionCall_emptyArguments() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-function-call-empty-arguments",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let part = try XCTUnwrap(candidate.content.parts.first)
- guard let functionCall = part as? FunctionCallPart else {
- XCTFail("Part is not a FunctionCall.")
- return
- }
- XCTAssertEqual(functionCall.name, "current_time")
- XCTAssertTrue(functionCall.args.isEmpty)
- XCTAssertEqual(response.functionCalls, [functionCall])
- }
- func testGenerateContent_success_functionCall_noArguments() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-function-call-no-arguments",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let part = try XCTUnwrap(candidate.content.parts.first)
- guard let functionCall = part as? FunctionCallPart else {
- XCTFail("Part is not a FunctionCall.")
- return
- }
- XCTAssertEqual(functionCall.name, "current_time")
- XCTAssertTrue(functionCall.args.isEmpty)
- XCTAssertEqual(response.functionCalls, [functionCall])
- }
- func testGenerateContent_success_functionCall_withArguments() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-function-call-with-arguments",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let part = try XCTUnwrap(candidate.content.parts.first)
- guard let functionCall = part as? FunctionCallPart else {
- XCTFail("Part is not a FunctionCall.")
- return
- }
- XCTAssertEqual(functionCall.name, "sum")
- XCTAssertEqual(functionCall.args.count, 2)
- let argX = try XCTUnwrap(functionCall.args["x"])
- XCTAssertEqual(argX, .number(4))
- let argY = try XCTUnwrap(functionCall.args["y"])
- XCTAssertEqual(argY, .number(5))
- XCTAssertEqual(response.functionCalls, [functionCall])
- }
- func testGenerateContent_success_functionCall_parallelCalls() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-function-call-parallel-calls",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 3)
- let functionCalls = response.functionCalls
- XCTAssertEqual(functionCalls.count, 3)
- }
- func testGenerateContent_success_functionCall_mixedContent() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-function-call-mixed-content",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 4)
- let functionCalls = response.functionCalls
- XCTAssertEqual(functionCalls.count, 2)
- let text = try XCTUnwrap(response.text)
- XCTAssertEqual(text, "The sum of [1, 2, 3] is")
- }
- func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-image-invalid-safety-ratings",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.content.parts.count, 1)
- XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
- let inlineDataParts = response.inlineDataParts
- XCTAssertEqual(inlineDataParts.count, 1)
- let imagePart = try XCTUnwrap(inlineDataParts.first)
- XCTAssertEqual(imagePart.mimeType, "image/png")
- XCTAssertGreaterThan(imagePart.data.count, 0)
- }
- func testGenerateContent_appCheck_validToken() async throws {
- let appCheckToken = "test-valid-token"
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- appCheck: AppCheckInteropFake(token: appCheckToken)
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- appCheckToken: appCheckToken
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_dataCollectionOff() async throws {
- let appCheckToken = "test-valid-token"
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- appCheckToken: appCheckToken,
- dataCollection: false
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_appCheck_tokenRefreshError() async throws {
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- appCheckToken: AppCheckInteropFake.placeholderTokenValue
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_auth_validAuthToken() async throws {
- let authToken = "test-valid-token"
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- auth: AuthInteropFake(token: authToken)
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- authToken: authToken
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_auth_nilAuthToken() async throws {
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- authToken: nil
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_auth_authTokenRefreshError() async throws {
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- auth: AuthInteropFake(error: AuthErrorFake())
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- authToken: nil
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw internalError(AuthErrorFake); no error.")
- } catch GenerateContentError.internalError(_ as AuthErrorFake) {
- //
- } catch {
- XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
- }
- }
- func testGenerateContent_usageMetadata() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- let usageMetadata = try XCTUnwrap(response.usageMetadata)
- XCTAssertEqual(usageMetadata.promptTokenCount, 6)
- XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
- XCTAssertEqual(usageMetadata.totalTokenCount, 13)
- XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
- XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
- }
- func testGenerateContent_groundingMetadata() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-google-search-grounding",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- let groundingMetadata = try XCTUnwrap(candidate.groundingMetadata)
- XCTAssertEqual(groundingMetadata.webSearchQueries, ["current weather in London"])
- XCTAssertNotNil(groundingMetadata.searchEntryPoint)
- XCTAssertNotNil(groundingMetadata.searchEntryPoint?.renderedContent)
- XCTAssertEqual(groundingMetadata.groundingChunks.count, 2)
- let firstChunk = try XCTUnwrap(groundingMetadata.groundingChunks.first?.web)
- XCTAssertEqual(firstChunk.title, "accuweather.com")
- XCTAssertNotNil(firstChunk.uri)
- XCTAssertNil(firstChunk.domain) // Domain is not supported by Google AI backend
- XCTAssertEqual(groundingMetadata.groundingSupports.count, 3)
- let firstSupport = try XCTUnwrap(groundingMetadata.groundingSupports.first)
- let segment = try XCTUnwrap(firstSupport.segment)
- XCTAssertEqual(segment.text, "The current weather in London, United Kingdom is cloudy.")
- XCTAssertEqual(segment.startIndex, 0)
- XCTAssertEqual(segment.partIndex, 0)
- XCTAssertEqual(segment.endIndex, 56)
- XCTAssertEqual(firstSupport.groundingChunkIndices, [0])
- }
- func testGenerateContent_withGoogleSearchTool() async throws {
- let model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- tools: [.googleSearch()],
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- _ = try await model.generateContent(testPrompt)
- }
- func testGenerateContent_failure_invalidAPIKey() async throws {
- let expectedStatusCode = 400
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-api-key",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- statusCode: expectedStatusCode
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
- } catch let GenerateContentError.internalError(error as BackendError) {
- XCTAssertEqual(error.httpResponseCode, 400)
- XCTAssertEqual(error.status, .invalidArgument)
- XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
- XCTAssertTrue(error.localizedDescription.contains(error.message))
- XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
- XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
- let nsError = error as NSError
- XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
- XCTAssertEqual(nsError.code, error.httpResponseCode)
- return
- } catch {
- XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
- }
- }
- func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
- let expectedStatusCode = 403
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-firebasevertexai-api-not-enabled",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- statusCode: expectedStatusCode
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
- } catch let GenerateContentError.internalError(error as BackendError) {
- XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
- XCTAssertEqual(error.status, .permissionDenied)
- XCTAssertTrue(error.message
- .starts(with: "Vertex AI in Firebase API has not been used in project"))
- XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
- return
- } catch {
- XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
- }
- }
- func testGenerateContent_failure_emptyContent() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-empty-content",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
- } catch let GenerateContentError
- .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
- guard case let .emptyContent(decodingError) = invalidCandidateError else {
- XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
- return
- }
- _ = try XCTUnwrap(decodingError as? DecodingError,
- "Not a DecodingError: \(decodingError)")
- } catch {
- XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
- }
- }
- func testGenerateContent_failure_finishReasonSafety() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-finish-reason-safety",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw")
- } catch let GenerateContentError.responseStoppedEarly(reason, response) {
- XCTAssertEqual(reason, .safety)
- XCTAssertEqual(response.text, "<redacted>")
- } catch {
- XCTFail("Should throw a responseStoppedEarly")
- }
- }
- func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-finish-reason-safety-no-content",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw")
- } catch let GenerateContentError.responseStoppedEarly(reason, response) {
- XCTAssertEqual(reason, .safety)
- XCTAssertNil(response.text)
- } catch {
- XCTFail("Should throw a responseStoppedEarly")
- }
- }
- func testGenerateContent_failure_imageRejected() async throws {
- let expectedStatusCode = 400
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-image-rejected",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- statusCode: 400
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
- } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
- XCTAssertEqual(rpcError.status, .invalidArgument)
- XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
- XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
- } catch {
- XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
- }
- }
- func testGenerateContent_failure_promptBlockedSafety() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-prompt-blocked-safety",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw")
- } catch let GenerateContentError.promptBlocked(response) {
- XCTAssertNil(response.text)
- let promptFeedback = try XCTUnwrap(response.promptFeedback)
- XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
- XCTAssertNil(promptFeedback.blockReasonMessage)
- } catch {
- XCTFail("Should throw a promptBlocked")
- }
- }
- func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-prompt-blocked-safety-with-message",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw")
- } catch let GenerateContentError.promptBlocked(response) {
- XCTAssertNil(response.text)
- let promptFeedback = try XCTUnwrap(response.promptFeedback)
- XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
- XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
- } catch {
- XCTFail("Should throw a promptBlocked")
- }
- }
- func testGenerateContent_failure_unknownEnum_finishReason() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-unknown-enum-finish-reason",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw")
- } catch let GenerateContentError.responseStoppedEarly(reason, response) {
- XCTAssertEqual(reason, unknownFinishReason)
- XCTAssertEqual(response.text, "Some text")
- } catch {
- XCTFail("Should throw a responseStoppedEarly")
- }
- }
- func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-unknown-enum-prompt-blocked",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw")
- } catch let GenerateContentError.promptBlocked(response) {
- let promptFeedback = try XCTUnwrap(response.promptFeedback)
- XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
- } catch {
- XCTFail("Should throw a promptBlocked")
- }
- }
- func testGenerateContent_failure_unknownModel() async throws {
- let expectedStatusCode = 404
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-unknown-model",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- statusCode: 404
- )
- do {
- _ = try await model.generateContent(testPrompt)
- XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
- } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
- XCTAssertEqual(rpcError.status, .notFound)
- XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
- XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
- } catch {
- XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
- }
- }
- func testGenerateContent_failure_nonHTTPResponse() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
- var responseError: Error?
- var content: GenerateContentResponse?
- do {
- content = try await model.generateContent(testPrompt)
- } catch {
- responseError = error
- }
- XCTAssertNil(content)
- XCTAssertNotNil(responseError)
- let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
- guard case let .internalError(underlyingError) = generateContentError else {
- XCTFail("Not an internal error: \(generateContentError)")
- return
- }
- XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
- let underlyingNSError = underlyingError as NSError
- XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
- XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
- }
- func testGenerateContent_failure_invalidResponse() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-invalid-response",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- var responseError: Error?
- var content: GenerateContentResponse?
- do {
- content = try await model.generateContent(testPrompt)
- } catch {
- responseError = error
- }
- XCTAssertNil(content)
- XCTAssertNotNil(responseError)
- let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
- guard case let .internalError(underlyingError) = generateContentError else {
- XCTFail("Not an internal error: \(generateContentError)")
- return
- }
- let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
- guard case let .dataCorrupted(context) = decodingError else {
- XCTFail("Not a data corrupted error: \(decodingError)")
- return
- }
- XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
- }
- func testGenerateContent_failure_malformedContent() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- // Note: Although this file does not contain `parts` in `content`, it is not actually
- // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
- // the proto API. Therefore, this test checks for the `emptyContent` error instead.
- forResource: "unary-failure-malformed-content",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- var responseError: Error?
- var content: GenerateContentResponse?
- do {
- content = try await model.generateContent(testPrompt)
- } catch {
- responseError = error
- }
- XCTAssertNil(content)
- XCTAssertNotNil(responseError)
- let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
- guard case let .internalError(underlyingError) = generateContentError else {
- XCTFail("Not an internal error: \(generateContentError)")
- return
- }
- let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
- guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
- XCTFail("Not an empty content error: \(invalidCandidateError)")
- return
- }
- _ = try XCTUnwrap(
- emptyContentUnderlyingError as? DecodingError,
- "Not a decoding error: \(emptyContentUnderlyingError)"
- )
- }
- func testGenerateContentMissingSafetyRatings() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-missing-safety-ratings",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let content = try await model.generateContent(testPrompt)
- let promptFeedback = try XCTUnwrap(content.promptFeedback)
- XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
- XCTAssertEqual(content.text, "This is the generated content.")
- }
- func testGenerateContent_requestOptions_customTimeout() async throws {
- let expectedTimeout = 150.0
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-basic-reply-short",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- timeout: expectedTimeout
- )
- let requestOptions = RequestOptions(timeout: expectedTimeout)
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: requestOptions,
- urlSession: urlSession
- )
- let response = try await model.generateContent(testPrompt)
- XCTAssertEqual(response.candidates.count, 1)
- }
- // MARK: - Generate Content (Streaming)
- func testGenerateContentStream_failureInvalidAPIKey() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-api-key",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- do {
- let stream = try model.generateContentStream("Hi")
- for try await _ in stream {
- XCTFail("No content is there, this shouldn't happen.")
- }
- } catch let GenerateContentError.internalError(error as BackendError) {
- XCTAssertEqual(error.httpResponseCode, 400)
- XCTAssertEqual(error.status, .invalidArgument)
- XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
- XCTAssertTrue(error.localizedDescription.contains(error.message))
- XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
- XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
- let nsError = error as NSError
- XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
- XCTAssertEqual(nsError.code, error.httpResponseCode)
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
- let expectedStatusCode = 403
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-firebasevertexai-api-not-enabled",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- statusCode: expectedStatusCode
- )
- do {
- let stream = try model.generateContentStream(testPrompt)
- for try await _ in stream {
- XCTFail("No content is there, this shouldn't happen.")
- }
- } catch let GenerateContentError.internalError(error as BackendError) {
- XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
- XCTAssertEqual(error.status, .permissionDenied)
- XCTAssertTrue(error.message
- .starts(with: "Vertex AI in Firebase API has not been used in project"))
- XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_failureEmptyContent() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-empty-content",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- do {
- let stream = try model.generateContentStream("Hi")
- for try await _ in stream {
- XCTFail("No content is there, this shouldn't happen.")
- }
- } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
- // Underlying error is as expected, nothing else to check.
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_failureFinishReasonSafety() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-finish-reason-safety",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- do {
- let stream = try model.generateContentStream("Hi")
- for try await _ in stream {
- XCTFail("Content shouldn't be shown, this shouldn't happen.")
- }
- } catch let GenerateContentError.responseStoppedEarly(reason, response) {
- XCTAssertEqual(reason, .safety)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.finishReason, reason)
- XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_failurePromptBlockedSafety() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-prompt-blocked-safety",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- do {
- let stream = try model.generateContentStream("Hi")
- for try await _ in stream {
- XCTFail("Content shouldn't be shown, this shouldn't happen.")
- }
- } catch let GenerateContentError.promptBlocked(response) {
- let promptFeedback = try XCTUnwrap(response.promptFeedback)
- XCTAssertEqual(promptFeedback.blockReason, .safety)
- XCTAssertNil(promptFeedback.blockReasonMessage)
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-prompt-blocked-safety-with-message",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- do {
- let stream = try model.generateContentStream("Hi")
- for try await _ in stream {
- XCTFail("Content shouldn't be shown, this shouldn't happen.")
- }
- } catch let GenerateContentError.promptBlocked(response) {
- let promptFeedback = try XCTUnwrap(response.promptFeedback)
- XCTAssertEqual(promptFeedback.blockReason, .safety)
- XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_failureUnknownFinishEnum() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-unknown-finish-enum",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
- let stream = try model.generateContentStream("Hi")
- do {
- for try await content in stream {
- XCTAssertNotNil(content.text)
- }
- } catch let GenerateContentError.responseStoppedEarly(reason, _) {
- XCTAssertEqual(reason, unknownFinishReason)
- return
- }
- XCTFail("Should have caught an error.")
- }
- func testGenerateContentStream_successBasicReplyLong() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-basic-reply-long",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- var responses = 0
- let stream = try model.generateContentStream("Hi")
- for try await content in stream {
- XCTAssertNotNil(content.text)
- responses += 1
- }
- XCTAssertEqual(responses, 4)
- }
- func testGenerateContentStream_successBasicReplyShort() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-basic-reply-short",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- var responses = 0
- let stream = try model.generateContentStream("Hi")
- for try await content in stream {
- XCTAssertNotNil(content.text)
- responses += 1
- }
- XCTAssertEqual(responses, 1)
- }
- func testGenerateContentStream_successUnknownSafetyEnum() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-unknown-safety-enum",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- let unknownSafetyRating = SafetyRating(
- category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
- probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
- probabilityScore: 0.0,
- severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
- severityScore: 0.0,
- blocked: false
- )
- var foundUnknownSafetyRating = false
- let stream = try model.generateContentStream("Hi")
- for try await content in stream {
- XCTAssertNotNil(content.text)
- if let ratings = content.candidates.first?.safetyRatings,
- ratings.contains(where: { $0 == unknownSafetyRating }) {
- foundUnknownSafetyRating = true
- }
- }
- XCTAssertTrue(foundUnknownSafetyRating)
- }
- func testGenerateContentStream_successWithCitations() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-citations",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- let expectedPublicationDate = DateComponents(
- calendar: Calendar(identifier: .gregorian),
- year: 2014,
- month: 3,
- day: 30
- )
- let stream = try model.generateContentStream("Hi")
- var citations = [Citation]()
- var responses = [GenerateContentResponse]()
- for try await content in stream {
- responses.append(content)
- XCTAssertNotNil(content.text)
- let candidate = try XCTUnwrap(content.candidates.first)
- if let sources = candidate.citationMetadata?.citations {
- citations.append(contentsOf: sources)
- }
- }
- let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
- XCTAssertEqual(lastCandidate.finishReason, .stop)
- XCTAssertEqual(citations.count, 6)
- XCTAssertTrue(citations
- .contains {
- $0.startIndex == 0 && $0.endIndex == 128
- && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
- && $0.license == nil && $0.publicationDate == nil
- })
- XCTAssertTrue(citations
- .contains {
- $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
- && $0.title == "some-citation-2" && $0.license == nil
- && $0.publicationDate == expectedPublicationDate
- })
- XCTAssertTrue(citations
- .contains {
- $0.startIndex == 272 && $0.endIndex == 431
- && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
- && $0.license == "mit" && $0.publicationDate == nil
- })
- XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
- XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
- XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
- }
- func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-image-invalid-safety-ratings",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- let stream = try model.generateContentStream(testPrompt)
- var responses = [GenerateContentResponse]()
- for try await content in stream {
- responses.append(content)
- }
- let response = try XCTUnwrap(responses.first)
- XCTAssertEqual(response.candidates.count, 1)
- let candidate = try XCTUnwrap(response.candidates.first)
- XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
- XCTAssertEqual(candidate.content.parts.count, 1)
- let inlineDataParts = response.inlineDataParts
- XCTAssertEqual(inlineDataParts.count, 1)
- let imagePart = try XCTUnwrap(inlineDataParts.first)
- XCTAssertEqual(imagePart.mimeType, "image/png")
- XCTAssertGreaterThan(imagePart.data.count, 0)
- }
- func testGenerateContentStream_appCheck_validToken() async throws {
- let appCheckToken = "test-valid-token"
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- appCheck: AppCheckInteropFake(token: appCheckToken)
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-basic-reply-short",
- withExtension: "txt",
- subdirectory: vertexSubdirectory,
- appCheckToken: appCheckToken
- )
- let stream = try model.generateContentStream(testPrompt)
- for try await _ in stream {}
- }
- func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
- appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
- ),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-basic-reply-short",
- withExtension: "txt",
- subdirectory: vertexSubdirectory,
- appCheckToken: AppCheckInteropFake.placeholderTokenValue
- )
- let stream = try model.generateContentStream(testPrompt)
- for try await _ in stream {}
- }
- func testGenerateContentStream_usageMetadata() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-basic-reply-short",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- var responses = [GenerateContentResponse]()
- let stream = try model.generateContentStream(testPrompt)
- for try await response in stream {
- responses.append(response)
- }
- for (index, response) in responses.enumerated() {
- if index == responses.endIndex - 1 {
- let usageMetadata = try XCTUnwrap(response.usageMetadata)
- XCTAssertEqual(usageMetadata.promptTokenCount, 6)
- XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
- XCTAssertEqual(usageMetadata.totalTokenCount, 10)
- } else {
- // Only the last streamed response contains usage metadata
- XCTAssertNil(response.usageMetadata)
- }
- }
- }
- func testGenerateContentStream_errorMidStream() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-error-mid-stream",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- var responseCount = 0
- do {
- let stream = try model.generateContentStream("Hi")
- for try await content in stream {
- XCTAssertNotNil(content.text)
- responseCount += 1
- }
- } catch let GenerateContentError.internalError(rpcError as BackendError) {
- XCTAssertEqual(rpcError.httpResponseCode, 499)
- XCTAssertEqual(rpcError.status, .cancelled)
- // Check the content count is correct.
- XCTAssertEqual(responseCount, 2)
- return
- }
- XCTFail("Expected an internalError with an RPCError.")
- }
- func testGenerateContentStream_nonHTTPResponse() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
- let stream = try model.generateContentStream("Hi")
- do {
- for try await content in stream {
- XCTFail("Unexpected content in stream: \(content)")
- }
- } catch let GenerateContentError.internalError(underlying) {
- XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
- return
- }
- XCTFail("Expected an internal error.")
- }
- func testGenerateContentStream_invalidResponse() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-failure-invalid-json",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- let stream = try model.generateContentStream(testPrompt)
- do {
- for try await content in stream {
- XCTFail("Unexpected content in stream: \(content)")
- }
- } catch let GenerateContentError.internalError(underlying as DecodingError) {
- guard case let .dataCorrupted(context) = underlying else {
- XCTFail("Not a data corrupted error: \(underlying)")
- return
- }
- XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
- return
- }
- XCTFail("Expected an internal error.")
- }
- func testGenerateContentStream_malformedContent() async throws {
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- // Note: Although this file does not contain `parts` in `content`, it is not actually
- // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
- // the proto API. Therefore, this test checks for the `emptyContent` error instead.
- forResource: "streaming-failure-malformed-content",
- withExtension: "txt",
- subdirectory: vertexSubdirectory
- )
- let stream = try model.generateContentStream(testPrompt)
- do {
- for try await content in stream {
- XCTFail("Unexpected content in stream: \(content)")
- }
- } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
- guard case let .emptyContent(contentError) = underlyingError else {
- XCTFail("Not an empty content error: \(underlyingError)")
- return
- }
- XCTAssert(contentError is DecodingError)
- return
- }
- XCTFail("Expected an internal decoding error.")
- }
- func testGenerateContentStream_requestOptions_customTimeout() async throws {
- let expectedTimeout = 150.0
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "streaming-success-basic-reply-short",
- withExtension: "txt",
- subdirectory: vertexSubdirectory,
- timeout: expectedTimeout
- )
- let requestOptions = RequestOptions(timeout: expectedTimeout)
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: requestOptions,
- urlSession: urlSession
- )
- var responses = 0
- let stream = try model.generateContentStream(testPrompt)
- for try await content in stream {
- XCTAssertNotNil(content.text)
- responses += 1
- }
- XCTAssertEqual(responses, 1)
- }
- // MARK: - Count Tokens
- func testCountTokens_succeeds() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-total-tokens",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.countTokens("Why is the sky blue?")
- XCTAssertEqual(response.totalTokens, 6)
- }
- func testCountTokens_succeeds_detailed() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-detailed-token-response",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.countTokens("Why is the sky blue?")
- XCTAssertEqual(response.totalTokens, 1837)
- XCTAssertEqual(response.promptTokensDetails.count, 2)
- XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
- XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
- XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
- XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
- }
- func testCountTokens_succeeds_allOptions() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-total-tokens",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let generationConfig = GenerationConfig(
- temperature: 0.5,
- topP: 0.9,
- topK: 3,
- candidateCount: 1,
- maxOutputTokens: 1024,
- stopSequences: ["test-stop"],
- responseMIMEType: "text/plain"
- )
- let sumFunction = FunctionDeclaration(
- name: "sum",
- description: "Add two integers.",
- parameters: ["x": .integer(), "y": .integer()]
- )
- let systemInstruction = ModelContent(
- role: "system",
- parts: "You are a calculator. Use the provided tools."
- )
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- generationConfig: generationConfig,
- tools: [Tool(functionDeclarations: [sumFunction])],
- systemInstruction: systemInstruction,
- requestOptions: RequestOptions(),
- urlSession: urlSession
- )
- let response = try await model.countTokens("Why is the sky blue?")
- XCTAssertEqual(response.totalTokens, 6)
- }
- func testCountTokens_succeeds_noBillableCharacters() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-no-billable-characters",
- withExtension: "json",
- subdirectory: vertexSubdirectory
- )
- let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
- XCTAssertEqual(response.totalTokens, 258)
- }
- func testCountTokens_modelNotFound() async throws {
- MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-failure-model-not-found", withExtension: "json",
- subdirectory: vertexSubdirectory,
- statusCode: 404
- )
- do {
- _ = try await model.countTokens("Why is the sky blue?")
- XCTFail("Request should not have succeeded.")
- } catch let rpcError as BackendError {
- XCTAssertEqual(rpcError.httpResponseCode, 404)
- XCTAssertEqual(rpcError.status, .notFound)
- XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
- return
- }
- XCTFail("Expected internal RPCError.")
- }
- func testCountTokens_requestOptions_customTimeout() async throws {
- let expectedTimeout = 150.0
- MockURLProtocol
- .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
- forResource: "unary-success-total-tokens",
- withExtension: "json",
- subdirectory: vertexSubdirectory,
- timeout: expectedTimeout
- )
- let requestOptions = RequestOptions(timeout: expectedTimeout)
- model = GenerativeModel(
- modelName: testModelName,
- modelResourceName: testModelResourceName,
- firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
- apiConfig: apiConfig,
- tools: nil,
- requestOptions: requestOptions,
- urlSession: urlSession
- )
- let response = try await model.countTokens(testPrompt)
- XCTAssertEqual(response.totalTokens, 6)
- }
- }
- @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
- extension SafetyRating: Swift.Comparable {
- public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
- return lhs.category.rawValue < rhs.category.rawValue
- }
- }
|