GenerativeModelVertexAITests.swift 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelVertexAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let safetyRatingsInvalidIgnored = [
  57. SafetyRating(
  58. category: .hateSpeech,
  59. probability: .negligible,
  60. probabilityScore: 0.00039444832,
  61. severity: .negligible,
  62. severityScore: 0.0,
  63. blocked: false
  64. ),
  65. SafetyRating(
  66. category: .dangerousContent,
  67. probability: .negligible,
  68. probabilityScore: 0.0010654529,
  69. severity: .negligible,
  70. severityScore: 0.0049325973,
  71. blocked: false
  72. ),
  73. SafetyRating(
  74. category: .harassment,
  75. probability: .negligible,
  76. probabilityScore: 0.00026658305,
  77. severity: .negligible,
  78. severityScore: 0.0,
  79. blocked: false
  80. ),
  81. SafetyRating(
  82. category: .sexuallyExplicit,
  83. probability: .negligible,
  84. probabilityScore: 0.0013701695,
  85. severity: .negligible,
  86. severityScore: 0.07626295,
  87. blocked: false
  88. ),
  89. // Ignored Invalid Safety Ratings: {},{},{},{}
  90. ].sorted()
  91. let testModelName = "test-model"
  92. let testModelResourceName =
  93. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  94. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  95. let vertexSubdirectory = "mock-responses/vertexai"
  96. var urlSession: URLSession!
  97. var model: GenerativeModel!
  98. override func setUp() async throws {
  99. let configuration = URLSessionConfiguration.default
  100. configuration.protocolClasses = [MockURLProtocol.self]
  101. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  102. model = GenerativeModel(
  103. modelName: testModelName,
  104. modelResourceName: testModelResourceName,
  105. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  106. apiConfig: apiConfig,
  107. tools: nil,
  108. requestOptions: RequestOptions(),
  109. urlSession: urlSession
  110. )
  111. }
  112. override func tearDown() {
  113. MockURLProtocol.requestHandler = nil
  114. }
  115. // MARK: - Generate Content
  116. func testGenerateContent_success_basicReplyLong() async throws {
  117. MockURLProtocol
  118. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  119. forResource: "unary-success-basic-reply-long",
  120. withExtension: "json",
  121. subdirectory: vertexSubdirectory
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. XCTAssertEqual(candidate.safetyRatings.count, 4)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let part = try XCTUnwrap(candidate.content.parts.first)
  131. let partText = try XCTUnwrap(part as? TextPart).text
  132. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  133. XCTAssertEqual(response.text, partText)
  134. XCTAssertEqual(response.functionCalls, [])
  135. XCTAssertEqual(response.inlineDataParts, [])
  136. }
  137. func testGenerateContent_success_basicReplyShort() async throws {
  138. MockURLProtocol
  139. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  140. forResource: "unary-success-basic-reply-short",
  141. withExtension: "json",
  142. subdirectory: vertexSubdirectory
  143. )
  144. let response = try await model.generateContent(testPrompt)
  145. XCTAssertEqual(response.candidates.count, 1)
  146. let candidate = try XCTUnwrap(response.candidates.first)
  147. let finishReason = try XCTUnwrap(candidate.finishReason)
  148. XCTAssertEqual(finishReason, .stop)
  149. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  150. XCTAssertEqual(candidate.content.parts.count, 1)
  151. let part = try XCTUnwrap(candidate.content.parts.first)
  152. let textPart = try XCTUnwrap(part as? TextPart)
  153. XCTAssertEqual(textPart.text, "Mountain View, California")
  154. XCTAssertEqual(response.text, textPart.text)
  155. XCTAssertEqual(response.functionCalls, [])
  156. }
  157. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  158. MockURLProtocol
  159. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  160. forResource: "unary-success-basic-response-long-usage-metadata",
  161. withExtension: "json",
  162. subdirectory: vertexSubdirectory
  163. )
  164. let response = try await model.generateContent(testPrompt)
  165. XCTAssertEqual(response.candidates.count, 1)
  166. let candidate = try XCTUnwrap(response.candidates.first)
  167. let finishReason = try XCTUnwrap(candidate.finishReason)
  168. XCTAssertEqual(finishReason, .stop)
  169. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  170. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  171. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  175. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  178. }
  179. func testGenerateContent_success_citations() async throws {
  180. MockURLProtocol
  181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  182. forResource: "unary-success-citations",
  183. withExtension: "json",
  184. subdirectory: vertexSubdirectory
  185. )
  186. let expectedPublicationDate = DateComponents(
  187. calendar: Calendar(identifier: .gregorian),
  188. year: 2019,
  189. month: 5,
  190. day: 10
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. XCTAssertEqual(candidate.content.parts.count, 1)
  196. XCTAssertEqual(response.text, "Some information cited from an external source")
  197. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  198. XCTAssertEqual(citationMetadata.citations.count, 3)
  199. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  200. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  201. XCTAssertEqual(citationSource1.startIndex, 0)
  202. XCTAssertEqual(citationSource1.endIndex, 128)
  203. XCTAssertNil(citationSource1.title)
  204. XCTAssertNil(citationSource1.license)
  205. XCTAssertNil(citationSource1.publicationDate)
  206. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  207. XCTAssertEqual(citationSource2.title, "some-citation-2")
  208. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  209. XCTAssertEqual(citationSource2.startIndex, 130)
  210. XCTAssertEqual(citationSource2.endIndex, 265)
  211. XCTAssertNil(citationSource2.uri)
  212. XCTAssertNil(citationSource2.license)
  213. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  214. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  215. XCTAssertEqual(citationSource3.startIndex, 272)
  216. XCTAssertEqual(citationSource3.endIndex, 431)
  217. XCTAssertEqual(citationSource3.license, "mit")
  218. XCTAssertNil(citationSource3.title)
  219. XCTAssertNil(citationSource3.publicationDate)
  220. }
  221. func testGenerateContent_success_quoteReply() async throws {
  222. MockURLProtocol
  223. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  224. forResource: "unary-success-quote-reply",
  225. withExtension: "json",
  226. subdirectory: vertexSubdirectory
  227. )
  228. let response = try await model.generateContent(testPrompt)
  229. XCTAssertEqual(response.candidates.count, 1)
  230. let candidate = try XCTUnwrap(response.candidates.first)
  231. let finishReason = try XCTUnwrap(candidate.finishReason)
  232. XCTAssertEqual(finishReason, .stop)
  233. XCTAssertEqual(candidate.safetyRatings.count, 4)
  234. XCTAssertEqual(candidate.content.parts.count, 1)
  235. let part = try XCTUnwrap(candidate.content.parts.first)
  236. let textPart = try XCTUnwrap(part as? TextPart)
  237. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  238. XCTAssertEqual(response.text, textPart.text)
  239. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  240. XCTAssertNil(promptFeedback.blockReason)
  241. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  242. }
  243. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  244. let expectedSafetyRatings = [
  245. SafetyRating(
  246. category: .harassment,
  247. probability: .medium,
  248. probabilityScore: 0.0,
  249. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  250. severityScore: 0.0,
  251. blocked: false
  252. ),
  253. SafetyRating(
  254. category: .dangerousContent,
  255. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  256. probabilityScore: 0.0,
  257. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  258. severityScore: 0.0,
  259. blocked: false
  260. ),
  261. SafetyRating(
  262. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  263. probability: .high,
  264. probabilityScore: 0.0,
  265. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  266. severityScore: 0.0,
  267. blocked: false
  268. ),
  269. ]
  270. MockURLProtocol
  271. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "unary-success-unknown-enum-safety-ratings",
  273. withExtension: "json",
  274. subdirectory: vertexSubdirectory
  275. )
  276. let response = try await model.generateContent(testPrompt)
  277. XCTAssertEqual(response.text, "Some text")
  278. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  279. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  280. }
  281. func testGenerateContent_success_prefixedModelName() async throws {
  282. MockURLProtocol
  283. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  284. forResource: "unary-success-basic-reply-short",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let model = GenerativeModel(
  289. modelName: testModelName,
  290. modelResourceName: testModelResourceName,
  291. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  292. apiConfig: apiConfig,
  293. tools: nil,
  294. requestOptions: RequestOptions(),
  295. urlSession: urlSession
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  300. MockURLProtocol
  301. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  302. forResource: "unary-success-function-call-empty-arguments",
  303. withExtension: "json",
  304. subdirectory: vertexSubdirectory
  305. )
  306. let response = try await model.generateContent(testPrompt)
  307. XCTAssertEqual(response.candidates.count, 1)
  308. let candidate = try XCTUnwrap(response.candidates.first)
  309. XCTAssertEqual(candidate.content.parts.count, 1)
  310. let part = try XCTUnwrap(candidate.content.parts.first)
  311. guard let functionCall = part as? FunctionCallPart else {
  312. XCTFail("Part is not a FunctionCall.")
  313. return
  314. }
  315. XCTAssertEqual(functionCall.name, "current_time")
  316. XCTAssertTrue(functionCall.args.isEmpty)
  317. XCTAssertEqual(response.functionCalls, [functionCall])
  318. }
  319. func testGenerateContent_success_functionCall_noArguments() async throws {
  320. MockURLProtocol
  321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  322. forResource: "unary-success-function-call-no-arguments",
  323. withExtension: "json",
  324. subdirectory: vertexSubdirectory
  325. )
  326. let response = try await model.generateContent(testPrompt)
  327. XCTAssertEqual(response.candidates.count, 1)
  328. let candidate = try XCTUnwrap(response.candidates.first)
  329. XCTAssertEqual(candidate.content.parts.count, 1)
  330. let part = try XCTUnwrap(candidate.content.parts.first)
  331. guard let functionCall = part as? FunctionCallPart else {
  332. XCTFail("Part is not a FunctionCall.")
  333. return
  334. }
  335. XCTAssertEqual(functionCall.name, "current_time")
  336. XCTAssertTrue(functionCall.args.isEmpty)
  337. XCTAssertEqual(response.functionCalls, [functionCall])
  338. }
  339. func testGenerateContent_success_functionCall_withArguments() async throws {
  340. MockURLProtocol
  341. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  342. forResource: "unary-success-function-call-with-arguments",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 1)
  350. let part = try XCTUnwrap(candidate.content.parts.first)
  351. guard let functionCall = part as? FunctionCallPart else {
  352. XCTFail("Part is not a FunctionCall.")
  353. return
  354. }
  355. XCTAssertEqual(functionCall.name, "sum")
  356. XCTAssertEqual(functionCall.args.count, 2)
  357. let argX = try XCTUnwrap(functionCall.args["x"])
  358. XCTAssertEqual(argX, .number(4))
  359. let argY = try XCTUnwrap(functionCall.args["y"])
  360. XCTAssertEqual(argY, .number(5))
  361. XCTAssertEqual(response.functionCalls, [functionCall])
  362. }
  363. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  364. MockURLProtocol
  365. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  366. forResource: "unary-success-function-call-parallel-calls",
  367. withExtension: "json",
  368. subdirectory: vertexSubdirectory
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. XCTAssertEqual(response.candidates.count, 1)
  372. let candidate = try XCTUnwrap(response.candidates.first)
  373. XCTAssertEqual(candidate.content.parts.count, 3)
  374. let functionCalls = response.functionCalls
  375. XCTAssertEqual(functionCalls.count, 3)
  376. }
  377. func testGenerateContent_success_functionCall_mixedContent() async throws {
  378. MockURLProtocol
  379. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  380. forResource: "unary-success-function-call-mixed-content",
  381. withExtension: "json",
  382. subdirectory: vertexSubdirectory
  383. )
  384. let response = try await model.generateContent(testPrompt)
  385. XCTAssertEqual(response.candidates.count, 1)
  386. let candidate = try XCTUnwrap(response.candidates.first)
  387. XCTAssertEqual(candidate.content.parts.count, 4)
  388. let functionCalls = response.functionCalls
  389. XCTAssertEqual(functionCalls.count, 2)
  390. let text = try XCTUnwrap(response.text)
  391. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  392. }
  393. func testGenerateContent_success_thinking_thoughtSummary() async throws {
  394. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  395. forResource: "unary-success-thinking-reply-thought-summary",
  396. withExtension: "json",
  397. subdirectory: vertexSubdirectory
  398. )
  399. let response = try await model.generateContent(testPrompt)
  400. XCTAssertEqual(response.candidates.count, 1)
  401. let candidate = try XCTUnwrap(response.candidates.first)
  402. XCTAssertEqual(candidate.finishReason, .stop)
  403. XCTAssertEqual(candidate.content.parts.count, 2)
  404. let thoughtPart = try XCTUnwrap(candidate.content.parts.first as? TextPart)
  405. XCTAssertTrue(thoughtPart.isThought)
  406. XCTAssertTrue(thoughtPart.text.hasPrefix("Right, someone needs the city where Google"))
  407. XCTAssertEqual(response.thoughtSummary, thoughtPart.text)
  408. let textPart = try XCTUnwrap(candidate.content.parts.last as? TextPart)
  409. XCTAssertFalse(textPart.isThought)
  410. XCTAssertEqual(textPart.text, "Mountain View")
  411. XCTAssertEqual(response.text, textPart.text)
  412. }
  413. func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
  414. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  415. forResource: "unary-success-image-invalid-safety-ratings",
  416. withExtension: "json",
  417. subdirectory: vertexSubdirectory
  418. )
  419. let response = try await model.generateContent(testPrompt)
  420. XCTAssertEqual(response.candidates.count, 1)
  421. let candidate = try XCTUnwrap(response.candidates.first)
  422. XCTAssertEqual(candidate.content.parts.count, 1)
  423. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  424. let inlineDataParts = response.inlineDataParts
  425. XCTAssertEqual(inlineDataParts.count, 1)
  426. let imagePart = try XCTUnwrap(inlineDataParts.first)
  427. XCTAssertEqual(imagePart.mimeType, "image/png")
  428. XCTAssertGreaterThan(imagePart.data.count, 0)
  429. }
  430. func testGenerateContent_appCheck_validToken() async throws {
  431. let appCheckToken = "test-valid-token"
  432. model = GenerativeModel(
  433. modelName: testModelName,
  434. modelResourceName: testModelResourceName,
  435. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  436. appCheck: AppCheckInteropFake(token: appCheckToken)
  437. ),
  438. apiConfig: apiConfig,
  439. tools: nil,
  440. requestOptions: RequestOptions(),
  441. urlSession: urlSession
  442. )
  443. MockURLProtocol
  444. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  445. forResource: "unary-success-basic-reply-short",
  446. withExtension: "json",
  447. subdirectory: vertexSubdirectory,
  448. appCheckToken: appCheckToken
  449. )
  450. _ = try await model.generateContent(testPrompt)
  451. }
  452. func testGenerateContent_appCheck_validToken_limitedUse() async throws {
  453. let appCheckToken = "test-valid-token"
  454. model = GenerativeModel(
  455. modelName: testModelName,
  456. modelResourceName: testModelResourceName,
  457. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  458. appCheck: AppCheckInteropFake(token: appCheckToken),
  459. useLimitedUseAppCheckTokens: true
  460. ),
  461. apiConfig: apiConfig,
  462. tools: nil,
  463. requestOptions: RequestOptions(),
  464. urlSession: urlSession
  465. )
  466. MockURLProtocol
  467. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  468. forResource: "unary-success-basic-reply-short",
  469. withExtension: "json",
  470. subdirectory: vertexSubdirectory,
  471. appCheckToken: "limited_use_\(appCheckToken)"
  472. )
  473. _ = try await model.generateContent(testPrompt)
  474. }
  475. func testGenerateContent_dataCollectionOff() async throws {
  476. let appCheckToken = "test-valid-token"
  477. model = GenerativeModel(
  478. modelName: testModelName,
  479. modelResourceName: testModelResourceName,
  480. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  481. appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
  482. ),
  483. apiConfig: apiConfig,
  484. tools: nil,
  485. requestOptions: RequestOptions(),
  486. urlSession: urlSession
  487. )
  488. MockURLProtocol
  489. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  490. forResource: "unary-success-basic-reply-short",
  491. withExtension: "json",
  492. subdirectory: vertexSubdirectory,
  493. appCheckToken: appCheckToken,
  494. dataCollection: false
  495. )
  496. _ = try await model.generateContent(testPrompt)
  497. }
  498. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  499. model = GenerativeModel(
  500. modelName: testModelName,
  501. modelResourceName: testModelResourceName,
  502. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  503. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  504. ),
  505. apiConfig: apiConfig,
  506. tools: nil,
  507. requestOptions: RequestOptions(),
  508. urlSession: urlSession
  509. )
  510. MockURLProtocol
  511. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  512. forResource: "unary-success-basic-reply-short",
  513. withExtension: "json",
  514. subdirectory: vertexSubdirectory,
  515. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  516. )
  517. _ = try await model.generateContent(testPrompt)
  518. }
  519. func testGenerateContent_auth_validAuthToken() async throws {
  520. let authToken = "test-valid-token"
  521. model = GenerativeModel(
  522. modelName: testModelName,
  523. modelResourceName: testModelResourceName,
  524. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  525. auth: AuthInteropFake(token: authToken)
  526. ),
  527. apiConfig: apiConfig,
  528. tools: nil,
  529. requestOptions: RequestOptions(),
  530. urlSession: urlSession
  531. )
  532. MockURLProtocol
  533. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  534. forResource: "unary-success-basic-reply-short",
  535. withExtension: "json",
  536. subdirectory: vertexSubdirectory,
  537. authToken: authToken
  538. )
  539. _ = try await model.generateContent(testPrompt)
  540. }
  541. func testGenerateContent_auth_nilAuthToken() async throws {
  542. model = GenerativeModel(
  543. modelName: testModelName,
  544. modelResourceName: testModelResourceName,
  545. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  546. apiConfig: apiConfig,
  547. tools: nil,
  548. requestOptions: RequestOptions(),
  549. urlSession: urlSession
  550. )
  551. MockURLProtocol
  552. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  553. forResource: "unary-success-basic-reply-short",
  554. withExtension: "json",
  555. subdirectory: vertexSubdirectory,
  556. authToken: nil
  557. )
  558. _ = try await model.generateContent(testPrompt)
  559. }
  560. func testGenerateContent_auth_authTokenRefreshError() async throws {
  561. model = GenerativeModel(
  562. modelName: testModelName,
  563. modelResourceName: testModelResourceName,
  564. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  565. auth: AuthInteropFake(error: AuthErrorFake())
  566. ),
  567. apiConfig: apiConfig,
  568. tools: nil,
  569. requestOptions: RequestOptions(),
  570. urlSession: urlSession
  571. )
  572. MockURLProtocol
  573. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  574. forResource: "unary-success-basic-reply-short",
  575. withExtension: "json",
  576. subdirectory: vertexSubdirectory,
  577. authToken: nil
  578. )
  579. do {
  580. _ = try await model.generateContent(testPrompt)
  581. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  582. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  583. //
  584. } catch {
  585. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  586. }
  587. }
  588. func testGenerateContent_usageMetadata() async throws {
  589. MockURLProtocol
  590. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  591. forResource: "unary-success-basic-reply-short",
  592. withExtension: "json",
  593. subdirectory: vertexSubdirectory
  594. )
  595. let response = try await model.generateContent(testPrompt)
  596. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  597. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  598. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  599. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  600. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  601. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  602. }
  603. func testGenerateContent_groundingMetadata() async throws {
  604. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  605. forResource: "unary-success-google-search-grounding",
  606. withExtension: "json",
  607. subdirectory: vertexSubdirectory
  608. )
  609. let response = try await model.generateContent(testPrompt)
  610. XCTAssertEqual(response.candidates.count, 1)
  611. let candidate = try XCTUnwrap(response.candidates.first)
  612. let groundingMetadata = try XCTUnwrap(candidate.groundingMetadata)
  613. XCTAssertEqual(groundingMetadata.webSearchQueries, ["current weather in London"])
  614. XCTAssertNotNil(groundingMetadata.searchEntryPoint)
  615. XCTAssertNotNil(groundingMetadata.searchEntryPoint?.renderedContent)
  616. XCTAssertEqual(groundingMetadata.groundingChunks.count, 2)
  617. let firstChunk = try XCTUnwrap(groundingMetadata.groundingChunks.first?.web)
  618. XCTAssertEqual(firstChunk.title, "accuweather.com")
  619. XCTAssertNotNil(firstChunk.uri)
  620. XCTAssertNil(firstChunk.domain) // Domain is not supported by Google AI backend
  621. XCTAssertEqual(groundingMetadata.groundingSupports.count, 3)
  622. let firstSupport = try XCTUnwrap(groundingMetadata.groundingSupports.first)
  623. let segment = try XCTUnwrap(firstSupport.segment)
  624. XCTAssertEqual(segment.text, "The current weather in London, United Kingdom is cloudy.")
  625. XCTAssertEqual(segment.startIndex, 0)
  626. XCTAssertEqual(segment.partIndex, 0)
  627. XCTAssertEqual(segment.endIndex, 56)
  628. XCTAssertEqual(firstSupport.groundingChunkIndices, [0])
  629. }
  630. func testGenerateContent_withGoogleSearchTool() async throws {
  631. let model = GenerativeModel(
  632. modelName: testModelName,
  633. modelResourceName: testModelResourceName,
  634. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  635. apiConfig: apiConfig,
  636. tools: [.googleSearch()],
  637. requestOptions: RequestOptions(),
  638. urlSession: urlSession
  639. )
  640. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  641. forResource: "unary-success-basic-reply-short",
  642. withExtension: "json",
  643. subdirectory: vertexSubdirectory
  644. )
  645. _ = try await model.generateContent(testPrompt)
  646. }
  647. func testGenerateContent_failure_invalidAPIKey() async throws {
  648. let expectedStatusCode = 400
  649. MockURLProtocol
  650. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  651. forResource: "unary-failure-api-key",
  652. withExtension: "json",
  653. subdirectory: vertexSubdirectory,
  654. statusCode: expectedStatusCode
  655. )
  656. do {
  657. _ = try await model.generateContent(testPrompt)
  658. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  659. } catch let GenerateContentError.internalError(error as BackendError) {
  660. XCTAssertEqual(error.httpResponseCode, 400)
  661. XCTAssertEqual(error.status, .invalidArgument)
  662. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  663. XCTAssertTrue(error.localizedDescription.contains(error.message))
  664. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  665. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  666. let nsError = error as NSError
  667. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  668. XCTAssertEqual(nsError.code, error.httpResponseCode)
  669. return
  670. } catch {
  671. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  672. }
  673. }
  674. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  675. let expectedStatusCode = 403
  676. MockURLProtocol
  677. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  678. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  679. withExtension: "json",
  680. subdirectory: vertexSubdirectory,
  681. statusCode: expectedStatusCode
  682. )
  683. do {
  684. _ = try await model.generateContent(testPrompt)
  685. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  686. } catch let GenerateContentError.internalError(error as BackendError) {
  687. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  688. XCTAssertEqual(error.status, .permissionDenied)
  689. XCTAssertTrue(error.message
  690. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  691. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  692. return
  693. } catch {
  694. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  695. }
  696. }
  697. func testGenerateContent_failure_emptyContent() async throws {
  698. MockURLProtocol
  699. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  700. forResource: "unary-failure-empty-content",
  701. withExtension: "json",
  702. subdirectory: vertexSubdirectory
  703. )
  704. do {
  705. _ = try await model.generateContent(testPrompt)
  706. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  707. } catch let GenerateContentError
  708. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  709. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  710. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  711. return
  712. }
  713. _ = try XCTUnwrap(decodingError as? DecodingError,
  714. "Not a DecodingError: \(decodingError)")
  715. } catch {
  716. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  717. }
  718. }
  719. func testGenerateContent_failure_finishReasonSafety() async throws {
  720. MockURLProtocol
  721. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  722. forResource: "unary-failure-finish-reason-safety",
  723. withExtension: "json",
  724. subdirectory: vertexSubdirectory
  725. )
  726. do {
  727. _ = try await model.generateContent(testPrompt)
  728. XCTFail("Should throw")
  729. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  730. XCTAssertEqual(reason, .safety)
  731. XCTAssertEqual(response.text, "<redacted>")
  732. } catch {
  733. XCTFail("Should throw a responseStoppedEarly")
  734. }
  735. }
  736. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  737. MockURLProtocol
  738. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  739. forResource: "unary-failure-finish-reason-safety-no-content",
  740. withExtension: "json",
  741. subdirectory: vertexSubdirectory
  742. )
  743. do {
  744. _ = try await model.generateContent(testPrompt)
  745. XCTFail("Should throw")
  746. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  747. XCTAssertEqual(reason, .safety)
  748. XCTAssertNil(response.text)
  749. } catch {
  750. XCTFail("Should throw a responseStoppedEarly")
  751. }
  752. }
  753. func testGenerateContent_failure_imageRejected() async throws {
  754. let expectedStatusCode = 400
  755. MockURLProtocol
  756. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  757. forResource: "unary-failure-image-rejected",
  758. withExtension: "json",
  759. subdirectory: vertexSubdirectory,
  760. statusCode: 400
  761. )
  762. do {
  763. _ = try await model.generateContent(testPrompt)
  764. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  765. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  766. XCTAssertEqual(rpcError.status, .invalidArgument)
  767. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  768. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  769. } catch {
  770. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  771. }
  772. }
  773. func testGenerateContent_failure_promptBlockedSafety() async throws {
  774. MockURLProtocol
  775. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  776. forResource: "unary-failure-prompt-blocked-safety",
  777. withExtension: "json",
  778. subdirectory: vertexSubdirectory
  779. )
  780. do {
  781. _ = try await model.generateContent(testPrompt)
  782. XCTFail("Should throw")
  783. } catch let GenerateContentError.promptBlocked(response) {
  784. XCTAssertNil(response.text)
  785. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  786. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  787. XCTAssertNil(promptFeedback.blockReasonMessage)
  788. } catch {
  789. XCTFail("Should throw a promptBlocked")
  790. }
  791. }
  792. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  793. MockURLProtocol
  794. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  795. forResource: "unary-failure-prompt-blocked-safety-with-message",
  796. withExtension: "json",
  797. subdirectory: vertexSubdirectory
  798. )
  799. do {
  800. _ = try await model.generateContent(testPrompt)
  801. XCTFail("Should throw")
  802. } catch let GenerateContentError.promptBlocked(response) {
  803. XCTAssertNil(response.text)
  804. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  805. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  806. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  807. } catch {
  808. XCTFail("Should throw a promptBlocked")
  809. }
  810. }
  811. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  812. MockURLProtocol
  813. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  814. forResource: "unary-failure-unknown-enum-finish-reason",
  815. withExtension: "json",
  816. subdirectory: vertexSubdirectory
  817. )
  818. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  819. do {
  820. _ = try await model.generateContent(testPrompt)
  821. XCTFail("Should throw")
  822. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  823. XCTAssertEqual(reason, unknownFinishReason)
  824. XCTAssertEqual(response.text, "Some text")
  825. } catch {
  826. XCTFail("Should throw a responseStoppedEarly")
  827. }
  828. }
  829. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  830. MockURLProtocol
  831. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  832. forResource: "unary-failure-unknown-enum-prompt-blocked",
  833. withExtension: "json",
  834. subdirectory: vertexSubdirectory
  835. )
  836. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  837. do {
  838. _ = try await model.generateContent(testPrompt)
  839. XCTFail("Should throw")
  840. } catch let GenerateContentError.promptBlocked(response) {
  841. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  842. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  843. } catch {
  844. XCTFail("Should throw a promptBlocked")
  845. }
  846. }
  847. func testGenerateContent_failure_unknownModel() async throws {
  848. let expectedStatusCode = 404
  849. MockURLProtocol
  850. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  851. forResource: "unary-failure-unknown-model",
  852. withExtension: "json",
  853. subdirectory: vertexSubdirectory,
  854. statusCode: 404
  855. )
  856. do {
  857. _ = try await model.generateContent(testPrompt)
  858. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  859. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  860. XCTAssertEqual(rpcError.status, .notFound)
  861. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  862. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  863. } catch {
  864. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  865. }
  866. }
  867. func testGenerateContent_failure_nonHTTPResponse() async throws {
  868. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  869. var responseError: Error?
  870. var content: GenerateContentResponse?
  871. do {
  872. content = try await model.generateContent(testPrompt)
  873. } catch {
  874. responseError = error
  875. }
  876. XCTAssertNil(content)
  877. XCTAssertNotNil(responseError)
  878. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  879. guard case let .internalError(underlyingError) = generateContentError else {
  880. XCTFail("Not an internal error: \(generateContentError)")
  881. return
  882. }
  883. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  884. let underlyingNSError = underlyingError as NSError
  885. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  886. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  887. }
  888. func testGenerateContent_failure_invalidResponse() async throws {
  889. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  890. forResource: "unary-failure-invalid-response",
  891. withExtension: "json",
  892. subdirectory: vertexSubdirectory
  893. )
  894. var responseError: Error?
  895. var content: GenerateContentResponse?
  896. do {
  897. content = try await model.generateContent(testPrompt)
  898. } catch {
  899. responseError = error
  900. }
  901. XCTAssertNil(content)
  902. XCTAssertNotNil(responseError)
  903. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  904. guard case let .internalError(underlyingError) = generateContentError else {
  905. XCTFail("Not an internal error: \(generateContentError)")
  906. return
  907. }
  908. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  909. guard case let .dataCorrupted(context) = decodingError else {
  910. XCTFail("Not a data corrupted error: \(decodingError)")
  911. return
  912. }
  913. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  914. }
  915. func testGenerateContent_failure_malformedContent() async throws {
  916. MockURLProtocol
  917. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  918. // Note: Although this file does not contain `parts` in `content`, it is not actually
  919. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  920. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  921. forResource: "unary-failure-malformed-content",
  922. withExtension: "json",
  923. subdirectory: vertexSubdirectory
  924. )
  925. var responseError: Error?
  926. var content: GenerateContentResponse?
  927. do {
  928. content = try await model.generateContent(testPrompt)
  929. } catch {
  930. responseError = error
  931. }
  932. XCTAssertNil(content)
  933. XCTAssertNotNil(responseError)
  934. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  935. guard case let .internalError(underlyingError) = generateContentError else {
  936. XCTFail("Not an internal error: \(generateContentError)")
  937. return
  938. }
  939. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  940. guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
  941. XCTFail("Not an empty content error: \(invalidCandidateError)")
  942. return
  943. }
  944. _ = try XCTUnwrap(
  945. emptyContentUnderlyingError as? DecodingError,
  946. "Not a decoding error: \(emptyContentUnderlyingError)"
  947. )
  948. }
  949. func testGenerateContentMissingSafetyRatings() async throws {
  950. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  951. forResource: "unary-success-missing-safety-ratings",
  952. withExtension: "json",
  953. subdirectory: vertexSubdirectory
  954. )
  955. let content = try await model.generateContent(testPrompt)
  956. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  957. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  958. XCTAssertEqual(content.text, "This is the generated content.")
  959. }
  960. func testGenerateContent_requestOptions_customTimeout() async throws {
  961. let expectedTimeout = 150.0
  962. MockURLProtocol
  963. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  964. forResource: "unary-success-basic-reply-short",
  965. withExtension: "json",
  966. subdirectory: vertexSubdirectory,
  967. timeout: expectedTimeout
  968. )
  969. let requestOptions = RequestOptions(timeout: expectedTimeout)
  970. model = GenerativeModel(
  971. modelName: testModelName,
  972. modelResourceName: testModelResourceName,
  973. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  974. apiConfig: apiConfig,
  975. tools: nil,
  976. requestOptions: requestOptions,
  977. urlSession: urlSession
  978. )
  979. let response = try await model.generateContent(testPrompt)
  980. XCTAssertEqual(response.candidates.count, 1)
  981. }
  982. // MARK: - Generate Content (Streaming)
  983. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  984. MockURLProtocol
  985. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  986. forResource: "unary-failure-api-key",
  987. withExtension: "json",
  988. subdirectory: vertexSubdirectory
  989. )
  990. do {
  991. let stream = try model.generateContentStream("Hi")
  992. for try await _ in stream {
  993. XCTFail("No content is there, this shouldn't happen.")
  994. }
  995. } catch let GenerateContentError.internalError(error as BackendError) {
  996. XCTAssertEqual(error.httpResponseCode, 400)
  997. XCTAssertEqual(error.status, .invalidArgument)
  998. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  999. XCTAssertTrue(error.localizedDescription.contains(error.message))
  1000. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  1001. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  1002. let nsError = error as NSError
  1003. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  1004. XCTAssertEqual(nsError.code, error.httpResponseCode)
  1005. return
  1006. }
  1007. XCTFail("Should have caught an error.")
  1008. }
  1009. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  1010. let expectedStatusCode = 403
  1011. MockURLProtocol
  1012. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1013. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  1014. withExtension: "json",
  1015. subdirectory: vertexSubdirectory,
  1016. statusCode: expectedStatusCode
  1017. )
  1018. do {
  1019. let stream = try model.generateContentStream(testPrompt)
  1020. for try await _ in stream {
  1021. XCTFail("No content is there, this shouldn't happen.")
  1022. }
  1023. } catch let GenerateContentError.internalError(error as BackendError) {
  1024. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  1025. XCTAssertEqual(error.status, .permissionDenied)
  1026. XCTAssertTrue(error.message
  1027. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  1028. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  1029. return
  1030. }
  1031. XCTFail("Should have caught an error.")
  1032. }
  1033. func testGenerateContentStream_failureEmptyContent() async throws {
  1034. MockURLProtocol
  1035. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1036. forResource: "streaming-failure-empty-content",
  1037. withExtension: "txt",
  1038. subdirectory: vertexSubdirectory
  1039. )
  1040. do {
  1041. let stream = try model.generateContentStream("Hi")
  1042. for try await _ in stream {
  1043. XCTFail("No content is there, this shouldn't happen.")
  1044. }
  1045. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  1046. // Underlying error is as expected, nothing else to check.
  1047. return
  1048. }
  1049. XCTFail("Should have caught an error.")
  1050. }
  1051. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  1052. MockURLProtocol
  1053. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1054. forResource: "streaming-failure-finish-reason-safety",
  1055. withExtension: "txt",
  1056. subdirectory: vertexSubdirectory
  1057. )
  1058. do {
  1059. let stream = try model.generateContentStream("Hi")
  1060. for try await _ in stream {
  1061. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1062. }
  1063. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  1064. XCTAssertEqual(reason, .safety)
  1065. let candidate = try XCTUnwrap(response.candidates.first)
  1066. XCTAssertEqual(candidate.finishReason, reason)
  1067. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  1068. return
  1069. }
  1070. XCTFail("Should have caught an error.")
  1071. }
  1072. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  1073. MockURLProtocol
  1074. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1075. forResource: "streaming-failure-prompt-blocked-safety",
  1076. withExtension: "txt",
  1077. subdirectory: vertexSubdirectory
  1078. )
  1079. do {
  1080. let stream = try model.generateContentStream("Hi")
  1081. for try await _ in stream {
  1082. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1083. }
  1084. } catch let GenerateContentError.promptBlocked(response) {
  1085. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1086. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1087. XCTAssertNil(promptFeedback.blockReasonMessage)
  1088. return
  1089. }
  1090. XCTFail("Should have caught an error.")
  1091. }
  1092. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  1093. MockURLProtocol
  1094. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1095. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  1096. withExtension: "txt",
  1097. subdirectory: vertexSubdirectory
  1098. )
  1099. do {
  1100. let stream = try model.generateContentStream("Hi")
  1101. for try await _ in stream {
  1102. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1103. }
  1104. } catch let GenerateContentError.promptBlocked(response) {
  1105. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1106. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1107. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  1108. return
  1109. }
  1110. XCTFail("Should have caught an error.")
  1111. }
  1112. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  1113. MockURLProtocol
  1114. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1115. forResource: "streaming-failure-unknown-finish-enum",
  1116. withExtension: "txt",
  1117. subdirectory: vertexSubdirectory
  1118. )
  1119. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  1120. let stream = try model.generateContentStream("Hi")
  1121. do {
  1122. for try await content in stream {
  1123. XCTAssertNotNil(content.text)
  1124. }
  1125. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  1126. XCTAssertEqual(reason, unknownFinishReason)
  1127. return
  1128. }
  1129. XCTFail("Should have caught an error.")
  1130. }
  1131. func testGenerateContentStream_successBasicReplyLong() async throws {
  1132. MockURLProtocol
  1133. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1134. forResource: "streaming-success-basic-reply-long",
  1135. withExtension: "txt",
  1136. subdirectory: vertexSubdirectory
  1137. )
  1138. var responses = 0
  1139. let stream = try model.generateContentStream("Hi")
  1140. for try await content in stream {
  1141. XCTAssertNotNil(content.text)
  1142. responses += 1
  1143. }
  1144. XCTAssertEqual(responses, 4)
  1145. }
  1146. func testGenerateContentStream_successBasicReplyShort() async throws {
  1147. MockURLProtocol
  1148. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1149. forResource: "streaming-success-basic-reply-short",
  1150. withExtension: "txt",
  1151. subdirectory: vertexSubdirectory
  1152. )
  1153. var responses = 0
  1154. let stream = try model.generateContentStream("Hi")
  1155. for try await content in stream {
  1156. XCTAssertNotNil(content.text)
  1157. responses += 1
  1158. }
  1159. XCTAssertEqual(responses, 1)
  1160. }
  1161. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1162. MockURLProtocol
  1163. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1164. forResource: "streaming-success-unknown-safety-enum",
  1165. withExtension: "txt",
  1166. subdirectory: vertexSubdirectory
  1167. )
  1168. let unknownSafetyRating = SafetyRating(
  1169. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1170. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1171. probabilityScore: 0.0,
  1172. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1173. severityScore: 0.0,
  1174. blocked: false
  1175. )
  1176. var foundUnknownSafetyRating = false
  1177. let stream = try model.generateContentStream("Hi")
  1178. for try await content in stream {
  1179. XCTAssertNotNil(content.text)
  1180. if let ratings = content.candidates.first?.safetyRatings,
  1181. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1182. foundUnknownSafetyRating = true
  1183. }
  1184. }
  1185. XCTAssertTrue(foundUnknownSafetyRating)
  1186. }
  1187. func testGenerateContentStream_successWithCitations() async throws {
  1188. MockURLProtocol
  1189. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1190. forResource: "streaming-success-citations",
  1191. withExtension: "txt",
  1192. subdirectory: vertexSubdirectory
  1193. )
  1194. let expectedPublicationDate = DateComponents(
  1195. calendar: Calendar(identifier: .gregorian),
  1196. year: 2014,
  1197. month: 3,
  1198. day: 30
  1199. )
  1200. let stream = try model.generateContentStream("Hi")
  1201. var citations = [Citation]()
  1202. var responses = [GenerateContentResponse]()
  1203. for try await content in stream {
  1204. responses.append(content)
  1205. XCTAssertNotNil(content.text)
  1206. let candidate = try XCTUnwrap(content.candidates.first)
  1207. if let sources = candidate.citationMetadata?.citations {
  1208. citations.append(contentsOf: sources)
  1209. }
  1210. }
  1211. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1212. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1213. XCTAssertEqual(citations.count, 6)
  1214. XCTAssertTrue(citations
  1215. .contains {
  1216. $0.startIndex == 0 && $0.endIndex == 128
  1217. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1218. && $0.license == nil && $0.publicationDate == nil
  1219. })
  1220. XCTAssertTrue(citations
  1221. .contains {
  1222. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1223. && $0.title == "some-citation-2" && $0.license == nil
  1224. && $0.publicationDate == expectedPublicationDate
  1225. })
  1226. XCTAssertTrue(citations
  1227. .contains {
  1228. $0.startIndex == 272 && $0.endIndex == 431
  1229. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1230. && $0.license == "mit" && $0.publicationDate == nil
  1231. })
  1232. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1233. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1234. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1235. }
  1236. func testGenerateContentStream_successWithThinking_thoughtSummary() async throws {
  1237. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1238. forResource: "streaming-success-thinking-reply-thought-summary",
  1239. withExtension: "txt",
  1240. subdirectory: vertexSubdirectory
  1241. )
  1242. var thoughtSummary = ""
  1243. var text = ""
  1244. let stream = try model.generateContentStream("Hi")
  1245. for try await response in stream {
  1246. let candidate = try XCTUnwrap(response.candidates.first)
  1247. XCTAssertEqual(candidate.content.parts.count, 1)
  1248. let part = try XCTUnwrap(candidate.content.parts.first)
  1249. let textPart = try XCTUnwrap(part as? TextPart)
  1250. if textPart.isThought {
  1251. let newThought = try XCTUnwrap(response.thoughtSummary)
  1252. thoughtSummary.append(newThought)
  1253. } else {
  1254. text.append(textPart.text)
  1255. }
  1256. }
  1257. XCTAssertTrue(thoughtSummary.hasPrefix("**Understanding the Core Question**"))
  1258. XCTAssertTrue(text.hasPrefix("The sky is blue due to a phenomenon"))
  1259. }
  1260. func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
  1261. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1262. forResource: "streaming-success-image-invalid-safety-ratings",
  1263. withExtension: "txt",
  1264. subdirectory: vertexSubdirectory
  1265. )
  1266. let stream = try model.generateContentStream(testPrompt)
  1267. var responses = [GenerateContentResponse]()
  1268. for try await content in stream {
  1269. responses.append(content)
  1270. }
  1271. let response = try XCTUnwrap(responses.first)
  1272. XCTAssertEqual(response.candidates.count, 1)
  1273. let candidate = try XCTUnwrap(response.candidates.first)
  1274. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  1275. XCTAssertEqual(candidate.content.parts.count, 1)
  1276. let inlineDataParts = response.inlineDataParts
  1277. XCTAssertEqual(inlineDataParts.count, 1)
  1278. let imagePart = try XCTUnwrap(inlineDataParts.first)
  1279. XCTAssertEqual(imagePart.mimeType, "image/png")
  1280. XCTAssertGreaterThan(imagePart.data.count, 0)
  1281. }
  1282. func testGenerateContentStream_appCheck_validToken() async throws {
  1283. let appCheckToken = "test-valid-token"
  1284. model = GenerativeModel(
  1285. modelName: testModelName,
  1286. modelResourceName: testModelResourceName,
  1287. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1288. appCheck: AppCheckInteropFake(token: appCheckToken)
  1289. ),
  1290. apiConfig: apiConfig,
  1291. tools: nil,
  1292. requestOptions: RequestOptions(),
  1293. urlSession: urlSession
  1294. )
  1295. MockURLProtocol
  1296. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1297. forResource: "streaming-success-basic-reply-short",
  1298. withExtension: "txt",
  1299. subdirectory: vertexSubdirectory,
  1300. appCheckToken: appCheckToken
  1301. )
  1302. let stream = try model.generateContentStream(testPrompt)
  1303. for try await _ in stream {}
  1304. }
  1305. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1306. model = GenerativeModel(
  1307. modelName: testModelName,
  1308. modelResourceName: testModelResourceName,
  1309. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1310. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  1311. ),
  1312. apiConfig: apiConfig,
  1313. tools: nil,
  1314. requestOptions: RequestOptions(),
  1315. urlSession: urlSession
  1316. )
  1317. MockURLProtocol
  1318. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1319. forResource: "streaming-success-basic-reply-short",
  1320. withExtension: "txt",
  1321. subdirectory: vertexSubdirectory,
  1322. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1323. )
  1324. let stream = try model.generateContentStream(testPrompt)
  1325. for try await _ in stream {}
  1326. }
  1327. func testGenerateContentStream_usageMetadata() async throws {
  1328. MockURLProtocol
  1329. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1330. forResource: "streaming-success-basic-reply-short",
  1331. withExtension: "txt",
  1332. subdirectory: vertexSubdirectory
  1333. )
  1334. var responses = [GenerateContentResponse]()
  1335. let stream = try model.generateContentStream(testPrompt)
  1336. for try await response in stream {
  1337. responses.append(response)
  1338. }
  1339. for (index, response) in responses.enumerated() {
  1340. if index == responses.endIndex - 1 {
  1341. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1342. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1343. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1344. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1345. } else {
  1346. // Only the last streamed response contains usage metadata
  1347. XCTAssertNil(response.usageMetadata)
  1348. }
  1349. }
  1350. }
  1351. func testGenerateContentStream_errorMidStream() async throws {
  1352. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1353. forResource: "streaming-failure-error-mid-stream",
  1354. withExtension: "txt",
  1355. subdirectory: vertexSubdirectory
  1356. )
  1357. var responseCount = 0
  1358. do {
  1359. let stream = try model.generateContentStream("Hi")
  1360. for try await content in stream {
  1361. XCTAssertNotNil(content.text)
  1362. responseCount += 1
  1363. }
  1364. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1365. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1366. XCTAssertEqual(rpcError.status, .cancelled)
  1367. // Check the content count is correct.
  1368. XCTAssertEqual(responseCount, 2)
  1369. return
  1370. }
  1371. XCTFail("Expected an internalError with an RPCError.")
  1372. }
  1373. func testGenerateContentStream_nonHTTPResponse() async throws {
  1374. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  1375. let stream = try model.generateContentStream("Hi")
  1376. do {
  1377. for try await content in stream {
  1378. XCTFail("Unexpected content in stream: \(content)")
  1379. }
  1380. } catch let GenerateContentError.internalError(underlying) {
  1381. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1382. return
  1383. }
  1384. XCTFail("Expected an internal error.")
  1385. }
  1386. func testGenerateContentStream_invalidResponse() async throws {
  1387. MockURLProtocol
  1388. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1389. forResource: "streaming-failure-invalid-json",
  1390. withExtension: "txt",
  1391. subdirectory: vertexSubdirectory
  1392. )
  1393. let stream = try model.generateContentStream(testPrompt)
  1394. do {
  1395. for try await content in stream {
  1396. XCTFail("Unexpected content in stream: \(content)")
  1397. }
  1398. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1399. guard case let .dataCorrupted(context) = underlying else {
  1400. XCTFail("Not a data corrupted error: \(underlying)")
  1401. return
  1402. }
  1403. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1404. return
  1405. }
  1406. XCTFail("Expected an internal error.")
  1407. }
  1408. func testGenerateContentStream_malformedContent() async throws {
  1409. MockURLProtocol
  1410. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1411. // Note: Although this file does not contain `parts` in `content`, it is not actually
  1412. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  1413. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  1414. forResource: "streaming-failure-malformed-content",
  1415. withExtension: "txt",
  1416. subdirectory: vertexSubdirectory
  1417. )
  1418. let stream = try model.generateContentStream(testPrompt)
  1419. do {
  1420. for try await content in stream {
  1421. XCTFail("Unexpected content in stream: \(content)")
  1422. }
  1423. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1424. guard case let .emptyContent(contentError) = underlyingError else {
  1425. XCTFail("Not an empty content error: \(underlyingError)")
  1426. return
  1427. }
  1428. XCTAssert(contentError is DecodingError)
  1429. return
  1430. }
  1431. XCTFail("Expected an internal decoding error.")
  1432. }
  1433. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1434. let expectedTimeout = 150.0
  1435. MockURLProtocol
  1436. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1437. forResource: "streaming-success-basic-reply-short",
  1438. withExtension: "txt",
  1439. subdirectory: vertexSubdirectory,
  1440. timeout: expectedTimeout
  1441. )
  1442. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1443. model = GenerativeModel(
  1444. modelName: testModelName,
  1445. modelResourceName: testModelResourceName,
  1446. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1447. apiConfig: apiConfig,
  1448. tools: nil,
  1449. requestOptions: requestOptions,
  1450. urlSession: urlSession
  1451. )
  1452. var responses = 0
  1453. let stream = try model.generateContentStream(testPrompt)
  1454. for try await content in stream {
  1455. XCTAssertNotNil(content.text)
  1456. responses += 1
  1457. }
  1458. XCTAssertEqual(responses, 1)
  1459. }
  1460. // MARK: - Count Tokens
  1461. func testCountTokens_succeeds() async throws {
  1462. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1463. forResource: "unary-success-total-tokens",
  1464. withExtension: "json",
  1465. subdirectory: vertexSubdirectory
  1466. )
  1467. let response = try await model.countTokens("Why is the sky blue?")
  1468. XCTAssertEqual(response.totalTokens, 6)
  1469. }
  1470. func testCountTokens_succeeds_detailed() async throws {
  1471. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1472. forResource: "unary-success-detailed-token-response",
  1473. withExtension: "json",
  1474. subdirectory: vertexSubdirectory
  1475. )
  1476. let response = try await model.countTokens("Why is the sky blue?")
  1477. XCTAssertEqual(response.totalTokens, 1837)
  1478. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1479. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1480. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1481. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1482. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1483. }
  1484. func testCountTokens_succeeds_allOptions() async throws {
  1485. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1486. forResource: "unary-success-total-tokens",
  1487. withExtension: "json",
  1488. subdirectory: vertexSubdirectory
  1489. )
  1490. let generationConfig = GenerationConfig(
  1491. temperature: 0.5,
  1492. topP: 0.9,
  1493. topK: 3,
  1494. candidateCount: 1,
  1495. maxOutputTokens: 1024,
  1496. stopSequences: ["test-stop"],
  1497. responseMIMEType: "text/plain"
  1498. )
  1499. let sumFunction = FunctionDeclaration(
  1500. name: "sum",
  1501. description: "Add two integers.",
  1502. parameters: ["x": .integer(), "y": .integer()]
  1503. )
  1504. let systemInstruction = ModelContent(
  1505. role: "system",
  1506. parts: "You are a calculator. Use the provided tools."
  1507. )
  1508. model = GenerativeModel(
  1509. modelName: testModelName,
  1510. modelResourceName: testModelResourceName,
  1511. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1512. apiConfig: apiConfig,
  1513. generationConfig: generationConfig,
  1514. tools: [Tool(functionDeclarations: [sumFunction])],
  1515. systemInstruction: systemInstruction,
  1516. requestOptions: RequestOptions(),
  1517. urlSession: urlSession
  1518. )
  1519. let response = try await model.countTokens("Why is the sky blue?")
  1520. XCTAssertEqual(response.totalTokens, 6)
  1521. }
  1522. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1523. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1524. forResource: "unary-success-no-billable-characters",
  1525. withExtension: "json",
  1526. subdirectory: vertexSubdirectory
  1527. )
  1528. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1529. XCTAssertEqual(response.totalTokens, 258)
  1530. }
  1531. func testCountTokens_modelNotFound() async throws {
  1532. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1533. forResource: "unary-failure-model-not-found", withExtension: "json",
  1534. subdirectory: vertexSubdirectory,
  1535. statusCode: 404
  1536. )
  1537. do {
  1538. _ = try await model.countTokens("Why is the sky blue?")
  1539. XCTFail("Request should not have succeeded.")
  1540. } catch let rpcError as BackendError {
  1541. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1542. XCTAssertEqual(rpcError.status, .notFound)
  1543. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1544. return
  1545. }
  1546. XCTFail("Expected internal RPCError.")
  1547. }
  1548. func testCountTokens_requestOptions_customTimeout() async throws {
  1549. let expectedTimeout = 150.0
  1550. MockURLProtocol
  1551. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1552. forResource: "unary-success-total-tokens",
  1553. withExtension: "json",
  1554. subdirectory: vertexSubdirectory,
  1555. timeout: expectedTimeout
  1556. )
  1557. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1558. model = GenerativeModel(
  1559. modelName: testModelName,
  1560. modelResourceName: testModelResourceName,
  1561. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1562. apiConfig: apiConfig,
  1563. tools: nil,
  1564. requestOptions: requestOptions,
  1565. urlSession: urlSession
  1566. )
  1567. let response = try await model.countTokens(testPrompt)
  1568. XCTAssertEqual(response.totalTokens, 6)
  1569. }
  1570. }
  1571. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1572. extension SafetyRating: Swift.Comparable {
  1573. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1574. return lhs.category.rawValue < rhs.category.rawValue
  1575. }
  1576. }