GenerativeModelVertexAITests.swift 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelVertexAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let safetyRatingsInvalidIgnored = [
  57. SafetyRating(
  58. category: .hateSpeech,
  59. probability: .negligible,
  60. probabilityScore: 0.00039444832,
  61. severity: .negligible,
  62. severityScore: 0.0,
  63. blocked: false
  64. ),
  65. SafetyRating(
  66. category: .dangerousContent,
  67. probability: .negligible,
  68. probabilityScore: 0.0010654529,
  69. severity: .negligible,
  70. severityScore: 0.0049325973,
  71. blocked: false
  72. ),
  73. SafetyRating(
  74. category: .harassment,
  75. probability: .negligible,
  76. probabilityScore: 0.00026658305,
  77. severity: .negligible,
  78. severityScore: 0.0,
  79. blocked: false
  80. ),
  81. SafetyRating(
  82. category: .sexuallyExplicit,
  83. probability: .negligible,
  84. probabilityScore: 0.0013701695,
  85. severity: .negligible,
  86. severityScore: 0.07626295,
  87. blocked: false
  88. ),
  89. // Ignored Invalid Safety Ratings: {},{},{},{}
  90. ].sorted()
  91. let testModelName = "test-model"
  92. let testModelResourceName =
  93. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  94. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  95. let vertexSubdirectory = "mock-responses/vertexai"
  96. var urlSession: URLSession!
  97. var model: GenerativeModel!
  98. override func setUp() async throws {
  99. let configuration = URLSessionConfiguration.default
  100. configuration.protocolClasses = [MockURLProtocol.self]
  101. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  102. model = GenerativeModel(
  103. modelName: testModelName,
  104. modelResourceName: testModelResourceName,
  105. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  106. apiConfig: apiConfig,
  107. tools: nil,
  108. requestOptions: RequestOptions(),
  109. urlSession: urlSession
  110. )
  111. }
  112. override func tearDown() {
  113. MockURLProtocol.requestHandler = nil
  114. }
  115. // MARK: - Generate Content
  116. func testGenerateContent_success_basicReplyLong() async throws {
  117. MockURLProtocol
  118. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  119. forResource: "unary-success-basic-reply-long",
  120. withExtension: "json",
  121. subdirectory: vertexSubdirectory
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. XCTAssertEqual(candidate.safetyRatings.count, 4)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let part = try XCTUnwrap(candidate.content.parts.first)
  131. let partText = try XCTUnwrap(part as? TextPart).text
  132. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  133. XCTAssertEqual(response.text, partText)
  134. XCTAssertEqual(response.functionCalls, [])
  135. XCTAssertEqual(response.inlineDataParts, [])
  136. }
  137. func testGenerateContent_success_basicReplyShort() async throws {
  138. MockURLProtocol
  139. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  140. forResource: "unary-success-basic-reply-short",
  141. withExtension: "json",
  142. subdirectory: vertexSubdirectory
  143. )
  144. let response = try await model.generateContent(testPrompt)
  145. XCTAssertEqual(response.candidates.count, 1)
  146. let candidate = try XCTUnwrap(response.candidates.first)
  147. let finishReason = try XCTUnwrap(candidate.finishReason)
  148. XCTAssertEqual(finishReason, .stop)
  149. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  150. XCTAssertEqual(candidate.content.parts.count, 1)
  151. let part = try XCTUnwrap(candidate.content.parts.first)
  152. let textPart = try XCTUnwrap(part as? TextPart)
  153. XCTAssertEqual(textPart.text, "Mountain View, California")
  154. XCTAssertEqual(response.text, textPart.text)
  155. XCTAssertEqual(response.functionCalls, [])
  156. }
  157. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  158. MockURLProtocol
  159. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  160. forResource: "unary-success-basic-response-long-usage-metadata",
  161. withExtension: "json",
  162. subdirectory: vertexSubdirectory
  163. )
  164. let response = try await model.generateContent(testPrompt)
  165. XCTAssertEqual(response.candidates.count, 1)
  166. let candidate = try XCTUnwrap(response.candidates.first)
  167. let finishReason = try XCTUnwrap(candidate.finishReason)
  168. XCTAssertEqual(finishReason, .stop)
  169. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  170. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  171. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  175. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  178. }
  179. func testGenerateContent_success_citations() async throws {
  180. MockURLProtocol
  181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  182. forResource: "unary-success-citations",
  183. withExtension: "json",
  184. subdirectory: vertexSubdirectory
  185. )
  186. let expectedPublicationDate = DateComponents(
  187. calendar: Calendar(identifier: .gregorian),
  188. year: 2019,
  189. month: 5,
  190. day: 10
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. XCTAssertEqual(candidate.content.parts.count, 1)
  196. XCTAssertEqual(response.text, "Some information cited from an external source")
  197. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  198. XCTAssertEqual(citationMetadata.citations.count, 3)
  199. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  200. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  201. XCTAssertEqual(citationSource1.startIndex, 0)
  202. XCTAssertEqual(citationSource1.endIndex, 128)
  203. XCTAssertNil(citationSource1.title)
  204. XCTAssertNil(citationSource1.license)
  205. XCTAssertNil(citationSource1.publicationDate)
  206. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  207. XCTAssertEqual(citationSource2.title, "some-citation-2")
  208. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  209. XCTAssertEqual(citationSource2.startIndex, 130)
  210. XCTAssertEqual(citationSource2.endIndex, 265)
  211. XCTAssertNil(citationSource2.uri)
  212. XCTAssertNil(citationSource2.license)
  213. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  214. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  215. XCTAssertEqual(citationSource3.startIndex, 272)
  216. XCTAssertEqual(citationSource3.endIndex, 431)
  217. XCTAssertEqual(citationSource3.license, "mit")
  218. XCTAssertNil(citationSource3.title)
  219. XCTAssertNil(citationSource3.publicationDate)
  220. }
  221. func testGenerateContent_success_quoteReply() async throws {
  222. MockURLProtocol
  223. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  224. forResource: "unary-success-quote-reply",
  225. withExtension: "json",
  226. subdirectory: vertexSubdirectory
  227. )
  228. let response = try await model.generateContent(testPrompt)
  229. XCTAssertEqual(response.candidates.count, 1)
  230. let candidate = try XCTUnwrap(response.candidates.first)
  231. let finishReason = try XCTUnwrap(candidate.finishReason)
  232. XCTAssertEqual(finishReason, .stop)
  233. XCTAssertEqual(candidate.safetyRatings.count, 4)
  234. XCTAssertEqual(candidate.content.parts.count, 1)
  235. let part = try XCTUnwrap(candidate.content.parts.first)
  236. let textPart = try XCTUnwrap(part as? TextPart)
  237. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  238. XCTAssertEqual(response.text, textPart.text)
  239. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  240. XCTAssertNil(promptFeedback.blockReason)
  241. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  242. }
  243. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  244. let expectedSafetyRatings = [
  245. SafetyRating(
  246. category: .harassment,
  247. probability: .medium,
  248. probabilityScore: 0.0,
  249. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  250. severityScore: 0.0,
  251. blocked: false
  252. ),
  253. SafetyRating(
  254. category: .dangerousContent,
  255. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  256. probabilityScore: 0.0,
  257. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  258. severityScore: 0.0,
  259. blocked: false
  260. ),
  261. SafetyRating(
  262. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  263. probability: .high,
  264. probabilityScore: 0.0,
  265. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  266. severityScore: 0.0,
  267. blocked: false
  268. ),
  269. ]
  270. MockURLProtocol
  271. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "unary-success-unknown-enum-safety-ratings",
  273. withExtension: "json",
  274. subdirectory: vertexSubdirectory
  275. )
  276. let response = try await model.generateContent(testPrompt)
  277. XCTAssertEqual(response.text, "Some text")
  278. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  279. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  280. }
  281. func testGenerateContent_success_prefixedModelName() async throws {
  282. MockURLProtocol
  283. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  284. forResource: "unary-success-basic-reply-short",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let model = GenerativeModel(
  289. modelName: testModelName,
  290. modelResourceName: testModelResourceName,
  291. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  292. apiConfig: apiConfig,
  293. tools: nil,
  294. requestOptions: RequestOptions(),
  295. urlSession: urlSession
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  300. MockURLProtocol
  301. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  302. forResource: "unary-success-function-call-empty-arguments",
  303. withExtension: "json",
  304. subdirectory: vertexSubdirectory
  305. )
  306. let response = try await model.generateContent(testPrompt)
  307. XCTAssertEqual(response.candidates.count, 1)
  308. let candidate = try XCTUnwrap(response.candidates.first)
  309. XCTAssertEqual(candidate.content.parts.count, 1)
  310. let part = try XCTUnwrap(candidate.content.parts.first)
  311. guard let functionCall = part as? FunctionCallPart else {
  312. XCTFail("Part is not a FunctionCall.")
  313. return
  314. }
  315. XCTAssertEqual(functionCall.name, "current_time")
  316. XCTAssertTrue(functionCall.args.isEmpty)
  317. XCTAssertEqual(response.functionCalls, [functionCall])
  318. }
  319. func testGenerateContent_success_functionCall_noArguments() async throws {
  320. MockURLProtocol
  321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  322. forResource: "unary-success-function-call-no-arguments",
  323. withExtension: "json",
  324. subdirectory: vertexSubdirectory
  325. )
  326. let response = try await model.generateContent(testPrompt)
  327. XCTAssertEqual(response.candidates.count, 1)
  328. let candidate = try XCTUnwrap(response.candidates.first)
  329. XCTAssertEqual(candidate.content.parts.count, 1)
  330. let part = try XCTUnwrap(candidate.content.parts.first)
  331. guard let functionCall = part as? FunctionCallPart else {
  332. XCTFail("Part is not a FunctionCall.")
  333. return
  334. }
  335. XCTAssertEqual(functionCall.name, "current_time")
  336. XCTAssertTrue(functionCall.args.isEmpty)
  337. XCTAssertEqual(response.functionCalls, [functionCall])
  338. }
  339. func testGenerateContent_success_functionCall_withArguments() async throws {
  340. MockURLProtocol
  341. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  342. forResource: "unary-success-function-call-with-arguments",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 1)
  350. let part = try XCTUnwrap(candidate.content.parts.first)
  351. guard let functionCall = part as? FunctionCallPart else {
  352. XCTFail("Part is not a FunctionCall.")
  353. return
  354. }
  355. XCTAssertEqual(functionCall.name, "sum")
  356. XCTAssertEqual(functionCall.args.count, 2)
  357. let argX = try XCTUnwrap(functionCall.args["x"])
  358. XCTAssertEqual(argX, .number(4))
  359. let argY = try XCTUnwrap(functionCall.args["y"])
  360. XCTAssertEqual(argY, .number(5))
  361. XCTAssertEqual(response.functionCalls, [functionCall])
  362. }
  363. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  364. MockURLProtocol
  365. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  366. forResource: "unary-success-function-call-parallel-calls",
  367. withExtension: "json",
  368. subdirectory: vertexSubdirectory
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. XCTAssertEqual(response.candidates.count, 1)
  372. let candidate = try XCTUnwrap(response.candidates.first)
  373. XCTAssertEqual(candidate.content.parts.count, 3)
  374. let functionCalls = response.functionCalls
  375. XCTAssertEqual(functionCalls.count, 3)
  376. }
  377. func testGenerateContent_success_functionCall_mixedContent() async throws {
  378. MockURLProtocol
  379. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  380. forResource: "unary-success-function-call-mixed-content",
  381. withExtension: "json",
  382. subdirectory: vertexSubdirectory
  383. )
  384. let response = try await model.generateContent(testPrompt)
  385. XCTAssertEqual(response.candidates.count, 1)
  386. let candidate = try XCTUnwrap(response.candidates.first)
  387. XCTAssertEqual(candidate.content.parts.count, 4)
  388. let functionCalls = response.functionCalls
  389. XCTAssertEqual(functionCalls.count, 2)
  390. let text = try XCTUnwrap(response.text)
  391. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  392. }
  393. func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
  394. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  395. forResource: "unary-success-image-invalid-safety-ratings",
  396. withExtension: "json",
  397. subdirectory: vertexSubdirectory
  398. )
  399. let response = try await model.generateContent(testPrompt)
  400. XCTAssertEqual(response.candidates.count, 1)
  401. let candidate = try XCTUnwrap(response.candidates.first)
  402. XCTAssertEqual(candidate.content.parts.count, 1)
  403. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  404. let inlineDataParts = response.inlineDataParts
  405. XCTAssertEqual(inlineDataParts.count, 1)
  406. let imagePart = try XCTUnwrap(inlineDataParts.first)
  407. XCTAssertEqual(imagePart.mimeType, "image/png")
  408. XCTAssertGreaterThan(imagePart.data.count, 0)
  409. }
  410. func testGenerateContent_appCheck_validToken() async throws {
  411. let appCheckToken = "test-valid-token"
  412. model = GenerativeModel(
  413. modelName: testModelName,
  414. modelResourceName: testModelResourceName,
  415. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  416. appCheck: AppCheckInteropFake(token: appCheckToken)
  417. ),
  418. apiConfig: apiConfig,
  419. tools: nil,
  420. requestOptions: RequestOptions(),
  421. urlSession: urlSession
  422. )
  423. MockURLProtocol
  424. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  425. forResource: "unary-success-basic-reply-short",
  426. withExtension: "json",
  427. subdirectory: vertexSubdirectory,
  428. appCheckToken: appCheckToken
  429. )
  430. _ = try await model.generateContent(testPrompt)
  431. }
  432. func testGenerateContent_dataCollectionOff() async throws {
  433. let appCheckToken = "test-valid-token"
  434. model = GenerativeModel(
  435. modelName: testModelName,
  436. modelResourceName: testModelResourceName,
  437. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  438. appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
  439. ),
  440. apiConfig: apiConfig,
  441. tools: nil,
  442. requestOptions: RequestOptions(),
  443. urlSession: urlSession
  444. )
  445. MockURLProtocol
  446. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  447. forResource: "unary-success-basic-reply-short",
  448. withExtension: "json",
  449. subdirectory: vertexSubdirectory,
  450. appCheckToken: appCheckToken,
  451. dataCollection: false
  452. )
  453. _ = try await model.generateContent(testPrompt)
  454. }
  455. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  456. model = GenerativeModel(
  457. modelName: testModelName,
  458. modelResourceName: testModelResourceName,
  459. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  460. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  461. ),
  462. apiConfig: apiConfig,
  463. tools: nil,
  464. requestOptions: RequestOptions(),
  465. urlSession: urlSession
  466. )
  467. MockURLProtocol
  468. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  469. forResource: "unary-success-basic-reply-short",
  470. withExtension: "json",
  471. subdirectory: vertexSubdirectory,
  472. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  473. )
  474. _ = try await model.generateContent(testPrompt)
  475. }
  476. func testGenerateContent_auth_validAuthToken() async throws {
  477. let authToken = "test-valid-token"
  478. model = GenerativeModel(
  479. modelName: testModelName,
  480. modelResourceName: testModelResourceName,
  481. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  482. auth: AuthInteropFake(token: authToken)
  483. ),
  484. apiConfig: apiConfig,
  485. tools: nil,
  486. requestOptions: RequestOptions(),
  487. urlSession: urlSession
  488. )
  489. MockURLProtocol
  490. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  491. forResource: "unary-success-basic-reply-short",
  492. withExtension: "json",
  493. subdirectory: vertexSubdirectory,
  494. authToken: authToken
  495. )
  496. _ = try await model.generateContent(testPrompt)
  497. }
  498. func testGenerateContent_auth_nilAuthToken() async throws {
  499. model = GenerativeModel(
  500. modelName: testModelName,
  501. modelResourceName: testModelResourceName,
  502. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  503. apiConfig: apiConfig,
  504. tools: nil,
  505. requestOptions: RequestOptions(),
  506. urlSession: urlSession
  507. )
  508. MockURLProtocol
  509. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  510. forResource: "unary-success-basic-reply-short",
  511. withExtension: "json",
  512. subdirectory: vertexSubdirectory,
  513. authToken: nil
  514. )
  515. _ = try await model.generateContent(testPrompt)
  516. }
  517. func testGenerateContent_auth_authTokenRefreshError() async throws {
  518. model = GenerativeModel(
  519. modelName: testModelName,
  520. modelResourceName: testModelResourceName,
  521. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  522. auth: AuthInteropFake(error: AuthErrorFake())
  523. ),
  524. apiConfig: apiConfig,
  525. tools: nil,
  526. requestOptions: RequestOptions(),
  527. urlSession: urlSession
  528. )
  529. MockURLProtocol
  530. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  531. forResource: "unary-success-basic-reply-short",
  532. withExtension: "json",
  533. subdirectory: vertexSubdirectory,
  534. authToken: nil
  535. )
  536. do {
  537. _ = try await model.generateContent(testPrompt)
  538. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  539. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  540. //
  541. } catch {
  542. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  543. }
  544. }
  545. func testGenerateContent_usageMetadata() async throws {
  546. MockURLProtocol
  547. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  548. forResource: "unary-success-basic-reply-short",
  549. withExtension: "json",
  550. subdirectory: vertexSubdirectory
  551. )
  552. let response = try await model.generateContent(testPrompt)
  553. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  554. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  555. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  556. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  557. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  558. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  559. }
  560. func testGenerateContent_failure_invalidAPIKey() async throws {
  561. let expectedStatusCode = 400
  562. MockURLProtocol
  563. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  564. forResource: "unary-failure-api-key",
  565. withExtension: "json",
  566. subdirectory: vertexSubdirectory,
  567. statusCode: expectedStatusCode
  568. )
  569. do {
  570. _ = try await model.generateContent(testPrompt)
  571. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  572. } catch let GenerateContentError.internalError(error as BackendError) {
  573. XCTAssertEqual(error.httpResponseCode, 400)
  574. XCTAssertEqual(error.status, .invalidArgument)
  575. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  576. XCTAssertTrue(error.localizedDescription.contains(error.message))
  577. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  578. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  579. let nsError = error as NSError
  580. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  581. XCTAssertEqual(nsError.code, error.httpResponseCode)
  582. return
  583. } catch {
  584. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  585. }
  586. }
  587. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  588. let expectedStatusCode = 403
  589. MockURLProtocol
  590. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  591. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  592. withExtension: "json",
  593. subdirectory: vertexSubdirectory,
  594. statusCode: expectedStatusCode
  595. )
  596. do {
  597. _ = try await model.generateContent(testPrompt)
  598. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  599. } catch let GenerateContentError.internalError(error as BackendError) {
  600. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  601. XCTAssertEqual(error.status, .permissionDenied)
  602. XCTAssertTrue(error.message
  603. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  604. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  605. return
  606. } catch {
  607. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  608. }
  609. }
  610. func testGenerateContent_failure_emptyContent() async throws {
  611. MockURLProtocol
  612. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  613. forResource: "unary-failure-empty-content",
  614. withExtension: "json",
  615. subdirectory: vertexSubdirectory
  616. )
  617. do {
  618. _ = try await model.generateContent(testPrompt)
  619. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  620. } catch let GenerateContentError
  621. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  622. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  623. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  624. return
  625. }
  626. _ = try XCTUnwrap(decodingError as? DecodingError,
  627. "Not a DecodingError: \(decodingError)")
  628. } catch {
  629. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  630. }
  631. }
  632. func testGenerateContent_failure_finishReasonSafety() async throws {
  633. MockURLProtocol
  634. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  635. forResource: "unary-failure-finish-reason-safety",
  636. withExtension: "json",
  637. subdirectory: vertexSubdirectory
  638. )
  639. do {
  640. _ = try await model.generateContent(testPrompt)
  641. XCTFail("Should throw")
  642. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  643. XCTAssertEqual(reason, .safety)
  644. XCTAssertEqual(response.text, "<redacted>")
  645. } catch {
  646. XCTFail("Should throw a responseStoppedEarly")
  647. }
  648. }
  649. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  650. MockURLProtocol
  651. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  652. forResource: "unary-failure-finish-reason-safety-no-content",
  653. withExtension: "json",
  654. subdirectory: vertexSubdirectory
  655. )
  656. do {
  657. _ = try await model.generateContent(testPrompt)
  658. XCTFail("Should throw")
  659. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  660. XCTAssertEqual(reason, .safety)
  661. XCTAssertNil(response.text)
  662. } catch {
  663. XCTFail("Should throw a responseStoppedEarly")
  664. }
  665. }
  666. func testGenerateContent_failure_imageRejected() async throws {
  667. let expectedStatusCode = 400
  668. MockURLProtocol
  669. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  670. forResource: "unary-failure-image-rejected",
  671. withExtension: "json",
  672. subdirectory: vertexSubdirectory,
  673. statusCode: 400
  674. )
  675. do {
  676. _ = try await model.generateContent(testPrompt)
  677. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  678. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  679. XCTAssertEqual(rpcError.status, .invalidArgument)
  680. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  681. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  682. } catch {
  683. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  684. }
  685. }
  686. func testGenerateContent_failure_promptBlockedSafety() async throws {
  687. MockURLProtocol
  688. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  689. forResource: "unary-failure-prompt-blocked-safety",
  690. withExtension: "json",
  691. subdirectory: vertexSubdirectory
  692. )
  693. do {
  694. _ = try await model.generateContent(testPrompt)
  695. XCTFail("Should throw")
  696. } catch let GenerateContentError.promptBlocked(response) {
  697. XCTAssertNil(response.text)
  698. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  699. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  700. XCTAssertNil(promptFeedback.blockReasonMessage)
  701. } catch {
  702. XCTFail("Should throw a promptBlocked")
  703. }
  704. }
  705. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  706. MockURLProtocol
  707. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  708. forResource: "unary-failure-prompt-blocked-safety-with-message",
  709. withExtension: "json",
  710. subdirectory: vertexSubdirectory
  711. )
  712. do {
  713. _ = try await model.generateContent(testPrompt)
  714. XCTFail("Should throw")
  715. } catch let GenerateContentError.promptBlocked(response) {
  716. XCTAssertNil(response.text)
  717. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  718. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  719. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  720. } catch {
  721. XCTFail("Should throw a promptBlocked")
  722. }
  723. }
  724. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  725. MockURLProtocol
  726. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  727. forResource: "unary-failure-unknown-enum-finish-reason",
  728. withExtension: "json",
  729. subdirectory: vertexSubdirectory
  730. )
  731. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  732. do {
  733. _ = try await model.generateContent(testPrompt)
  734. XCTFail("Should throw")
  735. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  736. XCTAssertEqual(reason, unknownFinishReason)
  737. XCTAssertEqual(response.text, "Some text")
  738. } catch {
  739. XCTFail("Should throw a responseStoppedEarly")
  740. }
  741. }
  742. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  743. MockURLProtocol
  744. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  745. forResource: "unary-failure-unknown-enum-prompt-blocked",
  746. withExtension: "json",
  747. subdirectory: vertexSubdirectory
  748. )
  749. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  750. do {
  751. _ = try await model.generateContent(testPrompt)
  752. XCTFail("Should throw")
  753. } catch let GenerateContentError.promptBlocked(response) {
  754. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  755. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  756. } catch {
  757. XCTFail("Should throw a promptBlocked")
  758. }
  759. }
  760. func testGenerateContent_failure_unknownModel() async throws {
  761. let expectedStatusCode = 404
  762. MockURLProtocol
  763. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  764. forResource: "unary-failure-unknown-model",
  765. withExtension: "json",
  766. subdirectory: vertexSubdirectory,
  767. statusCode: 404
  768. )
  769. do {
  770. _ = try await model.generateContent(testPrompt)
  771. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  772. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  773. XCTAssertEqual(rpcError.status, .notFound)
  774. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  775. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  776. } catch {
  777. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  778. }
  779. }
  780. func testGenerateContent_failure_nonHTTPResponse() async throws {
  781. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  782. var responseError: Error?
  783. var content: GenerateContentResponse?
  784. do {
  785. content = try await model.generateContent(testPrompt)
  786. } catch {
  787. responseError = error
  788. }
  789. XCTAssertNil(content)
  790. XCTAssertNotNil(responseError)
  791. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  792. guard case let .internalError(underlyingError) = generateContentError else {
  793. XCTFail("Not an internal error: \(generateContentError)")
  794. return
  795. }
  796. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  797. let underlyingNSError = underlyingError as NSError
  798. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  799. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  800. }
  801. func testGenerateContent_failure_invalidResponse() async throws {
  802. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  803. forResource: "unary-failure-invalid-response",
  804. withExtension: "json",
  805. subdirectory: vertexSubdirectory
  806. )
  807. var responseError: Error?
  808. var content: GenerateContentResponse?
  809. do {
  810. content = try await model.generateContent(testPrompt)
  811. } catch {
  812. responseError = error
  813. }
  814. XCTAssertNil(content)
  815. XCTAssertNotNil(responseError)
  816. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  817. guard case let .internalError(underlyingError) = generateContentError else {
  818. XCTFail("Not an internal error: \(generateContentError)")
  819. return
  820. }
  821. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  822. guard case let .dataCorrupted(context) = decodingError else {
  823. XCTFail("Not a data corrupted error: \(decodingError)")
  824. return
  825. }
  826. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  827. }
  828. func testGenerateContent_failure_malformedContent() async throws {
  829. MockURLProtocol
  830. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  831. forResource: "unary-failure-malformed-content",
  832. withExtension: "json",
  833. subdirectory: vertexSubdirectory
  834. )
  835. var responseError: Error?
  836. var content: GenerateContentResponse?
  837. do {
  838. content = try await model.generateContent(testPrompt)
  839. } catch {
  840. responseError = error
  841. }
  842. XCTAssertNil(content)
  843. XCTAssertNotNil(responseError)
  844. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  845. guard case let .internalError(underlyingError) = generateContentError else {
  846. XCTFail("Not an internal error: \(generateContentError)")
  847. return
  848. }
  849. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  850. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  851. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  852. return
  853. }
  854. _ = try XCTUnwrap(
  855. malformedContentUnderlyingError as? DecodingError,
  856. "Not a decoding error: \(malformedContentUnderlyingError)"
  857. )
  858. }
  859. func testGenerateContentMissingSafetyRatings() async throws {
  860. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  861. forResource: "unary-success-missing-safety-ratings",
  862. withExtension: "json",
  863. subdirectory: vertexSubdirectory
  864. )
  865. let content = try await model.generateContent(testPrompt)
  866. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  867. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  868. XCTAssertEqual(content.text, "This is the generated content.")
  869. }
  870. func testGenerateContent_requestOptions_customTimeout() async throws {
  871. let expectedTimeout = 150.0
  872. MockURLProtocol
  873. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  874. forResource: "unary-success-basic-reply-short",
  875. withExtension: "json",
  876. subdirectory: vertexSubdirectory,
  877. timeout: expectedTimeout
  878. )
  879. let requestOptions = RequestOptions(timeout: expectedTimeout)
  880. model = GenerativeModel(
  881. modelName: testModelName,
  882. modelResourceName: testModelResourceName,
  883. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  884. apiConfig: apiConfig,
  885. tools: nil,
  886. requestOptions: requestOptions,
  887. urlSession: urlSession
  888. )
  889. let response = try await model.generateContent(testPrompt)
  890. XCTAssertEqual(response.candidates.count, 1)
  891. }
  892. // MARK: - Generate Content (Streaming)
  893. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  894. MockURLProtocol
  895. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  896. forResource: "unary-failure-api-key",
  897. withExtension: "json",
  898. subdirectory: vertexSubdirectory
  899. )
  900. do {
  901. let stream = try model.generateContentStream("Hi")
  902. for try await _ in stream {
  903. XCTFail("No content is there, this shouldn't happen.")
  904. }
  905. } catch let GenerateContentError.internalError(error as BackendError) {
  906. XCTAssertEqual(error.httpResponseCode, 400)
  907. XCTAssertEqual(error.status, .invalidArgument)
  908. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  909. XCTAssertTrue(error.localizedDescription.contains(error.message))
  910. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  911. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  912. let nsError = error as NSError
  913. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  914. XCTAssertEqual(nsError.code, error.httpResponseCode)
  915. return
  916. }
  917. XCTFail("Should have caught an error.")
  918. }
  919. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  920. let expectedStatusCode = 403
  921. MockURLProtocol
  922. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  923. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  924. withExtension: "json",
  925. subdirectory: vertexSubdirectory,
  926. statusCode: expectedStatusCode
  927. )
  928. do {
  929. let stream = try model.generateContentStream(testPrompt)
  930. for try await _ in stream {
  931. XCTFail("No content is there, this shouldn't happen.")
  932. }
  933. } catch let GenerateContentError.internalError(error as BackendError) {
  934. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  935. XCTAssertEqual(error.status, .permissionDenied)
  936. XCTAssertTrue(error.message
  937. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  938. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  939. return
  940. }
  941. XCTFail("Should have caught an error.")
  942. }
  943. func testGenerateContentStream_failureEmptyContent() async throws {
  944. MockURLProtocol
  945. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  946. forResource: "streaming-failure-empty-content",
  947. withExtension: "txt",
  948. subdirectory: vertexSubdirectory
  949. )
  950. do {
  951. let stream = try model.generateContentStream("Hi")
  952. for try await _ in stream {
  953. XCTFail("No content is there, this shouldn't happen.")
  954. }
  955. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  956. // Underlying error is as expected, nothing else to check.
  957. return
  958. }
  959. XCTFail("Should have caught an error.")
  960. }
  961. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  962. MockURLProtocol
  963. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  964. forResource: "streaming-failure-finish-reason-safety",
  965. withExtension: "txt",
  966. subdirectory: vertexSubdirectory
  967. )
  968. do {
  969. let stream = try model.generateContentStream("Hi")
  970. for try await _ in stream {
  971. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  972. }
  973. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  974. XCTAssertEqual(reason, .safety)
  975. let candidate = try XCTUnwrap(response.candidates.first)
  976. XCTAssertEqual(candidate.finishReason, reason)
  977. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  978. return
  979. }
  980. XCTFail("Should have caught an error.")
  981. }
  982. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  983. MockURLProtocol
  984. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  985. forResource: "streaming-failure-prompt-blocked-safety",
  986. withExtension: "txt",
  987. subdirectory: vertexSubdirectory
  988. )
  989. do {
  990. let stream = try model.generateContentStream("Hi")
  991. for try await _ in stream {
  992. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  993. }
  994. } catch let GenerateContentError.promptBlocked(response) {
  995. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  996. XCTAssertEqual(promptFeedback.blockReason, .safety)
  997. XCTAssertNil(promptFeedback.blockReasonMessage)
  998. return
  999. }
  1000. XCTFail("Should have caught an error.")
  1001. }
  1002. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  1003. MockURLProtocol
  1004. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1005. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  1006. withExtension: "txt",
  1007. subdirectory: vertexSubdirectory
  1008. )
  1009. do {
  1010. let stream = try model.generateContentStream("Hi")
  1011. for try await _ in stream {
  1012. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1013. }
  1014. } catch let GenerateContentError.promptBlocked(response) {
  1015. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1016. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1017. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  1018. return
  1019. }
  1020. XCTFail("Should have caught an error.")
  1021. }
  1022. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  1023. MockURLProtocol
  1024. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1025. forResource: "streaming-failure-unknown-finish-enum",
  1026. withExtension: "txt",
  1027. subdirectory: vertexSubdirectory
  1028. )
  1029. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  1030. let stream = try model.generateContentStream("Hi")
  1031. do {
  1032. for try await content in stream {
  1033. XCTAssertNotNil(content.text)
  1034. }
  1035. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  1036. XCTAssertEqual(reason, unknownFinishReason)
  1037. return
  1038. }
  1039. XCTFail("Should have caught an error.")
  1040. }
  1041. func testGenerateContentStream_successBasicReplyLong() async throws {
  1042. MockURLProtocol
  1043. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1044. forResource: "streaming-success-basic-reply-long",
  1045. withExtension: "txt",
  1046. subdirectory: vertexSubdirectory
  1047. )
  1048. var responses = 0
  1049. let stream = try model.generateContentStream("Hi")
  1050. for try await content in stream {
  1051. XCTAssertNotNil(content.text)
  1052. responses += 1
  1053. }
  1054. XCTAssertEqual(responses, 4)
  1055. }
  1056. func testGenerateContentStream_successBasicReplyShort() async throws {
  1057. MockURLProtocol
  1058. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1059. forResource: "streaming-success-basic-reply-short",
  1060. withExtension: "txt",
  1061. subdirectory: vertexSubdirectory
  1062. )
  1063. var responses = 0
  1064. let stream = try model.generateContentStream("Hi")
  1065. for try await content in stream {
  1066. XCTAssertNotNil(content.text)
  1067. responses += 1
  1068. }
  1069. XCTAssertEqual(responses, 1)
  1070. }
  1071. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1072. MockURLProtocol
  1073. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1074. forResource: "streaming-success-unknown-safety-enum",
  1075. withExtension: "txt",
  1076. subdirectory: vertexSubdirectory
  1077. )
  1078. let unknownSafetyRating = SafetyRating(
  1079. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1080. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1081. probabilityScore: 0.0,
  1082. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1083. severityScore: 0.0,
  1084. blocked: false
  1085. )
  1086. var foundUnknownSafetyRating = false
  1087. let stream = try model.generateContentStream("Hi")
  1088. for try await content in stream {
  1089. XCTAssertNotNil(content.text)
  1090. if let ratings = content.candidates.first?.safetyRatings,
  1091. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1092. foundUnknownSafetyRating = true
  1093. }
  1094. }
  1095. XCTAssertTrue(foundUnknownSafetyRating)
  1096. }
  1097. func testGenerateContentStream_successWithCitations() async throws {
  1098. MockURLProtocol
  1099. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1100. forResource: "streaming-success-citations",
  1101. withExtension: "txt",
  1102. subdirectory: vertexSubdirectory
  1103. )
  1104. let expectedPublicationDate = DateComponents(
  1105. calendar: Calendar(identifier: .gregorian),
  1106. year: 2014,
  1107. month: 3,
  1108. day: 30
  1109. )
  1110. let stream = try model.generateContentStream("Hi")
  1111. var citations = [Citation]()
  1112. var responses = [GenerateContentResponse]()
  1113. for try await content in stream {
  1114. responses.append(content)
  1115. XCTAssertNotNil(content.text)
  1116. let candidate = try XCTUnwrap(content.candidates.first)
  1117. if let sources = candidate.citationMetadata?.citations {
  1118. citations.append(contentsOf: sources)
  1119. }
  1120. }
  1121. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1122. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1123. XCTAssertEqual(citations.count, 6)
  1124. XCTAssertTrue(citations
  1125. .contains {
  1126. $0.startIndex == 0 && $0.endIndex == 128
  1127. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1128. && $0.license == nil && $0.publicationDate == nil
  1129. })
  1130. XCTAssertTrue(citations
  1131. .contains {
  1132. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1133. && $0.title == "some-citation-2" && $0.license == nil
  1134. && $0.publicationDate == expectedPublicationDate
  1135. })
  1136. XCTAssertTrue(citations
  1137. .contains {
  1138. $0.startIndex == 272 && $0.endIndex == 431
  1139. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1140. && $0.license == "mit" && $0.publicationDate == nil
  1141. })
  1142. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1143. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1144. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1145. }
  1146. func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
  1147. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1148. forResource: "streaming-success-image-invalid-safety-ratings",
  1149. withExtension: "txt",
  1150. subdirectory: vertexSubdirectory
  1151. )
  1152. let stream = try model.generateContentStream(testPrompt)
  1153. var responses = [GenerateContentResponse]()
  1154. for try await content in stream {
  1155. responses.append(content)
  1156. }
  1157. let response = try XCTUnwrap(responses.first)
  1158. XCTAssertEqual(response.candidates.count, 1)
  1159. let candidate = try XCTUnwrap(response.candidates.first)
  1160. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  1161. XCTAssertEqual(candidate.content.parts.count, 1)
  1162. let inlineDataParts = response.inlineDataParts
  1163. XCTAssertEqual(inlineDataParts.count, 1)
  1164. let imagePart = try XCTUnwrap(inlineDataParts.first)
  1165. XCTAssertEqual(imagePart.mimeType, "image/png")
  1166. XCTAssertGreaterThan(imagePart.data.count, 0)
  1167. }
  1168. func testGenerateContentStream_appCheck_validToken() async throws {
  1169. let appCheckToken = "test-valid-token"
  1170. model = GenerativeModel(
  1171. modelName: testModelName,
  1172. modelResourceName: testModelResourceName,
  1173. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1174. appCheck: AppCheckInteropFake(token: appCheckToken)
  1175. ),
  1176. apiConfig: apiConfig,
  1177. tools: nil,
  1178. requestOptions: RequestOptions(),
  1179. urlSession: urlSession
  1180. )
  1181. MockURLProtocol
  1182. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1183. forResource: "streaming-success-basic-reply-short",
  1184. withExtension: "txt",
  1185. subdirectory: vertexSubdirectory,
  1186. appCheckToken: appCheckToken
  1187. )
  1188. let stream = try model.generateContentStream(testPrompt)
  1189. for try await _ in stream {}
  1190. }
  1191. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1192. model = GenerativeModel(
  1193. modelName: testModelName,
  1194. modelResourceName: testModelResourceName,
  1195. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1196. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  1197. ),
  1198. apiConfig: apiConfig,
  1199. tools: nil,
  1200. requestOptions: RequestOptions(),
  1201. urlSession: urlSession
  1202. )
  1203. MockURLProtocol
  1204. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1205. forResource: "streaming-success-basic-reply-short",
  1206. withExtension: "txt",
  1207. subdirectory: vertexSubdirectory,
  1208. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1209. )
  1210. let stream = try model.generateContentStream(testPrompt)
  1211. for try await _ in stream {}
  1212. }
  1213. func testGenerateContentStream_usageMetadata() async throws {
  1214. MockURLProtocol
  1215. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1216. forResource: "streaming-success-basic-reply-short",
  1217. withExtension: "txt",
  1218. subdirectory: vertexSubdirectory
  1219. )
  1220. var responses = [GenerateContentResponse]()
  1221. let stream = try model.generateContentStream(testPrompt)
  1222. for try await response in stream {
  1223. responses.append(response)
  1224. }
  1225. for (index, response) in responses.enumerated() {
  1226. if index == responses.endIndex - 1 {
  1227. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1228. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1229. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1230. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1231. } else {
  1232. // Only the last streamed response contains usage metadata
  1233. XCTAssertNil(response.usageMetadata)
  1234. }
  1235. }
  1236. }
  1237. func testGenerateContentStream_errorMidStream() async throws {
  1238. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1239. forResource: "streaming-failure-error-mid-stream",
  1240. withExtension: "txt",
  1241. subdirectory: vertexSubdirectory
  1242. )
  1243. var responseCount = 0
  1244. do {
  1245. let stream = try model.generateContentStream("Hi")
  1246. for try await content in stream {
  1247. XCTAssertNotNil(content.text)
  1248. responseCount += 1
  1249. }
  1250. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1251. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1252. XCTAssertEqual(rpcError.status, .cancelled)
  1253. // Check the content count is correct.
  1254. XCTAssertEqual(responseCount, 2)
  1255. return
  1256. }
  1257. XCTFail("Expected an internalError with an RPCError.")
  1258. }
  1259. func testGenerateContentStream_nonHTTPResponse() async throws {
  1260. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  1261. let stream = try model.generateContentStream("Hi")
  1262. do {
  1263. for try await content in stream {
  1264. XCTFail("Unexpected content in stream: \(content)")
  1265. }
  1266. } catch let GenerateContentError.internalError(underlying) {
  1267. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1268. return
  1269. }
  1270. XCTFail("Expected an internal error.")
  1271. }
  1272. func testGenerateContentStream_invalidResponse() async throws {
  1273. MockURLProtocol
  1274. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1275. forResource: "streaming-failure-invalid-json",
  1276. withExtension: "txt",
  1277. subdirectory: vertexSubdirectory
  1278. )
  1279. let stream = try model.generateContentStream(testPrompt)
  1280. do {
  1281. for try await content in stream {
  1282. XCTFail("Unexpected content in stream: \(content)")
  1283. }
  1284. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1285. guard case let .dataCorrupted(context) = underlying else {
  1286. XCTFail("Not a data corrupted error: \(underlying)")
  1287. return
  1288. }
  1289. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1290. return
  1291. }
  1292. XCTFail("Expected an internal error.")
  1293. }
  1294. func testGenerateContentStream_malformedContent() async throws {
  1295. MockURLProtocol
  1296. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1297. forResource: "streaming-failure-malformed-content",
  1298. withExtension: "txt",
  1299. subdirectory: vertexSubdirectory
  1300. )
  1301. let stream = try model.generateContentStream(testPrompt)
  1302. do {
  1303. for try await content in stream {
  1304. XCTFail("Unexpected content in stream: \(content)")
  1305. }
  1306. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1307. guard case let .malformedContent(contentError) = underlyingError else {
  1308. XCTFail("Not a malformed content error: \(underlyingError)")
  1309. return
  1310. }
  1311. XCTAssert(contentError is DecodingError)
  1312. return
  1313. }
  1314. XCTFail("Expected an internal decoding error.")
  1315. }
  1316. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1317. let expectedTimeout = 150.0
  1318. MockURLProtocol
  1319. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1320. forResource: "streaming-success-basic-reply-short",
  1321. withExtension: "txt",
  1322. subdirectory: vertexSubdirectory,
  1323. timeout: expectedTimeout
  1324. )
  1325. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1326. model = GenerativeModel(
  1327. modelName: testModelName,
  1328. modelResourceName: testModelResourceName,
  1329. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1330. apiConfig: apiConfig,
  1331. tools: nil,
  1332. requestOptions: requestOptions,
  1333. urlSession: urlSession
  1334. )
  1335. var responses = 0
  1336. let stream = try model.generateContentStream(testPrompt)
  1337. for try await content in stream {
  1338. XCTAssertNotNil(content.text)
  1339. responses += 1
  1340. }
  1341. XCTAssertEqual(responses, 1)
  1342. }
  1343. // MARK: - Count Tokens
  1344. func testCountTokens_succeeds() async throws {
  1345. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1346. forResource: "unary-success-total-tokens",
  1347. withExtension: "json",
  1348. subdirectory: vertexSubdirectory
  1349. )
  1350. let response = try await model.countTokens("Why is the sky blue?")
  1351. XCTAssertEqual(response.totalTokens, 6)
  1352. XCTAssertEqual(response.totalBillableCharacters, 16)
  1353. }
  1354. func testCountTokens_succeeds_detailed() async throws {
  1355. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1356. forResource: "unary-success-detailed-token-response",
  1357. withExtension: "json",
  1358. subdirectory: vertexSubdirectory
  1359. )
  1360. let response = try await model.countTokens("Why is the sky blue?")
  1361. XCTAssertEqual(response.totalTokens, 1837)
  1362. XCTAssertEqual(response.totalBillableCharacters, 117)
  1363. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1364. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1365. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1366. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1367. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1368. }
  1369. func testCountTokens_succeeds_allOptions() async throws {
  1370. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1371. forResource: "unary-success-total-tokens",
  1372. withExtension: "json",
  1373. subdirectory: vertexSubdirectory
  1374. )
  1375. let generationConfig = GenerationConfig(
  1376. temperature: 0.5,
  1377. topP: 0.9,
  1378. topK: 3,
  1379. candidateCount: 1,
  1380. maxOutputTokens: 1024,
  1381. stopSequences: ["test-stop"],
  1382. responseMIMEType: "text/plain"
  1383. )
  1384. let sumFunction = FunctionDeclaration(
  1385. name: "sum",
  1386. description: "Add two integers.",
  1387. parameters: ["x": .integer(), "y": .integer()]
  1388. )
  1389. let systemInstruction = ModelContent(
  1390. role: "system",
  1391. parts: "You are a calculator. Use the provided tools."
  1392. )
  1393. model = GenerativeModel(
  1394. modelName: testModelName,
  1395. modelResourceName: testModelResourceName,
  1396. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1397. apiConfig: apiConfig,
  1398. generationConfig: generationConfig,
  1399. tools: [Tool(functionDeclarations: [sumFunction])],
  1400. systemInstruction: systemInstruction,
  1401. requestOptions: RequestOptions(),
  1402. urlSession: urlSession
  1403. )
  1404. let response = try await model.countTokens("Why is the sky blue?")
  1405. XCTAssertEqual(response.totalTokens, 6)
  1406. XCTAssertEqual(response.totalBillableCharacters, 16)
  1407. }
  1408. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1409. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1410. forResource: "unary-success-no-billable-characters",
  1411. withExtension: "json",
  1412. subdirectory: vertexSubdirectory
  1413. )
  1414. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1415. XCTAssertEqual(response.totalTokens, 258)
  1416. XCTAssertNil(response.totalBillableCharacters)
  1417. }
  1418. func testCountTokens_modelNotFound() async throws {
  1419. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1420. forResource: "unary-failure-model-not-found", withExtension: "json",
  1421. subdirectory: vertexSubdirectory,
  1422. statusCode: 404
  1423. )
  1424. do {
  1425. _ = try await model.countTokens("Why is the sky blue?")
  1426. XCTFail("Request should not have succeeded.")
  1427. } catch let rpcError as BackendError {
  1428. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1429. XCTAssertEqual(rpcError.status, .notFound)
  1430. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1431. return
  1432. }
  1433. XCTFail("Expected internal RPCError.")
  1434. }
  1435. func testCountTokens_requestOptions_customTimeout() async throws {
  1436. let expectedTimeout = 150.0
  1437. MockURLProtocol
  1438. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1439. forResource: "unary-success-total-tokens",
  1440. withExtension: "json",
  1441. subdirectory: vertexSubdirectory,
  1442. timeout: expectedTimeout
  1443. )
  1444. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1445. model = GenerativeModel(
  1446. modelName: testModelName,
  1447. modelResourceName: testModelResourceName,
  1448. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1449. apiConfig: apiConfig,
  1450. tools: nil,
  1451. requestOptions: requestOptions,
  1452. urlSession: urlSession
  1453. )
  1454. let response = try await model.countTokens(testPrompt)
  1455. XCTAssertEqual(response.totalTokens, 6)
  1456. }
  1457. }
  1458. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1459. extension SafetyRating: Swift.Comparable {
  1460. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1461. return lhs.category.rawValue < rhs.category.rawValue
  1462. }
  1463. }