GenerativeModelVertexAITests.swift 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelVertexAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let safetyRatingsInvalidIgnored = [
  57. SafetyRating(
  58. category: .hateSpeech,
  59. probability: .negligible,
  60. probabilityScore: 0.00039444832,
  61. severity: .negligible,
  62. severityScore: 0.0,
  63. blocked: false
  64. ),
  65. SafetyRating(
  66. category: .dangerousContent,
  67. probability: .negligible,
  68. probabilityScore: 0.0010654529,
  69. severity: .negligible,
  70. severityScore: 0.0049325973,
  71. blocked: false
  72. ),
  73. SafetyRating(
  74. category: .harassment,
  75. probability: .negligible,
  76. probabilityScore: 0.00026658305,
  77. severity: .negligible,
  78. severityScore: 0.0,
  79. blocked: false
  80. ),
  81. SafetyRating(
  82. category: .sexuallyExplicit,
  83. probability: .negligible,
  84. probabilityScore: 0.0013701695,
  85. severity: .negligible,
  86. severityScore: 0.07626295,
  87. blocked: false
  88. ),
  89. // Ignored Invalid Safety Ratings: {},{},{},{}
  90. ].sorted()
  91. let testModelName = "test-model"
  92. let testModelResourceName =
  93. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  94. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  95. let vertexSubdirectory = "mock-responses/vertexai"
  96. var urlSession: URLSession!
  97. var model: GenerativeModel!
  98. override func setUp() async throws {
  99. let configuration = URLSessionConfiguration.default
  100. configuration.protocolClasses = [MockURLProtocol.self]
  101. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  102. model = GenerativeModel(
  103. modelName: testModelName,
  104. modelResourceName: testModelResourceName,
  105. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  106. apiConfig: apiConfig,
  107. tools: nil,
  108. requestOptions: RequestOptions(),
  109. urlSession: urlSession
  110. )
  111. }
  112. override func tearDown() {
  113. MockURLProtocol.requestHandler = nil
  114. }
  115. // MARK: - Generate Content
  116. func testGenerateContent_success_basicReplyLong() async throws {
  117. MockURLProtocol
  118. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  119. forResource: "unary-success-basic-reply-long",
  120. withExtension: "json",
  121. subdirectory: vertexSubdirectory
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. XCTAssertEqual(candidate.safetyRatings.count, 4)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let part = try XCTUnwrap(candidate.content.parts.first)
  131. let partText = try XCTUnwrap(part as? TextPart).text
  132. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  133. XCTAssertEqual(response.text, partText)
  134. XCTAssertEqual(response.functionCalls, [])
  135. XCTAssertEqual(response.inlineDataParts, [])
  136. }
  137. func testGenerateContent_success_basicReplyShort() async throws {
  138. MockURLProtocol
  139. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  140. forResource: "unary-success-basic-reply-short",
  141. withExtension: "json",
  142. subdirectory: vertexSubdirectory
  143. )
  144. let response = try await model.generateContent(testPrompt)
  145. XCTAssertEqual(response.candidates.count, 1)
  146. let candidate = try XCTUnwrap(response.candidates.first)
  147. let finishReason = try XCTUnwrap(candidate.finishReason)
  148. XCTAssertEqual(finishReason, .stop)
  149. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  150. XCTAssertEqual(candidate.content.parts.count, 1)
  151. let part = try XCTUnwrap(candidate.content.parts.first)
  152. let textPart = try XCTUnwrap(part as? TextPart)
  153. XCTAssertEqual(textPart.text, "Mountain View, California")
  154. XCTAssertEqual(response.text, textPart.text)
  155. XCTAssertEqual(response.functionCalls, [])
  156. }
  157. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  158. MockURLProtocol
  159. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  160. forResource: "unary-success-basic-response-long-usage-metadata",
  161. withExtension: "json",
  162. subdirectory: vertexSubdirectory
  163. )
  164. let response = try await model.generateContent(testPrompt)
  165. XCTAssertEqual(response.candidates.count, 1)
  166. let candidate = try XCTUnwrap(response.candidates.first)
  167. let finishReason = try XCTUnwrap(candidate.finishReason)
  168. XCTAssertEqual(finishReason, .stop)
  169. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  170. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  171. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  175. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  178. }
  179. func testGenerateContent_success_citations() async throws {
  180. MockURLProtocol
  181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  182. forResource: "unary-success-citations",
  183. withExtension: "json",
  184. subdirectory: vertexSubdirectory
  185. )
  186. let expectedPublicationDate = DateComponents(
  187. calendar: Calendar(identifier: .gregorian),
  188. year: 2019,
  189. month: 5,
  190. day: 10
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. XCTAssertEqual(candidate.content.parts.count, 1)
  196. XCTAssertEqual(response.text, "Some information cited from an external source")
  197. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  198. XCTAssertEqual(citationMetadata.citations.count, 3)
  199. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  200. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  201. XCTAssertEqual(citationSource1.startIndex, 0)
  202. XCTAssertEqual(citationSource1.endIndex, 128)
  203. XCTAssertNil(citationSource1.title)
  204. XCTAssertNil(citationSource1.license)
  205. XCTAssertNil(citationSource1.publicationDate)
  206. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  207. XCTAssertEqual(citationSource2.title, "some-citation-2")
  208. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  209. XCTAssertEqual(citationSource2.startIndex, 130)
  210. XCTAssertEqual(citationSource2.endIndex, 265)
  211. XCTAssertNil(citationSource2.uri)
  212. XCTAssertNil(citationSource2.license)
  213. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  214. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  215. XCTAssertEqual(citationSource3.startIndex, 272)
  216. XCTAssertEqual(citationSource3.endIndex, 431)
  217. XCTAssertEqual(citationSource3.license, "mit")
  218. XCTAssertNil(citationSource3.title)
  219. XCTAssertNil(citationSource3.publicationDate)
  220. }
  221. func testGenerateContent_success_quoteReply() async throws {
  222. MockURLProtocol
  223. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  224. forResource: "unary-success-quote-reply",
  225. withExtension: "json",
  226. subdirectory: vertexSubdirectory
  227. )
  228. let response = try await model.generateContent(testPrompt)
  229. XCTAssertEqual(response.candidates.count, 1)
  230. let candidate = try XCTUnwrap(response.candidates.first)
  231. let finishReason = try XCTUnwrap(candidate.finishReason)
  232. XCTAssertEqual(finishReason, .stop)
  233. XCTAssertEqual(candidate.safetyRatings.count, 4)
  234. XCTAssertEqual(candidate.content.parts.count, 1)
  235. let part = try XCTUnwrap(candidate.content.parts.first)
  236. let textPart = try XCTUnwrap(part as? TextPart)
  237. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  238. XCTAssertEqual(response.text, textPart.text)
  239. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  240. XCTAssertNil(promptFeedback.blockReason)
  241. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  242. }
  243. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  244. let expectedSafetyRatings = [
  245. SafetyRating(
  246. category: .harassment,
  247. probability: .medium,
  248. probabilityScore: 0.0,
  249. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  250. severityScore: 0.0,
  251. blocked: false
  252. ),
  253. SafetyRating(
  254. category: .dangerousContent,
  255. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  256. probabilityScore: 0.0,
  257. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  258. severityScore: 0.0,
  259. blocked: false
  260. ),
  261. SafetyRating(
  262. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  263. probability: .high,
  264. probabilityScore: 0.0,
  265. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  266. severityScore: 0.0,
  267. blocked: false
  268. ),
  269. ]
  270. MockURLProtocol
  271. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "unary-success-unknown-enum-safety-ratings",
  273. withExtension: "json",
  274. subdirectory: vertexSubdirectory
  275. )
  276. let response = try await model.generateContent(testPrompt)
  277. XCTAssertEqual(response.text, "Some text")
  278. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  279. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  280. }
  281. func testGenerateContent_success_prefixedModelName() async throws {
  282. MockURLProtocol
  283. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  284. forResource: "unary-success-basic-reply-short",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let model = GenerativeModel(
  289. modelName: testModelName,
  290. modelResourceName: testModelResourceName,
  291. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  292. apiConfig: apiConfig,
  293. tools: nil,
  294. requestOptions: RequestOptions(),
  295. urlSession: urlSession
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  300. MockURLProtocol
  301. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  302. forResource: "unary-success-function-call-empty-arguments",
  303. withExtension: "json",
  304. subdirectory: vertexSubdirectory
  305. )
  306. let response = try await model.generateContent(testPrompt)
  307. XCTAssertEqual(response.candidates.count, 1)
  308. let candidate = try XCTUnwrap(response.candidates.first)
  309. XCTAssertEqual(candidate.content.parts.count, 1)
  310. let part = try XCTUnwrap(candidate.content.parts.first)
  311. guard let functionCall = part as? FunctionCallPart else {
  312. XCTFail("Part is not a FunctionCall.")
  313. return
  314. }
  315. XCTAssertEqual(functionCall.name, "current_time")
  316. XCTAssertTrue(functionCall.args.isEmpty)
  317. XCTAssertEqual(response.functionCalls, [functionCall])
  318. }
  319. func testGenerateContent_success_functionCall_noArguments() async throws {
  320. MockURLProtocol
  321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  322. forResource: "unary-success-function-call-no-arguments",
  323. withExtension: "json",
  324. subdirectory: vertexSubdirectory
  325. )
  326. let response = try await model.generateContent(testPrompt)
  327. XCTAssertEqual(response.candidates.count, 1)
  328. let candidate = try XCTUnwrap(response.candidates.first)
  329. XCTAssertEqual(candidate.content.parts.count, 1)
  330. let part = try XCTUnwrap(candidate.content.parts.first)
  331. guard let functionCall = part as? FunctionCallPart else {
  332. XCTFail("Part is not a FunctionCall.")
  333. return
  334. }
  335. XCTAssertEqual(functionCall.name, "current_time")
  336. XCTAssertTrue(functionCall.args.isEmpty)
  337. XCTAssertEqual(response.functionCalls, [functionCall])
  338. }
  339. func testGenerateContent_success_functionCall_withArguments() async throws {
  340. MockURLProtocol
  341. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  342. forResource: "unary-success-function-call-with-arguments",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 1)
  350. let part = try XCTUnwrap(candidate.content.parts.first)
  351. guard let functionCall = part as? FunctionCallPart else {
  352. XCTFail("Part is not a FunctionCall.")
  353. return
  354. }
  355. XCTAssertEqual(functionCall.name, "sum")
  356. XCTAssertEqual(functionCall.args.count, 2)
  357. let argX = try XCTUnwrap(functionCall.args["x"])
  358. XCTAssertEqual(argX, .number(4))
  359. let argY = try XCTUnwrap(functionCall.args["y"])
  360. XCTAssertEqual(argY, .number(5))
  361. XCTAssertEqual(response.functionCalls, [functionCall])
  362. }
  363. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  364. MockURLProtocol
  365. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  366. forResource: "unary-success-function-call-parallel-calls",
  367. withExtension: "json",
  368. subdirectory: vertexSubdirectory
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. XCTAssertEqual(response.candidates.count, 1)
  372. let candidate = try XCTUnwrap(response.candidates.first)
  373. XCTAssertEqual(candidate.content.parts.count, 3)
  374. let functionCalls = response.functionCalls
  375. XCTAssertEqual(functionCalls.count, 3)
  376. }
  377. func testGenerateContent_success_functionCall_mixedContent() async throws {
  378. MockURLProtocol
  379. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  380. forResource: "unary-success-function-call-mixed-content",
  381. withExtension: "json",
  382. subdirectory: vertexSubdirectory
  383. )
  384. let response = try await model.generateContent(testPrompt)
  385. XCTAssertEqual(response.candidates.count, 1)
  386. let candidate = try XCTUnwrap(response.candidates.first)
  387. XCTAssertEqual(candidate.content.parts.count, 4)
  388. let functionCalls = response.functionCalls
  389. XCTAssertEqual(functionCalls.count, 2)
  390. let text = try XCTUnwrap(response.text)
  391. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  392. }
  393. func testGenerateContent_success_thinking_thoughtSummary() async throws {
  394. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  395. forResource: "unary-success-thinking-reply-thought-summary",
  396. withExtension: "json",
  397. subdirectory: vertexSubdirectory
  398. )
  399. let response = try await model.generateContent(testPrompt)
  400. XCTAssertEqual(response.candidates.count, 1)
  401. let candidate = try XCTUnwrap(response.candidates.first)
  402. XCTAssertEqual(candidate.finishReason, .stop)
  403. XCTAssertEqual(candidate.content.parts.count, 2)
  404. let thoughtPart = try XCTUnwrap(candidate.content.parts.first as? TextPart)
  405. XCTAssertTrue(thoughtPart.isThought)
  406. XCTAssertTrue(thoughtPart.text.hasPrefix("Right, someone needs the city where Google"))
  407. XCTAssertEqual(response.thoughtSummary, thoughtPart.text)
  408. let textPart = try XCTUnwrap(candidate.content.parts.last as? TextPart)
  409. XCTAssertFalse(textPart.isThought)
  410. XCTAssertEqual(textPart.text, "Mountain View")
  411. XCTAssertEqual(response.text, textPart.text)
  412. }
  413. func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
  414. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  415. forResource: "unary-success-image-invalid-safety-ratings",
  416. withExtension: "json",
  417. subdirectory: vertexSubdirectory
  418. )
  419. let response = try await model.generateContent(testPrompt)
  420. XCTAssertEqual(response.candidates.count, 1)
  421. let candidate = try XCTUnwrap(response.candidates.first)
  422. XCTAssertEqual(candidate.content.parts.count, 1)
  423. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  424. let inlineDataParts = response.inlineDataParts
  425. XCTAssertEqual(inlineDataParts.count, 1)
  426. let imagePart = try XCTUnwrap(inlineDataParts.first)
  427. XCTAssertEqual(imagePart.mimeType, "image/png")
  428. XCTAssertGreaterThan(imagePart.data.count, 0)
  429. }
  430. func testGenerateContent_success_image_emptyPartIgnored() async throws {
  431. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  432. forResource: "unary-success-empty-part",
  433. withExtension: "json",
  434. subdirectory: vertexSubdirectory
  435. )
  436. let response = try await model.generateContent(testPrompt)
  437. XCTAssertEqual(response.candidates.count, 1)
  438. let candidate = try XCTUnwrap(response.candidates.first)
  439. XCTAssertEqual(candidate.content.parts.count, 2)
  440. let inlineDataParts = response.inlineDataParts
  441. XCTAssertEqual(inlineDataParts.count, 1)
  442. let imagePart = try XCTUnwrap(inlineDataParts.first)
  443. XCTAssertEqual(imagePart.mimeType, "image/png")
  444. XCTAssertGreaterThan(imagePart.data.count, 0)
  445. let text = try XCTUnwrap(response.text)
  446. XCTAssertTrue(text.starts(with: "I can certainly help you with that"))
  447. }
  448. func testGenerateContent_appCheck_validToken() async throws {
  449. let appCheckToken = "test-valid-token"
  450. model = GenerativeModel(
  451. modelName: testModelName,
  452. modelResourceName: testModelResourceName,
  453. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  454. appCheck: AppCheckInteropFake(token: appCheckToken)
  455. ),
  456. apiConfig: apiConfig,
  457. tools: nil,
  458. requestOptions: RequestOptions(),
  459. urlSession: urlSession
  460. )
  461. MockURLProtocol
  462. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  463. forResource: "unary-success-basic-reply-short",
  464. withExtension: "json",
  465. subdirectory: vertexSubdirectory,
  466. appCheckToken: appCheckToken
  467. )
  468. _ = try await model.generateContent(testPrompt)
  469. }
  470. func testGenerateContent_appCheck_validToken_limitedUse() async throws {
  471. let appCheckToken = "test-valid-token"
  472. model = GenerativeModel(
  473. modelName: testModelName,
  474. modelResourceName: testModelResourceName,
  475. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  476. appCheck: AppCheckInteropFake(token: appCheckToken),
  477. useLimitedUseAppCheckTokens: true
  478. ),
  479. apiConfig: apiConfig,
  480. tools: nil,
  481. requestOptions: RequestOptions(),
  482. urlSession: urlSession
  483. )
  484. MockURLProtocol
  485. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  486. forResource: "unary-success-basic-reply-short",
  487. withExtension: "json",
  488. subdirectory: vertexSubdirectory,
  489. appCheckToken: "limited_use_\(appCheckToken)"
  490. )
  491. _ = try await model.generateContent(testPrompt)
  492. }
  493. func testGenerateContent_dataCollectionOff() async throws {
  494. let appCheckToken = "test-valid-token"
  495. model = GenerativeModel(
  496. modelName: testModelName,
  497. modelResourceName: testModelResourceName,
  498. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  499. appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
  500. ),
  501. apiConfig: apiConfig,
  502. tools: nil,
  503. requestOptions: RequestOptions(),
  504. urlSession: urlSession
  505. )
  506. MockURLProtocol
  507. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  508. forResource: "unary-success-basic-reply-short",
  509. withExtension: "json",
  510. subdirectory: vertexSubdirectory,
  511. appCheckToken: appCheckToken,
  512. dataCollection: false
  513. )
  514. _ = try await model.generateContent(testPrompt)
  515. }
  516. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  517. model = GenerativeModel(
  518. modelName: testModelName,
  519. modelResourceName: testModelResourceName,
  520. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  521. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  522. ),
  523. apiConfig: apiConfig,
  524. tools: nil,
  525. requestOptions: RequestOptions(),
  526. urlSession: urlSession
  527. )
  528. MockURLProtocol
  529. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  530. forResource: "unary-success-basic-reply-short",
  531. withExtension: "json",
  532. subdirectory: vertexSubdirectory,
  533. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  534. )
  535. _ = try await model.generateContent(testPrompt)
  536. }
  537. func testGenerateContent_auth_validAuthToken() async throws {
  538. let authToken = "test-valid-token"
  539. model = GenerativeModel(
  540. modelName: testModelName,
  541. modelResourceName: testModelResourceName,
  542. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  543. auth: AuthInteropFake(token: authToken)
  544. ),
  545. apiConfig: apiConfig,
  546. tools: nil,
  547. requestOptions: RequestOptions(),
  548. urlSession: urlSession
  549. )
  550. MockURLProtocol
  551. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  552. forResource: "unary-success-basic-reply-short",
  553. withExtension: "json",
  554. subdirectory: vertexSubdirectory,
  555. authToken: authToken
  556. )
  557. _ = try await model.generateContent(testPrompt)
  558. }
  559. func testGenerateContent_auth_nilAuthToken() async throws {
  560. model = GenerativeModel(
  561. modelName: testModelName,
  562. modelResourceName: testModelResourceName,
  563. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  564. apiConfig: apiConfig,
  565. tools: nil,
  566. requestOptions: RequestOptions(),
  567. urlSession: urlSession
  568. )
  569. MockURLProtocol
  570. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  571. forResource: "unary-success-basic-reply-short",
  572. withExtension: "json",
  573. subdirectory: vertexSubdirectory,
  574. authToken: nil
  575. )
  576. _ = try await model.generateContent(testPrompt)
  577. }
  578. func testGenerateContent_auth_authTokenRefreshError() async throws {
  579. model = GenerativeModel(
  580. modelName: testModelName,
  581. modelResourceName: testModelResourceName,
  582. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  583. auth: AuthInteropFake(error: AuthErrorFake())
  584. ),
  585. apiConfig: apiConfig,
  586. tools: nil,
  587. requestOptions: RequestOptions(),
  588. urlSession: urlSession
  589. )
  590. MockURLProtocol
  591. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  592. forResource: "unary-success-basic-reply-short",
  593. withExtension: "json",
  594. subdirectory: vertexSubdirectory,
  595. authToken: nil
  596. )
  597. do {
  598. _ = try await model.generateContent(testPrompt)
  599. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  600. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  601. //
  602. } catch {
  603. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  604. }
  605. }
  606. func testGenerateContent_usageMetadata() async throws {
  607. MockURLProtocol
  608. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  609. forResource: "unary-success-basic-reply-short",
  610. withExtension: "json",
  611. subdirectory: vertexSubdirectory
  612. )
  613. let response = try await model.generateContent(testPrompt)
  614. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  615. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  616. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  617. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  618. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  619. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  620. }
  621. func testGenerateContent_groundingMetadata() async throws {
  622. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  623. forResource: "unary-success-google-search-grounding",
  624. withExtension: "json",
  625. subdirectory: vertexSubdirectory
  626. )
  627. let response = try await model.generateContent(testPrompt)
  628. XCTAssertEqual(response.candidates.count, 1)
  629. let candidate = try XCTUnwrap(response.candidates.first)
  630. let groundingMetadata = try XCTUnwrap(candidate.groundingMetadata)
  631. XCTAssertEqual(groundingMetadata.webSearchQueries, ["current weather in London"])
  632. XCTAssertNotNil(groundingMetadata.searchEntryPoint)
  633. XCTAssertNotNil(groundingMetadata.searchEntryPoint?.renderedContent)
  634. XCTAssertEqual(groundingMetadata.groundingChunks.count, 2)
  635. let firstChunk = try XCTUnwrap(groundingMetadata.groundingChunks.first?.web)
  636. XCTAssertEqual(firstChunk.title, "accuweather.com")
  637. XCTAssertNotNil(firstChunk.uri)
  638. XCTAssertNil(firstChunk.domain) // Domain is not supported by Google AI backend
  639. XCTAssertEqual(groundingMetadata.groundingSupports.count, 3)
  640. let firstSupport = try XCTUnwrap(groundingMetadata.groundingSupports.first)
  641. let segment = try XCTUnwrap(firstSupport.segment)
  642. XCTAssertEqual(segment.text, "The current weather in London, United Kingdom is cloudy.")
  643. XCTAssertEqual(segment.startIndex, 0)
  644. XCTAssertEqual(segment.partIndex, 0)
  645. XCTAssertEqual(segment.endIndex, 56)
  646. XCTAssertEqual(firstSupport.groundingChunkIndices, [0])
  647. }
  648. func testGenerateContent_withGoogleSearchTool() async throws {
  649. let model = GenerativeModel(
  650. modelName: testModelName,
  651. modelResourceName: testModelResourceName,
  652. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  653. apiConfig: apiConfig,
  654. tools: [.googleSearch()],
  655. requestOptions: RequestOptions(),
  656. urlSession: urlSession
  657. )
  658. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  659. forResource: "unary-success-basic-reply-short",
  660. withExtension: "json",
  661. subdirectory: vertexSubdirectory
  662. )
  663. _ = try await model.generateContent(testPrompt)
  664. }
  665. func testGenerateContent_failure_invalidAPIKey() async throws {
  666. let expectedStatusCode = 400
  667. MockURLProtocol
  668. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  669. forResource: "unary-failure-api-key",
  670. withExtension: "json",
  671. subdirectory: vertexSubdirectory,
  672. statusCode: expectedStatusCode
  673. )
  674. do {
  675. _ = try await model.generateContent(testPrompt)
  676. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  677. } catch let GenerateContentError.internalError(error as BackendError) {
  678. XCTAssertEqual(error.httpResponseCode, 400)
  679. XCTAssertEqual(error.status, .invalidArgument)
  680. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  681. XCTAssertTrue(error.localizedDescription.contains(error.message))
  682. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  683. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  684. let nsError = error as NSError
  685. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  686. XCTAssertEqual(nsError.code, error.httpResponseCode)
  687. return
  688. } catch {
  689. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  690. }
  691. }
  692. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  693. let expectedStatusCode = 403
  694. MockURLProtocol
  695. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  696. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  697. withExtension: "json",
  698. subdirectory: vertexSubdirectory,
  699. statusCode: expectedStatusCode
  700. )
  701. do {
  702. _ = try await model.generateContent(testPrompt)
  703. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  704. } catch let GenerateContentError.internalError(error as BackendError) {
  705. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  706. XCTAssertEqual(error.status, .permissionDenied)
  707. XCTAssertTrue(error.message
  708. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  709. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  710. return
  711. } catch {
  712. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  713. }
  714. }
  715. func testGenerateContent_failure_emptyContent() async throws {
  716. MockURLProtocol
  717. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  718. forResource: "unary-failure-empty-content",
  719. withExtension: "json",
  720. subdirectory: vertexSubdirectory
  721. )
  722. do {
  723. _ = try await model.generateContent(testPrompt)
  724. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  725. } catch let GenerateContentError
  726. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  727. guard case let .emptyContent(underlyingError) = invalidCandidateError else {
  728. XCTFail("Should be an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  729. return
  730. }
  731. _ = try XCTUnwrap(underlyingError as? Candidate.EmptyContentError,
  732. "Should be an empty content error: \(underlyingError)")
  733. } catch {
  734. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  735. }
  736. }
  737. func testGenerateContent_failure_finishReasonSafety() async throws {
  738. MockURLProtocol
  739. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  740. forResource: "unary-failure-finish-reason-safety",
  741. withExtension: "json",
  742. subdirectory: vertexSubdirectory
  743. )
  744. do {
  745. _ = try await model.generateContent(testPrompt)
  746. XCTFail("Should throw")
  747. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  748. XCTAssertEqual(reason, .safety)
  749. XCTAssertEqual(response.text, "<redacted>")
  750. } catch {
  751. XCTFail("Should throw a responseStoppedEarly")
  752. }
  753. }
  754. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  755. MockURLProtocol
  756. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  757. forResource: "unary-failure-finish-reason-safety-no-content",
  758. withExtension: "json",
  759. subdirectory: vertexSubdirectory
  760. )
  761. do {
  762. _ = try await model.generateContent(testPrompt)
  763. XCTFail("Should throw")
  764. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  765. XCTAssertEqual(reason, .safety)
  766. XCTAssertNil(response.text)
  767. } catch {
  768. XCTFail("Should throw a responseStoppedEarly")
  769. }
  770. }
  771. func testGenerateContent_failure_imageRejected() async throws {
  772. let expectedStatusCode = 400
  773. MockURLProtocol
  774. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  775. forResource: "unary-failure-image-rejected",
  776. withExtension: "json",
  777. subdirectory: vertexSubdirectory,
  778. statusCode: 400
  779. )
  780. do {
  781. _ = try await model.generateContent(testPrompt)
  782. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  783. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  784. XCTAssertEqual(rpcError.status, .invalidArgument)
  785. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  786. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  787. } catch {
  788. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  789. }
  790. }
  791. func testGenerateContent_failure_promptBlockedSafety() async throws {
  792. MockURLProtocol
  793. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  794. forResource: "unary-failure-prompt-blocked-safety",
  795. withExtension: "json",
  796. subdirectory: vertexSubdirectory
  797. )
  798. do {
  799. _ = try await model.generateContent(testPrompt)
  800. XCTFail("Should throw")
  801. } catch let GenerateContentError.promptBlocked(response) {
  802. XCTAssertNil(response.text)
  803. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  804. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  805. XCTAssertNil(promptFeedback.blockReasonMessage)
  806. } catch {
  807. XCTFail("Should throw a promptBlocked")
  808. }
  809. }
  810. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  811. MockURLProtocol
  812. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  813. forResource: "unary-failure-prompt-blocked-safety-with-message",
  814. withExtension: "json",
  815. subdirectory: vertexSubdirectory
  816. )
  817. do {
  818. _ = try await model.generateContent(testPrompt)
  819. XCTFail("Should throw")
  820. } catch let GenerateContentError.promptBlocked(response) {
  821. XCTAssertNil(response.text)
  822. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  823. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  824. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  825. } catch {
  826. XCTFail("Should throw a promptBlocked")
  827. }
  828. }
  829. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  830. MockURLProtocol
  831. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  832. forResource: "unary-failure-unknown-enum-finish-reason",
  833. withExtension: "json",
  834. subdirectory: vertexSubdirectory
  835. )
  836. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  837. do {
  838. _ = try await model.generateContent(testPrompt)
  839. XCTFail("Should throw")
  840. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  841. XCTAssertEqual(reason, unknownFinishReason)
  842. XCTAssertEqual(response.text, "Some text")
  843. } catch {
  844. XCTFail("Should throw a responseStoppedEarly")
  845. }
  846. }
  847. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  848. MockURLProtocol
  849. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  850. forResource: "unary-failure-unknown-enum-prompt-blocked",
  851. withExtension: "json",
  852. subdirectory: vertexSubdirectory
  853. )
  854. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  855. do {
  856. _ = try await model.generateContent(testPrompt)
  857. XCTFail("Should throw")
  858. } catch let GenerateContentError.promptBlocked(response) {
  859. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  860. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  861. } catch {
  862. XCTFail("Should throw a promptBlocked")
  863. }
  864. }
  865. func testGenerateContent_failure_unknownModel() async throws {
  866. let expectedStatusCode = 404
  867. MockURLProtocol
  868. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  869. forResource: "unary-failure-unknown-model",
  870. withExtension: "json",
  871. subdirectory: vertexSubdirectory,
  872. statusCode: 404
  873. )
  874. do {
  875. _ = try await model.generateContent(testPrompt)
  876. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  877. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  878. XCTAssertEqual(rpcError.status, .notFound)
  879. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  880. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  881. } catch {
  882. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  883. }
  884. }
  885. func testGenerateContent_failure_nonHTTPResponse() async throws {
  886. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  887. var responseError: Error?
  888. var content: GenerateContentResponse?
  889. do {
  890. content = try await model.generateContent(testPrompt)
  891. } catch {
  892. responseError = error
  893. }
  894. XCTAssertNil(content)
  895. XCTAssertNotNil(responseError)
  896. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  897. guard case let .internalError(underlyingError) = generateContentError else {
  898. XCTFail("Should be an internal error: \(generateContentError)")
  899. return
  900. }
  901. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  902. let underlyingNSError = underlyingError as NSError
  903. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  904. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  905. }
  906. func testGenerateContent_failure_invalidResponse() async throws {
  907. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  908. forResource: "unary-failure-invalid-response",
  909. withExtension: "json",
  910. subdirectory: vertexSubdirectory
  911. )
  912. var responseError: Error?
  913. var content: GenerateContentResponse?
  914. do {
  915. content = try await model.generateContent(testPrompt)
  916. } catch {
  917. responseError = error
  918. }
  919. XCTAssertNil(content)
  920. XCTAssertNotNil(responseError)
  921. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  922. guard case let .internalError(underlyingError) = generateContentError else {
  923. XCTFail("Should be an internal error: \(generateContentError)")
  924. return
  925. }
  926. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  927. guard case let .dataCorrupted(context) = decodingError else {
  928. XCTFail("Should be a data corrupted error: \(decodingError)")
  929. return
  930. }
  931. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  932. }
  933. func testGenerateContent_failure_malformedContent() async throws {
  934. MockURLProtocol
  935. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  936. // Note: Although this file does not contain `parts` in `content`, it is not actually
  937. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  938. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  939. forResource: "unary-failure-malformed-content",
  940. withExtension: "json",
  941. subdirectory: vertexSubdirectory
  942. )
  943. var responseError: Error?
  944. var content: GenerateContentResponse?
  945. do {
  946. content = try await model.generateContent(testPrompt)
  947. } catch {
  948. responseError = error
  949. }
  950. XCTAssertNil(content)
  951. XCTAssertNotNil(responseError)
  952. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  953. guard case let .internalError(underlyingError) = generateContentError else {
  954. XCTFail("Should be an internal error: \(generateContentError)")
  955. return
  956. }
  957. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  958. guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
  959. XCTFail("Should be an empty content error: \(invalidCandidateError)")
  960. return
  961. }
  962. _ = try XCTUnwrap(
  963. emptyContentUnderlyingError as? Candidate.EmptyContentError,
  964. "Should be an empty content error: \(emptyContentUnderlyingError)"
  965. )
  966. }
  967. func testGenerateContentMissingSafetyRatings() async throws {
  968. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  969. forResource: "unary-success-missing-safety-ratings",
  970. withExtension: "json",
  971. subdirectory: vertexSubdirectory
  972. )
  973. let content = try await model.generateContent(testPrompt)
  974. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  975. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  976. XCTAssertEqual(content.text, "This is the generated content.")
  977. }
  978. func testGenerateContent_requestOptions_customTimeout() async throws {
  979. let expectedTimeout = 150.0
  980. MockURLProtocol
  981. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  982. forResource: "unary-success-basic-reply-short",
  983. withExtension: "json",
  984. subdirectory: vertexSubdirectory,
  985. timeout: expectedTimeout
  986. )
  987. let requestOptions = RequestOptions(timeout: expectedTimeout)
  988. model = GenerativeModel(
  989. modelName: testModelName,
  990. modelResourceName: testModelResourceName,
  991. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  992. apiConfig: apiConfig,
  993. tools: nil,
  994. requestOptions: requestOptions,
  995. urlSession: urlSession
  996. )
  997. let response = try await model.generateContent(testPrompt)
  998. XCTAssertEqual(response.candidates.count, 1)
  999. }
  1000. // MARK: - Generate Content (Streaming)
  1001. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  1002. MockURLProtocol
  1003. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1004. forResource: "unary-failure-api-key",
  1005. withExtension: "json",
  1006. subdirectory: vertexSubdirectory
  1007. )
  1008. do {
  1009. let stream = try model.generateContentStream("Hi")
  1010. for try await _ in stream {
  1011. XCTFail("No content is there, this shouldn't happen.")
  1012. }
  1013. } catch let GenerateContentError.internalError(error as BackendError) {
  1014. XCTAssertEqual(error.httpResponseCode, 400)
  1015. XCTAssertEqual(error.status, .invalidArgument)
  1016. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  1017. XCTAssertTrue(error.localizedDescription.contains(error.message))
  1018. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  1019. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  1020. let nsError = error as NSError
  1021. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  1022. XCTAssertEqual(nsError.code, error.httpResponseCode)
  1023. return
  1024. }
  1025. XCTFail("Should have caught an error.")
  1026. }
  1027. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  1028. let expectedStatusCode = 403
  1029. MockURLProtocol
  1030. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1031. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  1032. withExtension: "json",
  1033. subdirectory: vertexSubdirectory,
  1034. statusCode: expectedStatusCode
  1035. )
  1036. do {
  1037. let stream = try model.generateContentStream(testPrompt)
  1038. for try await _ in stream {
  1039. XCTFail("No content is there, this shouldn't happen.")
  1040. }
  1041. } catch let GenerateContentError.internalError(error as BackendError) {
  1042. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  1043. XCTAssertEqual(error.status, .permissionDenied)
  1044. XCTAssertTrue(error.message
  1045. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  1046. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  1047. return
  1048. }
  1049. XCTFail("Should have caught an error.")
  1050. }
  1051. func testGenerateContentStream_failureEmptyContent() async throws {
  1052. MockURLProtocol
  1053. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1054. forResource: "streaming-failure-empty-content",
  1055. withExtension: "txt",
  1056. subdirectory: vertexSubdirectory
  1057. )
  1058. do {
  1059. let stream = try model.generateContentStream("Hi")
  1060. for try await _ in stream {
  1061. XCTFail("No content is there, this shouldn't happen.")
  1062. }
  1063. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  1064. // Underlying error is as expected, nothing else to check.
  1065. return
  1066. }
  1067. XCTFail("Should have caught an error.")
  1068. }
  1069. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  1070. MockURLProtocol
  1071. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1072. forResource: "streaming-failure-finish-reason-safety",
  1073. withExtension: "txt",
  1074. subdirectory: vertexSubdirectory
  1075. )
  1076. do {
  1077. let stream = try model.generateContentStream("Hi")
  1078. for try await _ in stream {
  1079. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1080. }
  1081. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  1082. XCTAssertEqual(reason, .safety)
  1083. let candidate = try XCTUnwrap(response.candidates.first)
  1084. XCTAssertEqual(candidate.finishReason, reason)
  1085. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  1086. return
  1087. }
  1088. XCTFail("Should have caught an error.")
  1089. }
  1090. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  1091. MockURLProtocol
  1092. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1093. forResource: "streaming-failure-prompt-blocked-safety",
  1094. withExtension: "txt",
  1095. subdirectory: vertexSubdirectory
  1096. )
  1097. do {
  1098. let stream = try model.generateContentStream("Hi")
  1099. for try await _ in stream {
  1100. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1101. }
  1102. } catch let GenerateContentError.promptBlocked(response) {
  1103. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1104. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1105. XCTAssertNil(promptFeedback.blockReasonMessage)
  1106. return
  1107. }
  1108. XCTFail("Should have caught an error.")
  1109. }
  1110. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  1111. MockURLProtocol
  1112. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1113. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  1114. withExtension: "txt",
  1115. subdirectory: vertexSubdirectory
  1116. )
  1117. do {
  1118. let stream = try model.generateContentStream("Hi")
  1119. for try await _ in stream {
  1120. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1121. }
  1122. } catch let GenerateContentError.promptBlocked(response) {
  1123. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1124. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1125. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  1126. return
  1127. }
  1128. XCTFail("Should have caught an error.")
  1129. }
  1130. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  1131. MockURLProtocol
  1132. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1133. forResource: "streaming-failure-unknown-finish-enum",
  1134. withExtension: "txt",
  1135. subdirectory: vertexSubdirectory
  1136. )
  1137. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  1138. let stream = try model.generateContentStream("Hi")
  1139. do {
  1140. for try await content in stream {
  1141. XCTAssertNotNil(content.text)
  1142. }
  1143. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  1144. XCTAssertEqual(reason, unknownFinishReason)
  1145. return
  1146. }
  1147. XCTFail("Should have caught an error.")
  1148. }
  1149. func testGenerateContentStream_successBasicReplyLong() async throws {
  1150. MockURLProtocol
  1151. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1152. forResource: "streaming-success-basic-reply-long",
  1153. withExtension: "txt",
  1154. subdirectory: vertexSubdirectory
  1155. )
  1156. var responses = 0
  1157. let stream = try model.generateContentStream("Hi")
  1158. for try await content in stream {
  1159. XCTAssertNotNil(content.text)
  1160. responses += 1
  1161. }
  1162. XCTAssertEqual(responses, 4)
  1163. }
  1164. func testGenerateContentStream_successBasicReplyShort() async throws {
  1165. MockURLProtocol
  1166. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1167. forResource: "streaming-success-basic-reply-short",
  1168. withExtension: "txt",
  1169. subdirectory: vertexSubdirectory
  1170. )
  1171. var responses = 0
  1172. let stream = try model.generateContentStream("Hi")
  1173. for try await content in stream {
  1174. XCTAssertNotNil(content.text)
  1175. responses += 1
  1176. }
  1177. XCTAssertEqual(responses, 1)
  1178. }
  1179. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1180. MockURLProtocol
  1181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1182. forResource: "streaming-success-unknown-safety-enum",
  1183. withExtension: "txt",
  1184. subdirectory: vertexSubdirectory
  1185. )
  1186. let unknownSafetyRating = SafetyRating(
  1187. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1188. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1189. probabilityScore: 0.0,
  1190. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1191. severityScore: 0.0,
  1192. blocked: false
  1193. )
  1194. var foundUnknownSafetyRating = false
  1195. let stream = try model.generateContentStream("Hi")
  1196. for try await content in stream {
  1197. XCTAssertNotNil(content.text)
  1198. if let ratings = content.candidates.first?.safetyRatings,
  1199. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1200. foundUnknownSafetyRating = true
  1201. }
  1202. }
  1203. XCTAssertTrue(foundUnknownSafetyRating)
  1204. }
  1205. func testGenerateContentStream_successWithCitations() async throws {
  1206. MockURLProtocol
  1207. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1208. forResource: "streaming-success-citations",
  1209. withExtension: "txt",
  1210. subdirectory: vertexSubdirectory
  1211. )
  1212. let expectedPublicationDate = DateComponents(
  1213. calendar: Calendar(identifier: .gregorian),
  1214. year: 2014,
  1215. month: 3,
  1216. day: 30
  1217. )
  1218. let stream = try model.generateContentStream("Hi")
  1219. var citations = [Citation]()
  1220. var responses = [GenerateContentResponse]()
  1221. for try await content in stream {
  1222. responses.append(content)
  1223. XCTAssertNotNil(content.text)
  1224. let candidate = try XCTUnwrap(content.candidates.first)
  1225. if let sources = candidate.citationMetadata?.citations {
  1226. citations.append(contentsOf: sources)
  1227. }
  1228. }
  1229. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1230. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1231. XCTAssertEqual(citations.count, 6)
  1232. XCTAssertTrue(citations
  1233. .contains {
  1234. $0.startIndex == 0 && $0.endIndex == 128
  1235. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1236. && $0.license == nil && $0.publicationDate == nil
  1237. })
  1238. XCTAssertTrue(citations
  1239. .contains {
  1240. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1241. && $0.title == "some-citation-2" && $0.license == nil
  1242. && $0.publicationDate == expectedPublicationDate
  1243. })
  1244. XCTAssertTrue(citations
  1245. .contains {
  1246. $0.startIndex == 272 && $0.endIndex == 431
  1247. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1248. && $0.license == "mit" && $0.publicationDate == nil
  1249. })
  1250. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1251. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1252. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1253. }
  1254. func testGenerateContentStream_successWithThinking_thoughtSummary() async throws {
  1255. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1256. forResource: "streaming-success-thinking-reply-thought-summary",
  1257. withExtension: "txt",
  1258. subdirectory: vertexSubdirectory
  1259. )
  1260. var thoughtSummary = ""
  1261. var text = ""
  1262. let stream = try model.generateContentStream("Hi")
  1263. for try await response in stream {
  1264. let candidate = try XCTUnwrap(response.candidates.first)
  1265. XCTAssertEqual(candidate.content.parts.count, 1)
  1266. let part = try XCTUnwrap(candidate.content.parts.first)
  1267. let textPart = try XCTUnwrap(part as? TextPart)
  1268. if textPart.isThought {
  1269. let newThought = try XCTUnwrap(response.thoughtSummary)
  1270. thoughtSummary.append(newThought)
  1271. } else {
  1272. text.append(textPart.text)
  1273. }
  1274. }
  1275. XCTAssertTrue(thoughtSummary.hasPrefix("**Understanding the Core Question**"))
  1276. XCTAssertTrue(text.hasPrefix("The sky is blue due to a phenomenon"))
  1277. }
  1278. func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
  1279. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1280. forResource: "streaming-success-image-invalid-safety-ratings",
  1281. withExtension: "txt",
  1282. subdirectory: vertexSubdirectory
  1283. )
  1284. let stream = try model.generateContentStream(testPrompt)
  1285. var responses = [GenerateContentResponse]()
  1286. for try await content in stream {
  1287. responses.append(content)
  1288. }
  1289. let response = try XCTUnwrap(responses.first)
  1290. XCTAssertEqual(response.candidates.count, 1)
  1291. let candidate = try XCTUnwrap(response.candidates.first)
  1292. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  1293. XCTAssertEqual(candidate.content.parts.count, 1)
  1294. let inlineDataParts = response.inlineDataParts
  1295. XCTAssertEqual(inlineDataParts.count, 1)
  1296. let imagePart = try XCTUnwrap(inlineDataParts.first)
  1297. XCTAssertEqual(imagePart.mimeType, "image/png")
  1298. XCTAssertGreaterThan(imagePart.data.count, 0)
  1299. }
  1300. func testGenerateContentStream_appCheck_validToken() async throws {
  1301. let appCheckToken = "test-valid-token"
  1302. model = GenerativeModel(
  1303. modelName: testModelName,
  1304. modelResourceName: testModelResourceName,
  1305. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1306. appCheck: AppCheckInteropFake(token: appCheckToken)
  1307. ),
  1308. apiConfig: apiConfig,
  1309. tools: nil,
  1310. requestOptions: RequestOptions(),
  1311. urlSession: urlSession
  1312. )
  1313. MockURLProtocol
  1314. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1315. forResource: "streaming-success-basic-reply-short",
  1316. withExtension: "txt",
  1317. subdirectory: vertexSubdirectory,
  1318. appCheckToken: appCheckToken
  1319. )
  1320. let stream = try model.generateContentStream(testPrompt)
  1321. for try await _ in stream {}
  1322. }
  1323. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1324. model = GenerativeModel(
  1325. modelName: testModelName,
  1326. modelResourceName: testModelResourceName,
  1327. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1328. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  1329. ),
  1330. apiConfig: apiConfig,
  1331. tools: nil,
  1332. requestOptions: RequestOptions(),
  1333. urlSession: urlSession
  1334. )
  1335. MockURLProtocol
  1336. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1337. forResource: "streaming-success-basic-reply-short",
  1338. withExtension: "txt",
  1339. subdirectory: vertexSubdirectory,
  1340. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1341. )
  1342. let stream = try model.generateContentStream(testPrompt)
  1343. for try await _ in stream {}
  1344. }
  1345. func testGenerateContentStream_usageMetadata() async throws {
  1346. MockURLProtocol
  1347. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1348. forResource: "streaming-success-basic-reply-short",
  1349. withExtension: "txt",
  1350. subdirectory: vertexSubdirectory
  1351. )
  1352. var responses = [GenerateContentResponse]()
  1353. let stream = try model.generateContentStream(testPrompt)
  1354. for try await response in stream {
  1355. responses.append(response)
  1356. }
  1357. for (index, response) in responses.enumerated() {
  1358. if index == responses.endIndex - 1 {
  1359. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1360. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1361. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1362. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1363. } else {
  1364. // Only the last streamed response contains usage metadata
  1365. XCTAssertNil(response.usageMetadata)
  1366. }
  1367. }
  1368. }
  1369. func testGenerateContentStream_errorMidStream() async throws {
  1370. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1371. forResource: "streaming-failure-error-mid-stream",
  1372. withExtension: "txt",
  1373. subdirectory: vertexSubdirectory
  1374. )
  1375. var responseCount = 0
  1376. do {
  1377. let stream = try model.generateContentStream("Hi")
  1378. for try await content in stream {
  1379. XCTAssertNotNil(content.text)
  1380. responseCount += 1
  1381. }
  1382. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1383. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1384. XCTAssertEqual(rpcError.status, .cancelled)
  1385. // Check the content count is correct.
  1386. XCTAssertEqual(responseCount, 2)
  1387. return
  1388. }
  1389. XCTFail("Expected an internalError with an RPCError.")
  1390. }
  1391. func testGenerateContentStream_nonHTTPResponse() async throws {
  1392. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  1393. let stream = try model.generateContentStream("Hi")
  1394. do {
  1395. for try await content in stream {
  1396. XCTFail("Unexpected content in stream: \(content)")
  1397. }
  1398. } catch let GenerateContentError.internalError(underlying) {
  1399. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1400. return
  1401. }
  1402. XCTFail("Expected an internal error.")
  1403. }
  1404. func testGenerateContentStream_invalidResponse() async throws {
  1405. MockURLProtocol
  1406. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1407. forResource: "streaming-failure-invalid-json",
  1408. withExtension: "txt",
  1409. subdirectory: vertexSubdirectory
  1410. )
  1411. let stream = try model.generateContentStream(testPrompt)
  1412. do {
  1413. for try await content in stream {
  1414. XCTFail("Unexpected content in stream: \(content)")
  1415. }
  1416. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1417. guard case let .dataCorrupted(context) = underlying else {
  1418. XCTFail("Should be a data corrupted error: \(underlying)")
  1419. return
  1420. }
  1421. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1422. return
  1423. }
  1424. XCTFail("Expected an internal error.")
  1425. }
  1426. func testGenerateContentStream_malformedContent() async throws {
  1427. MockURLProtocol
  1428. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1429. // Note: Although this file does not contain `parts` in `content`, it is not actually
  1430. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  1431. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  1432. forResource: "streaming-failure-malformed-content",
  1433. withExtension: "txt",
  1434. subdirectory: vertexSubdirectory
  1435. )
  1436. let stream = try model.generateContentStream(testPrompt)
  1437. do {
  1438. for try await content in stream {
  1439. XCTFail("Unexpected content in stream: \(content)")
  1440. }
  1441. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1442. guard case let .emptyContent(contentError) = underlyingError else {
  1443. XCTFail("Should be an empty content error: \(underlyingError)")
  1444. return
  1445. }
  1446. XCTAssert(contentError is Candidate.EmptyContentError)
  1447. return
  1448. }
  1449. XCTFail("Expected an internal decoding error.")
  1450. }
  1451. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1452. let expectedTimeout = 150.0
  1453. MockURLProtocol
  1454. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1455. forResource: "streaming-success-basic-reply-short",
  1456. withExtension: "txt",
  1457. subdirectory: vertexSubdirectory,
  1458. timeout: expectedTimeout
  1459. )
  1460. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1461. model = GenerativeModel(
  1462. modelName: testModelName,
  1463. modelResourceName: testModelResourceName,
  1464. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1465. apiConfig: apiConfig,
  1466. tools: nil,
  1467. requestOptions: requestOptions,
  1468. urlSession: urlSession
  1469. )
  1470. var responses = 0
  1471. let stream = try model.generateContentStream(testPrompt)
  1472. for try await content in stream {
  1473. XCTAssertNotNil(content.text)
  1474. responses += 1
  1475. }
  1476. XCTAssertEqual(responses, 1)
  1477. }
  1478. // MARK: - Count Tokens
  1479. func testCountTokens_succeeds() async throws {
  1480. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1481. forResource: "unary-success-total-tokens",
  1482. withExtension: "json",
  1483. subdirectory: vertexSubdirectory
  1484. )
  1485. let response = try await model.countTokens("Why is the sky blue?")
  1486. XCTAssertEqual(response.totalTokens, 6)
  1487. }
  1488. func testCountTokens_succeeds_detailed() async throws {
  1489. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1490. forResource: "unary-success-detailed-token-response",
  1491. withExtension: "json",
  1492. subdirectory: vertexSubdirectory
  1493. )
  1494. let response = try await model.countTokens("Why is the sky blue?")
  1495. XCTAssertEqual(response.totalTokens, 1837)
  1496. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1497. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1498. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1499. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1500. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1501. }
  1502. func testCountTokens_succeeds_allOptions() async throws {
  1503. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1504. forResource: "unary-success-total-tokens",
  1505. withExtension: "json",
  1506. subdirectory: vertexSubdirectory
  1507. )
  1508. let generationConfig = GenerationConfig(
  1509. temperature: 0.5,
  1510. topP: 0.9,
  1511. topK: 3,
  1512. candidateCount: 1,
  1513. maxOutputTokens: 1024,
  1514. stopSequences: ["test-stop"],
  1515. responseMIMEType: "text/plain"
  1516. )
  1517. let sumFunction = FunctionDeclaration(
  1518. name: "sum",
  1519. description: "Add two integers.",
  1520. parameters: ["x": .integer(), "y": .integer()]
  1521. )
  1522. let systemInstruction = ModelContent(
  1523. role: "system",
  1524. parts: "You are a calculator. Use the provided tools."
  1525. )
  1526. model = GenerativeModel(
  1527. modelName: testModelName,
  1528. modelResourceName: testModelResourceName,
  1529. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1530. apiConfig: apiConfig,
  1531. generationConfig: generationConfig,
  1532. tools: [Tool(functionDeclarations: [sumFunction])],
  1533. systemInstruction: systemInstruction,
  1534. requestOptions: RequestOptions(),
  1535. urlSession: urlSession
  1536. )
  1537. let response = try await model.countTokens("Why is the sky blue?")
  1538. XCTAssertEqual(response.totalTokens, 6)
  1539. }
  1540. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1541. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1542. forResource: "unary-success-no-billable-characters",
  1543. withExtension: "json",
  1544. subdirectory: vertexSubdirectory
  1545. )
  1546. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1547. XCTAssertEqual(response.totalTokens, 258)
  1548. }
  1549. func testCountTokens_modelNotFound() async throws {
  1550. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1551. forResource: "unary-failure-model-not-found", withExtension: "json",
  1552. subdirectory: vertexSubdirectory,
  1553. statusCode: 404
  1554. )
  1555. do {
  1556. _ = try await model.countTokens("Why is the sky blue?")
  1557. XCTFail("Request should not have succeeded.")
  1558. } catch let rpcError as BackendError {
  1559. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1560. XCTAssertEqual(rpcError.status, .notFound)
  1561. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1562. return
  1563. }
  1564. XCTFail("Expected internal RPCError.")
  1565. }
  1566. func testCountTokens_requestOptions_customTimeout() async throws {
  1567. let expectedTimeout = 150.0
  1568. MockURLProtocol
  1569. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1570. forResource: "unary-success-total-tokens",
  1571. withExtension: "json",
  1572. subdirectory: vertexSubdirectory,
  1573. timeout: expectedTimeout
  1574. )
  1575. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1576. model = GenerativeModel(
  1577. modelName: testModelName,
  1578. modelResourceName: testModelResourceName,
  1579. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1580. apiConfig: apiConfig,
  1581. tools: nil,
  1582. requestOptions: requestOptions,
  1583. urlSession: urlSession
  1584. )
  1585. let response = try await model.countTokens(testPrompt)
  1586. XCTAssertEqual(response.totalTokens, 6)
  1587. }
  1588. }
  1589. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1590. extension SafetyRating: Swift.Comparable {
  1591. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1592. return lhs.category.rawValue < rhs.category.rawValue
  1593. }
  1594. }