GenerativeModelVertexAITests.swift 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelVertexAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let safetyRatingsInvalidIgnored = [
  57. SafetyRating(
  58. category: .hateSpeech,
  59. probability: .negligible,
  60. probabilityScore: 0.00039444832,
  61. severity: .negligible,
  62. severityScore: 0.0,
  63. blocked: false
  64. ),
  65. SafetyRating(
  66. category: .dangerousContent,
  67. probability: .negligible,
  68. probabilityScore: 0.0010654529,
  69. severity: .negligible,
  70. severityScore: 0.0049325973,
  71. blocked: false
  72. ),
  73. SafetyRating(
  74. category: .harassment,
  75. probability: .negligible,
  76. probabilityScore: 0.00026658305,
  77. severity: .negligible,
  78. severityScore: 0.0,
  79. blocked: false
  80. ),
  81. SafetyRating(
  82. category: .sexuallyExplicit,
  83. probability: .negligible,
  84. probabilityScore: 0.0013701695,
  85. severity: .negligible,
  86. severityScore: 0.07626295,
  87. blocked: false
  88. ),
  89. // Ignored Invalid Safety Ratings: {},{},{},{}
  90. ].sorted()
  91. let testModelName = "test-model"
  92. let testModelResourceName =
  93. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  94. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  95. let vertexSubdirectory = "mock-responses/vertexai"
  96. var urlSession: URLSession!
  97. var model: GenerativeModel!
  98. override func setUp() async throws {
  99. let configuration = URLSessionConfiguration.default
  100. configuration.protocolClasses = [MockURLProtocol.self]
  101. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  102. model = GenerativeModel(
  103. modelName: testModelName,
  104. modelResourceName: testModelResourceName,
  105. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  106. apiConfig: apiConfig,
  107. tools: nil,
  108. requestOptions: RequestOptions(),
  109. urlSession: urlSession
  110. )
  111. }
  112. override func tearDown() {
  113. MockURLProtocol.requestHandler = nil
  114. }
  115. // MARK: - Generate Content
  116. func testGenerateContent_success_basicReplyLong() async throws {
  117. MockURLProtocol
  118. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  119. forResource: "unary-success-basic-reply-long",
  120. withExtension: "json",
  121. subdirectory: vertexSubdirectory
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. XCTAssertEqual(candidate.safetyRatings.count, 4)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let part = try XCTUnwrap(candidate.content.parts.first)
  131. let partText = try XCTUnwrap(part as? TextPart).text
  132. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  133. XCTAssertEqual(response.text, partText)
  134. XCTAssertEqual(response.functionCalls, [])
  135. XCTAssertEqual(response.inlineDataParts, [])
  136. }
  137. func testGenerateContent_success_basicReplyShort() async throws {
  138. MockURLProtocol
  139. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  140. forResource: "unary-success-basic-reply-short",
  141. withExtension: "json",
  142. subdirectory: vertexSubdirectory
  143. )
  144. let response = try await model.generateContent(testPrompt)
  145. XCTAssertEqual(response.candidates.count, 1)
  146. let candidate = try XCTUnwrap(response.candidates.first)
  147. let finishReason = try XCTUnwrap(candidate.finishReason)
  148. XCTAssertEqual(finishReason, .stop)
  149. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  150. XCTAssertEqual(candidate.content.parts.count, 1)
  151. let part = try XCTUnwrap(candidate.content.parts.first)
  152. let textPart = try XCTUnwrap(part as? TextPart)
  153. XCTAssertEqual(textPart.text, "Mountain View, California")
  154. XCTAssertEqual(response.text, textPart.text)
  155. XCTAssertEqual(response.functionCalls, [])
  156. }
  157. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  158. MockURLProtocol
  159. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  160. forResource: "unary-success-basic-response-long-usage-metadata",
  161. withExtension: "json",
  162. subdirectory: vertexSubdirectory
  163. )
  164. let response = try await model.generateContent(testPrompt)
  165. XCTAssertEqual(response.candidates.count, 1)
  166. let candidate = try XCTUnwrap(response.candidates.first)
  167. let finishReason = try XCTUnwrap(candidate.finishReason)
  168. XCTAssertEqual(finishReason, .stop)
  169. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  170. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  171. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  175. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  178. }
  179. func testGenerateContent_success_citations() async throws {
  180. MockURLProtocol
  181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  182. forResource: "unary-success-citations",
  183. withExtension: "json",
  184. subdirectory: vertexSubdirectory
  185. )
  186. let expectedPublicationDate = DateComponents(
  187. calendar: Calendar(identifier: .gregorian),
  188. year: 2019,
  189. month: 5,
  190. day: 10
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. XCTAssertEqual(candidate.content.parts.count, 1)
  196. XCTAssertEqual(response.text, "Some information cited from an external source")
  197. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  198. XCTAssertEqual(citationMetadata.citations.count, 3)
  199. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  200. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  201. XCTAssertEqual(citationSource1.startIndex, 0)
  202. XCTAssertEqual(citationSource1.endIndex, 128)
  203. XCTAssertNil(citationSource1.title)
  204. XCTAssertNil(citationSource1.license)
  205. XCTAssertNil(citationSource1.publicationDate)
  206. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  207. XCTAssertEqual(citationSource2.title, "some-citation-2")
  208. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  209. XCTAssertEqual(citationSource2.startIndex, 130)
  210. XCTAssertEqual(citationSource2.endIndex, 265)
  211. XCTAssertNil(citationSource2.uri)
  212. XCTAssertNil(citationSource2.license)
  213. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  214. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  215. XCTAssertEqual(citationSource3.startIndex, 272)
  216. XCTAssertEqual(citationSource3.endIndex, 431)
  217. XCTAssertEqual(citationSource3.license, "mit")
  218. XCTAssertNil(citationSource3.title)
  219. XCTAssertNil(citationSource3.publicationDate)
  220. }
  221. func testGenerateContent_success_quoteReply() async throws {
  222. MockURLProtocol
  223. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  224. forResource: "unary-success-quote-reply",
  225. withExtension: "json",
  226. subdirectory: vertexSubdirectory
  227. )
  228. let response = try await model.generateContent(testPrompt)
  229. XCTAssertEqual(response.candidates.count, 1)
  230. let candidate = try XCTUnwrap(response.candidates.first)
  231. let finishReason = try XCTUnwrap(candidate.finishReason)
  232. XCTAssertEqual(finishReason, .stop)
  233. XCTAssertEqual(candidate.safetyRatings.count, 4)
  234. XCTAssertEqual(candidate.content.parts.count, 1)
  235. let part = try XCTUnwrap(candidate.content.parts.first)
  236. let textPart = try XCTUnwrap(part as? TextPart)
  237. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  238. XCTAssertEqual(response.text, textPart.text)
  239. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  240. XCTAssertNil(promptFeedback.blockReason)
  241. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  242. }
  243. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  244. let expectedSafetyRatings = [
  245. SafetyRating(
  246. category: .harassment,
  247. probability: .medium,
  248. probabilityScore: 0.0,
  249. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  250. severityScore: 0.0,
  251. blocked: false
  252. ),
  253. SafetyRating(
  254. category: .dangerousContent,
  255. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  256. probabilityScore: 0.0,
  257. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  258. severityScore: 0.0,
  259. blocked: false
  260. ),
  261. SafetyRating(
  262. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  263. probability: .high,
  264. probabilityScore: 0.0,
  265. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  266. severityScore: 0.0,
  267. blocked: false
  268. ),
  269. ]
  270. MockURLProtocol
  271. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "unary-success-unknown-enum-safety-ratings",
  273. withExtension: "json",
  274. subdirectory: vertexSubdirectory
  275. )
  276. let response = try await model.generateContent(testPrompt)
  277. XCTAssertEqual(response.text, "Some text")
  278. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  279. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  280. }
  281. func testGenerateContent_success_prefixedModelName() async throws {
  282. MockURLProtocol
  283. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  284. forResource: "unary-success-basic-reply-short",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let model = GenerativeModel(
  289. modelName: testModelName,
  290. modelResourceName: testModelResourceName,
  291. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  292. apiConfig: apiConfig,
  293. tools: nil,
  294. requestOptions: RequestOptions(),
  295. urlSession: urlSession
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  300. MockURLProtocol
  301. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  302. forResource: "unary-success-function-call-empty-arguments",
  303. withExtension: "json",
  304. subdirectory: vertexSubdirectory
  305. )
  306. let response = try await model.generateContent(testPrompt)
  307. XCTAssertEqual(response.candidates.count, 1)
  308. let candidate = try XCTUnwrap(response.candidates.first)
  309. XCTAssertEqual(candidate.content.parts.count, 1)
  310. let part = try XCTUnwrap(candidate.content.parts.first)
  311. guard let functionCall = part as? FunctionCallPart else {
  312. XCTFail("Part is not a FunctionCall.")
  313. return
  314. }
  315. XCTAssertEqual(functionCall.name, "current_time")
  316. XCTAssertTrue(functionCall.args.isEmpty)
  317. XCTAssertEqual(response.functionCalls, [functionCall])
  318. }
  319. func testGenerateContent_success_functionCall_noArguments() async throws {
  320. MockURLProtocol
  321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  322. forResource: "unary-success-function-call-no-arguments",
  323. withExtension: "json",
  324. subdirectory: vertexSubdirectory
  325. )
  326. let response = try await model.generateContent(testPrompt)
  327. XCTAssertEqual(response.candidates.count, 1)
  328. let candidate = try XCTUnwrap(response.candidates.first)
  329. XCTAssertEqual(candidate.content.parts.count, 1)
  330. let part = try XCTUnwrap(candidate.content.parts.first)
  331. guard let functionCall = part as? FunctionCallPart else {
  332. XCTFail("Part is not a FunctionCall.")
  333. return
  334. }
  335. XCTAssertEqual(functionCall.name, "current_time")
  336. XCTAssertTrue(functionCall.args.isEmpty)
  337. XCTAssertEqual(response.functionCalls, [functionCall])
  338. }
  339. func testGenerateContent_success_functionCall_withArguments() async throws {
  340. MockURLProtocol
  341. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  342. forResource: "unary-success-function-call-with-arguments",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 1)
  350. let part = try XCTUnwrap(candidate.content.parts.first)
  351. guard let functionCall = part as? FunctionCallPart else {
  352. XCTFail("Part is not a FunctionCall.")
  353. return
  354. }
  355. XCTAssertEqual(functionCall.name, "sum")
  356. XCTAssertEqual(functionCall.args.count, 2)
  357. let argX = try XCTUnwrap(functionCall.args["x"])
  358. XCTAssertEqual(argX, .number(4))
  359. let argY = try XCTUnwrap(functionCall.args["y"])
  360. XCTAssertEqual(argY, .number(5))
  361. XCTAssertEqual(response.functionCalls, [functionCall])
  362. }
  363. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  364. MockURLProtocol
  365. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  366. forResource: "unary-success-function-call-parallel-calls",
  367. withExtension: "json",
  368. subdirectory: vertexSubdirectory
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. XCTAssertEqual(response.candidates.count, 1)
  372. let candidate = try XCTUnwrap(response.candidates.first)
  373. XCTAssertEqual(candidate.content.parts.count, 3)
  374. let functionCalls = response.functionCalls
  375. XCTAssertEqual(functionCalls.count, 3)
  376. }
  377. func testGenerateContent_success_functionCall_mixedContent() async throws {
  378. MockURLProtocol
  379. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  380. forResource: "unary-success-function-call-mixed-content",
  381. withExtension: "json",
  382. subdirectory: vertexSubdirectory
  383. )
  384. let response = try await model.generateContent(testPrompt)
  385. XCTAssertEqual(response.candidates.count, 1)
  386. let candidate = try XCTUnwrap(response.candidates.first)
  387. XCTAssertEqual(candidate.content.parts.count, 4)
  388. let functionCalls = response.functionCalls
  389. XCTAssertEqual(functionCalls.count, 2)
  390. let text = try XCTUnwrap(response.text)
  391. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  392. }
  393. func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
  394. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  395. forResource: "unary-success-image-invalid-safety-ratings",
  396. withExtension: "json",
  397. subdirectory: vertexSubdirectory
  398. )
  399. let response = try await model.generateContent(testPrompt)
  400. XCTAssertEqual(response.candidates.count, 1)
  401. let candidate = try XCTUnwrap(response.candidates.first)
  402. XCTAssertEqual(candidate.content.parts.count, 1)
  403. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  404. let inlineDataParts = response.inlineDataParts
  405. XCTAssertEqual(inlineDataParts.count, 1)
  406. let imagePart = try XCTUnwrap(inlineDataParts.first)
  407. XCTAssertEqual(imagePart.mimeType, "image/png")
  408. XCTAssertGreaterThan(imagePart.data.count, 0)
  409. }
  410. func testGenerateContent_appCheck_validToken() async throws {
  411. let appCheckToken = "test-valid-token"
  412. model = GenerativeModel(
  413. modelName: testModelName,
  414. modelResourceName: testModelResourceName,
  415. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  416. appCheck: AppCheckInteropFake(token: appCheckToken)
  417. ),
  418. apiConfig: apiConfig,
  419. tools: nil,
  420. requestOptions: RequestOptions(),
  421. urlSession: urlSession
  422. )
  423. MockURLProtocol
  424. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  425. forResource: "unary-success-basic-reply-short",
  426. withExtension: "json",
  427. subdirectory: vertexSubdirectory,
  428. appCheckToken: appCheckToken
  429. )
  430. _ = try await model.generateContent(testPrompt)
  431. }
  432. func testGenerateContent_dataCollectionOff() async throws {
  433. let appCheckToken = "test-valid-token"
  434. model = GenerativeModel(
  435. modelName: testModelName,
  436. modelResourceName: testModelResourceName,
  437. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  438. appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
  439. ),
  440. apiConfig: apiConfig,
  441. tools: nil,
  442. requestOptions: RequestOptions(),
  443. urlSession: urlSession
  444. )
  445. MockURLProtocol
  446. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  447. forResource: "unary-success-basic-reply-short",
  448. withExtension: "json",
  449. subdirectory: vertexSubdirectory,
  450. appCheckToken: appCheckToken,
  451. dataCollection: false
  452. )
  453. _ = try await model.generateContent(testPrompt)
  454. }
  455. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  456. model = GenerativeModel(
  457. modelName: testModelName,
  458. modelResourceName: testModelResourceName,
  459. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  460. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  461. ),
  462. apiConfig: apiConfig,
  463. tools: nil,
  464. requestOptions: RequestOptions(),
  465. urlSession: urlSession
  466. )
  467. MockURLProtocol
  468. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  469. forResource: "unary-success-basic-reply-short",
  470. withExtension: "json",
  471. subdirectory: vertexSubdirectory,
  472. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  473. )
  474. _ = try await model.generateContent(testPrompt)
  475. }
  476. func testGenerateContent_auth_validAuthToken() async throws {
  477. let authToken = "test-valid-token"
  478. model = GenerativeModel(
  479. modelName: testModelName,
  480. modelResourceName: testModelResourceName,
  481. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  482. auth: AuthInteropFake(token: authToken)
  483. ),
  484. apiConfig: apiConfig,
  485. tools: nil,
  486. requestOptions: RequestOptions(),
  487. urlSession: urlSession
  488. )
  489. MockURLProtocol
  490. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  491. forResource: "unary-success-basic-reply-short",
  492. withExtension: "json",
  493. subdirectory: vertexSubdirectory,
  494. authToken: authToken
  495. )
  496. _ = try await model.generateContent(testPrompt)
  497. }
  498. func testGenerateContent_auth_nilAuthToken() async throws {
  499. model = GenerativeModel(
  500. modelName: testModelName,
  501. modelResourceName: testModelResourceName,
  502. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  503. apiConfig: apiConfig,
  504. tools: nil,
  505. requestOptions: RequestOptions(),
  506. urlSession: urlSession
  507. )
  508. MockURLProtocol
  509. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  510. forResource: "unary-success-basic-reply-short",
  511. withExtension: "json",
  512. subdirectory: vertexSubdirectory,
  513. authToken: nil
  514. )
  515. _ = try await model.generateContent(testPrompt)
  516. }
  517. func testGenerateContent_auth_authTokenRefreshError() async throws {
  518. model = GenerativeModel(
  519. modelName: testModelName,
  520. modelResourceName: testModelResourceName,
  521. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  522. auth: AuthInteropFake(error: AuthErrorFake())
  523. ),
  524. apiConfig: apiConfig,
  525. tools: nil,
  526. requestOptions: RequestOptions(),
  527. urlSession: urlSession
  528. )
  529. MockURLProtocol
  530. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  531. forResource: "unary-success-basic-reply-short",
  532. withExtension: "json",
  533. subdirectory: vertexSubdirectory,
  534. authToken: nil
  535. )
  536. do {
  537. _ = try await model.generateContent(testPrompt)
  538. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  539. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  540. //
  541. } catch {
  542. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  543. }
  544. }
  545. func testGenerateContent_usageMetadata() async throws {
  546. MockURLProtocol
  547. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  548. forResource: "unary-success-basic-reply-short",
  549. withExtension: "json",
  550. subdirectory: vertexSubdirectory
  551. )
  552. let response = try await model.generateContent(testPrompt)
  553. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  554. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  555. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  556. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  557. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  558. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  559. }
  560. func testGenerateContent_groundingMetadata() async throws {
  561. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  562. forResource: "unary-success-google-search-grounding",
  563. withExtension: "json",
  564. subdirectory: vertexSubdirectory
  565. )
  566. let response = try await model.generateContent(testPrompt)
  567. XCTAssertEqual(response.candidates.count, 1)
  568. let candidate = try XCTUnwrap(response.candidates.first)
  569. let groundingMetadata = try XCTUnwrap(candidate.groundingMetadata)
  570. XCTAssertEqual(groundingMetadata.webSearchQueries, ["current weather in London"])
  571. XCTAssertNotNil(groundingMetadata.searchEntryPoint)
  572. XCTAssertNotNil(groundingMetadata.searchEntryPoint?.renderedContent)
  573. XCTAssertEqual(groundingMetadata.groundingChunks.count, 2)
  574. let firstChunk = try XCTUnwrap(groundingMetadata.groundingChunks.first?.web)
  575. XCTAssertEqual(firstChunk.title, "accuweather.com")
  576. XCTAssertNotNil(firstChunk.uri)
  577. XCTAssertNil(firstChunk.domain) // Domain is not supported by Google AI backend
  578. XCTAssertEqual(groundingMetadata.groundingSupports.count, 3)
  579. let firstSupport = try XCTUnwrap(groundingMetadata.groundingSupports.first)
  580. let segment = try XCTUnwrap(firstSupport.segment)
  581. XCTAssertEqual(segment.text, "The current weather in London, United Kingdom is cloudy.")
  582. XCTAssertEqual(segment.startIndex, 0)
  583. XCTAssertEqual(segment.partIndex, 0)
  584. XCTAssertEqual(segment.endIndex, 56)
  585. XCTAssertEqual(firstSupport.groundingChunkIndices, [0])
  586. }
  587. func testGenerateContent_withGoogleSearchTool() async throws {
  588. let model = GenerativeModel(
  589. modelName: testModelName,
  590. modelResourceName: testModelResourceName,
  591. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  592. apiConfig: apiConfig,
  593. tools: [.googleSearch()],
  594. requestOptions: RequestOptions(),
  595. urlSession: urlSession
  596. )
  597. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  598. forResource: "unary-success-basic-reply-short",
  599. withExtension: "json",
  600. subdirectory: vertexSubdirectory
  601. )
  602. _ = try await model.generateContent(testPrompt)
  603. }
  604. func testGenerateContent_failure_invalidAPIKey() async throws {
  605. let expectedStatusCode = 400
  606. MockURLProtocol
  607. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  608. forResource: "unary-failure-api-key",
  609. withExtension: "json",
  610. subdirectory: vertexSubdirectory,
  611. statusCode: expectedStatusCode
  612. )
  613. do {
  614. _ = try await model.generateContent(testPrompt)
  615. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  616. } catch let GenerateContentError.internalError(error as BackendError) {
  617. XCTAssertEqual(error.httpResponseCode, 400)
  618. XCTAssertEqual(error.status, .invalidArgument)
  619. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  620. XCTAssertTrue(error.localizedDescription.contains(error.message))
  621. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  622. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  623. let nsError = error as NSError
  624. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  625. XCTAssertEqual(nsError.code, error.httpResponseCode)
  626. return
  627. } catch {
  628. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  629. }
  630. }
  631. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  632. let expectedStatusCode = 403
  633. MockURLProtocol
  634. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  635. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  636. withExtension: "json",
  637. subdirectory: vertexSubdirectory,
  638. statusCode: expectedStatusCode
  639. )
  640. do {
  641. _ = try await model.generateContent(testPrompt)
  642. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  643. } catch let GenerateContentError.internalError(error as BackendError) {
  644. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  645. XCTAssertEqual(error.status, .permissionDenied)
  646. XCTAssertTrue(error.message
  647. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  648. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  649. return
  650. } catch {
  651. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  652. }
  653. }
  654. func testGenerateContent_failure_emptyContent() async throws {
  655. MockURLProtocol
  656. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  657. forResource: "unary-failure-empty-content",
  658. withExtension: "json",
  659. subdirectory: vertexSubdirectory
  660. )
  661. do {
  662. _ = try await model.generateContent(testPrompt)
  663. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  664. } catch let GenerateContentError
  665. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  666. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  667. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  668. return
  669. }
  670. _ = try XCTUnwrap(decodingError as? DecodingError,
  671. "Not a DecodingError: \(decodingError)")
  672. } catch {
  673. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  674. }
  675. }
  676. func testGenerateContent_failure_finishReasonSafety() async throws {
  677. MockURLProtocol
  678. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  679. forResource: "unary-failure-finish-reason-safety",
  680. withExtension: "json",
  681. subdirectory: vertexSubdirectory
  682. )
  683. do {
  684. _ = try await model.generateContent(testPrompt)
  685. XCTFail("Should throw")
  686. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  687. XCTAssertEqual(reason, .safety)
  688. XCTAssertEqual(response.text, "<redacted>")
  689. } catch {
  690. XCTFail("Should throw a responseStoppedEarly")
  691. }
  692. }
  693. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  694. MockURLProtocol
  695. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  696. forResource: "unary-failure-finish-reason-safety-no-content",
  697. withExtension: "json",
  698. subdirectory: vertexSubdirectory
  699. )
  700. do {
  701. _ = try await model.generateContent(testPrompt)
  702. XCTFail("Should throw")
  703. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  704. XCTAssertEqual(reason, .safety)
  705. XCTAssertNil(response.text)
  706. } catch {
  707. XCTFail("Should throw a responseStoppedEarly")
  708. }
  709. }
  710. func testGenerateContent_failure_imageRejected() async throws {
  711. let expectedStatusCode = 400
  712. MockURLProtocol
  713. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  714. forResource: "unary-failure-image-rejected",
  715. withExtension: "json",
  716. subdirectory: vertexSubdirectory,
  717. statusCode: 400
  718. )
  719. do {
  720. _ = try await model.generateContent(testPrompt)
  721. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  722. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  723. XCTAssertEqual(rpcError.status, .invalidArgument)
  724. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  725. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  726. } catch {
  727. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  728. }
  729. }
  730. func testGenerateContent_failure_promptBlockedSafety() async throws {
  731. MockURLProtocol
  732. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  733. forResource: "unary-failure-prompt-blocked-safety",
  734. withExtension: "json",
  735. subdirectory: vertexSubdirectory
  736. )
  737. do {
  738. _ = try await model.generateContent(testPrompt)
  739. XCTFail("Should throw")
  740. } catch let GenerateContentError.promptBlocked(response) {
  741. XCTAssertNil(response.text)
  742. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  743. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  744. XCTAssertNil(promptFeedback.blockReasonMessage)
  745. } catch {
  746. XCTFail("Should throw a promptBlocked")
  747. }
  748. }
  749. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  750. MockURLProtocol
  751. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  752. forResource: "unary-failure-prompt-blocked-safety-with-message",
  753. withExtension: "json",
  754. subdirectory: vertexSubdirectory
  755. )
  756. do {
  757. _ = try await model.generateContent(testPrompt)
  758. XCTFail("Should throw")
  759. } catch let GenerateContentError.promptBlocked(response) {
  760. XCTAssertNil(response.text)
  761. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  762. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  763. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  764. } catch {
  765. XCTFail("Should throw a promptBlocked")
  766. }
  767. }
  768. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  769. MockURLProtocol
  770. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  771. forResource: "unary-failure-unknown-enum-finish-reason",
  772. withExtension: "json",
  773. subdirectory: vertexSubdirectory
  774. )
  775. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  776. do {
  777. _ = try await model.generateContent(testPrompt)
  778. XCTFail("Should throw")
  779. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  780. XCTAssertEqual(reason, unknownFinishReason)
  781. XCTAssertEqual(response.text, "Some text")
  782. } catch {
  783. XCTFail("Should throw a responseStoppedEarly")
  784. }
  785. }
  786. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  787. MockURLProtocol
  788. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  789. forResource: "unary-failure-unknown-enum-prompt-blocked",
  790. withExtension: "json",
  791. subdirectory: vertexSubdirectory
  792. )
  793. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  794. do {
  795. _ = try await model.generateContent(testPrompt)
  796. XCTFail("Should throw")
  797. } catch let GenerateContentError.promptBlocked(response) {
  798. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  799. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  800. } catch {
  801. XCTFail("Should throw a promptBlocked")
  802. }
  803. }
  804. func testGenerateContent_failure_unknownModel() async throws {
  805. let expectedStatusCode = 404
  806. MockURLProtocol
  807. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  808. forResource: "unary-failure-unknown-model",
  809. withExtension: "json",
  810. subdirectory: vertexSubdirectory,
  811. statusCode: 404
  812. )
  813. do {
  814. _ = try await model.generateContent(testPrompt)
  815. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  816. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  817. XCTAssertEqual(rpcError.status, .notFound)
  818. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  819. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  820. } catch {
  821. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  822. }
  823. }
  824. func testGenerateContent_failure_nonHTTPResponse() async throws {
  825. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  826. var responseError: Error?
  827. var content: GenerateContentResponse?
  828. do {
  829. content = try await model.generateContent(testPrompt)
  830. } catch {
  831. responseError = error
  832. }
  833. XCTAssertNil(content)
  834. XCTAssertNotNil(responseError)
  835. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  836. guard case let .internalError(underlyingError) = generateContentError else {
  837. XCTFail("Not an internal error: \(generateContentError)")
  838. return
  839. }
  840. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  841. let underlyingNSError = underlyingError as NSError
  842. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  843. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  844. }
  845. func testGenerateContent_failure_invalidResponse() async throws {
  846. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  847. forResource: "unary-failure-invalid-response",
  848. withExtension: "json",
  849. subdirectory: vertexSubdirectory
  850. )
  851. var responseError: Error?
  852. var content: GenerateContentResponse?
  853. do {
  854. content = try await model.generateContent(testPrompt)
  855. } catch {
  856. responseError = error
  857. }
  858. XCTAssertNil(content)
  859. XCTAssertNotNil(responseError)
  860. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  861. guard case let .internalError(underlyingError) = generateContentError else {
  862. XCTFail("Not an internal error: \(generateContentError)")
  863. return
  864. }
  865. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  866. guard case let .dataCorrupted(context) = decodingError else {
  867. XCTFail("Not a data corrupted error: \(decodingError)")
  868. return
  869. }
  870. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  871. }
  872. func testGenerateContent_failure_malformedContent() async throws {
  873. MockURLProtocol
  874. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  875. // Note: Although this file does not contain `parts` in `content`, it is not actually
  876. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  877. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  878. forResource: "unary-failure-malformed-content",
  879. withExtension: "json",
  880. subdirectory: vertexSubdirectory
  881. )
  882. var responseError: Error?
  883. var content: GenerateContentResponse?
  884. do {
  885. content = try await model.generateContent(testPrompt)
  886. } catch {
  887. responseError = error
  888. }
  889. XCTAssertNil(content)
  890. XCTAssertNotNil(responseError)
  891. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  892. guard case let .internalError(underlyingError) = generateContentError else {
  893. XCTFail("Not an internal error: \(generateContentError)")
  894. return
  895. }
  896. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  897. guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
  898. XCTFail("Not an empty content error: \(invalidCandidateError)")
  899. return
  900. }
  901. _ = try XCTUnwrap(
  902. emptyContentUnderlyingError as? DecodingError,
  903. "Not a decoding error: \(emptyContentUnderlyingError)"
  904. )
  905. }
  906. func testGenerateContentMissingSafetyRatings() async throws {
  907. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  908. forResource: "unary-success-missing-safety-ratings",
  909. withExtension: "json",
  910. subdirectory: vertexSubdirectory
  911. )
  912. let content = try await model.generateContent(testPrompt)
  913. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  914. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  915. XCTAssertEqual(content.text, "This is the generated content.")
  916. }
  917. func testGenerateContent_requestOptions_customTimeout() async throws {
  918. let expectedTimeout = 150.0
  919. MockURLProtocol
  920. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  921. forResource: "unary-success-basic-reply-short",
  922. withExtension: "json",
  923. subdirectory: vertexSubdirectory,
  924. timeout: expectedTimeout
  925. )
  926. let requestOptions = RequestOptions(timeout: expectedTimeout)
  927. model = GenerativeModel(
  928. modelName: testModelName,
  929. modelResourceName: testModelResourceName,
  930. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  931. apiConfig: apiConfig,
  932. tools: nil,
  933. requestOptions: requestOptions,
  934. urlSession: urlSession
  935. )
  936. let response = try await model.generateContent(testPrompt)
  937. XCTAssertEqual(response.candidates.count, 1)
  938. }
  939. // MARK: - Generate Content (Streaming)
  940. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  941. MockURLProtocol
  942. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  943. forResource: "unary-failure-api-key",
  944. withExtension: "json",
  945. subdirectory: vertexSubdirectory
  946. )
  947. do {
  948. let stream = try model.generateContentStream("Hi")
  949. for try await _ in stream {
  950. XCTFail("No content is there, this shouldn't happen.")
  951. }
  952. } catch let GenerateContentError.internalError(error as BackendError) {
  953. XCTAssertEqual(error.httpResponseCode, 400)
  954. XCTAssertEqual(error.status, .invalidArgument)
  955. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  956. XCTAssertTrue(error.localizedDescription.contains(error.message))
  957. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  958. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  959. let nsError = error as NSError
  960. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  961. XCTAssertEqual(nsError.code, error.httpResponseCode)
  962. return
  963. }
  964. XCTFail("Should have caught an error.")
  965. }
  966. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  967. let expectedStatusCode = 403
  968. MockURLProtocol
  969. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  970. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  971. withExtension: "json",
  972. subdirectory: vertexSubdirectory,
  973. statusCode: expectedStatusCode
  974. )
  975. do {
  976. let stream = try model.generateContentStream(testPrompt)
  977. for try await _ in stream {
  978. XCTFail("No content is there, this shouldn't happen.")
  979. }
  980. } catch let GenerateContentError.internalError(error as BackendError) {
  981. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  982. XCTAssertEqual(error.status, .permissionDenied)
  983. XCTAssertTrue(error.message
  984. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  985. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  986. return
  987. }
  988. XCTFail("Should have caught an error.")
  989. }
  990. func testGenerateContentStream_failureEmptyContent() async throws {
  991. MockURLProtocol
  992. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  993. forResource: "streaming-failure-empty-content",
  994. withExtension: "txt",
  995. subdirectory: vertexSubdirectory
  996. )
  997. do {
  998. let stream = try model.generateContentStream("Hi")
  999. for try await _ in stream {
  1000. XCTFail("No content is there, this shouldn't happen.")
  1001. }
  1002. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  1003. // Underlying error is as expected, nothing else to check.
  1004. return
  1005. }
  1006. XCTFail("Should have caught an error.")
  1007. }
  1008. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  1009. MockURLProtocol
  1010. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1011. forResource: "streaming-failure-finish-reason-safety",
  1012. withExtension: "txt",
  1013. subdirectory: vertexSubdirectory
  1014. )
  1015. do {
  1016. let stream = try model.generateContentStream("Hi")
  1017. for try await _ in stream {
  1018. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1019. }
  1020. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  1021. XCTAssertEqual(reason, .safety)
  1022. let candidate = try XCTUnwrap(response.candidates.first)
  1023. XCTAssertEqual(candidate.finishReason, reason)
  1024. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  1025. return
  1026. }
  1027. XCTFail("Should have caught an error.")
  1028. }
  1029. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  1030. MockURLProtocol
  1031. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1032. forResource: "streaming-failure-prompt-blocked-safety",
  1033. withExtension: "txt",
  1034. subdirectory: vertexSubdirectory
  1035. )
  1036. do {
  1037. let stream = try model.generateContentStream("Hi")
  1038. for try await _ in stream {
  1039. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1040. }
  1041. } catch let GenerateContentError.promptBlocked(response) {
  1042. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1043. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1044. XCTAssertNil(promptFeedback.blockReasonMessage)
  1045. return
  1046. }
  1047. XCTFail("Should have caught an error.")
  1048. }
  1049. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  1050. MockURLProtocol
  1051. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1052. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  1053. withExtension: "txt",
  1054. subdirectory: vertexSubdirectory
  1055. )
  1056. do {
  1057. let stream = try model.generateContentStream("Hi")
  1058. for try await _ in stream {
  1059. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1060. }
  1061. } catch let GenerateContentError.promptBlocked(response) {
  1062. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1063. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1064. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  1065. return
  1066. }
  1067. XCTFail("Should have caught an error.")
  1068. }
  1069. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  1070. MockURLProtocol
  1071. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1072. forResource: "streaming-failure-unknown-finish-enum",
  1073. withExtension: "txt",
  1074. subdirectory: vertexSubdirectory
  1075. )
  1076. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  1077. let stream = try model.generateContentStream("Hi")
  1078. do {
  1079. for try await content in stream {
  1080. XCTAssertNotNil(content.text)
  1081. }
  1082. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  1083. XCTAssertEqual(reason, unknownFinishReason)
  1084. return
  1085. }
  1086. XCTFail("Should have caught an error.")
  1087. }
  1088. func testGenerateContentStream_successBasicReplyLong() async throws {
  1089. MockURLProtocol
  1090. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1091. forResource: "streaming-success-basic-reply-long",
  1092. withExtension: "txt",
  1093. subdirectory: vertexSubdirectory
  1094. )
  1095. var responses = 0
  1096. let stream = try model.generateContentStream("Hi")
  1097. for try await content in stream {
  1098. XCTAssertNotNil(content.text)
  1099. responses += 1
  1100. }
  1101. XCTAssertEqual(responses, 4)
  1102. }
  1103. func testGenerateContentStream_successBasicReplyShort() async throws {
  1104. MockURLProtocol
  1105. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1106. forResource: "streaming-success-basic-reply-short",
  1107. withExtension: "txt",
  1108. subdirectory: vertexSubdirectory
  1109. )
  1110. var responses = 0
  1111. let stream = try model.generateContentStream("Hi")
  1112. for try await content in stream {
  1113. XCTAssertNotNil(content.text)
  1114. responses += 1
  1115. }
  1116. XCTAssertEqual(responses, 1)
  1117. }
  1118. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1119. MockURLProtocol
  1120. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1121. forResource: "streaming-success-unknown-safety-enum",
  1122. withExtension: "txt",
  1123. subdirectory: vertexSubdirectory
  1124. )
  1125. let unknownSafetyRating = SafetyRating(
  1126. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1127. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1128. probabilityScore: 0.0,
  1129. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1130. severityScore: 0.0,
  1131. blocked: false
  1132. )
  1133. var foundUnknownSafetyRating = false
  1134. let stream = try model.generateContentStream("Hi")
  1135. for try await content in stream {
  1136. XCTAssertNotNil(content.text)
  1137. if let ratings = content.candidates.first?.safetyRatings,
  1138. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1139. foundUnknownSafetyRating = true
  1140. }
  1141. }
  1142. XCTAssertTrue(foundUnknownSafetyRating)
  1143. }
  1144. func testGenerateContentStream_successWithCitations() async throws {
  1145. MockURLProtocol
  1146. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1147. forResource: "streaming-success-citations",
  1148. withExtension: "txt",
  1149. subdirectory: vertexSubdirectory
  1150. )
  1151. let expectedPublicationDate = DateComponents(
  1152. calendar: Calendar(identifier: .gregorian),
  1153. year: 2014,
  1154. month: 3,
  1155. day: 30
  1156. )
  1157. let stream = try model.generateContentStream("Hi")
  1158. var citations = [Citation]()
  1159. var responses = [GenerateContentResponse]()
  1160. for try await content in stream {
  1161. responses.append(content)
  1162. XCTAssertNotNil(content.text)
  1163. let candidate = try XCTUnwrap(content.candidates.first)
  1164. if let sources = candidate.citationMetadata?.citations {
  1165. citations.append(contentsOf: sources)
  1166. }
  1167. }
  1168. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1169. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1170. XCTAssertEqual(citations.count, 6)
  1171. XCTAssertTrue(citations
  1172. .contains {
  1173. $0.startIndex == 0 && $0.endIndex == 128
  1174. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1175. && $0.license == nil && $0.publicationDate == nil
  1176. })
  1177. XCTAssertTrue(citations
  1178. .contains {
  1179. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1180. && $0.title == "some-citation-2" && $0.license == nil
  1181. && $0.publicationDate == expectedPublicationDate
  1182. })
  1183. XCTAssertTrue(citations
  1184. .contains {
  1185. $0.startIndex == 272 && $0.endIndex == 431
  1186. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1187. && $0.license == "mit" && $0.publicationDate == nil
  1188. })
  1189. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1190. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1191. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1192. }
  1193. func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
  1194. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1195. forResource: "streaming-success-image-invalid-safety-ratings",
  1196. withExtension: "txt",
  1197. subdirectory: vertexSubdirectory
  1198. )
  1199. let stream = try model.generateContentStream(testPrompt)
  1200. var responses = [GenerateContentResponse]()
  1201. for try await content in stream {
  1202. responses.append(content)
  1203. }
  1204. let response = try XCTUnwrap(responses.first)
  1205. XCTAssertEqual(response.candidates.count, 1)
  1206. let candidate = try XCTUnwrap(response.candidates.first)
  1207. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  1208. XCTAssertEqual(candidate.content.parts.count, 1)
  1209. let inlineDataParts = response.inlineDataParts
  1210. XCTAssertEqual(inlineDataParts.count, 1)
  1211. let imagePart = try XCTUnwrap(inlineDataParts.first)
  1212. XCTAssertEqual(imagePart.mimeType, "image/png")
  1213. XCTAssertGreaterThan(imagePart.data.count, 0)
  1214. }
  1215. func testGenerateContentStream_appCheck_validToken() async throws {
  1216. let appCheckToken = "test-valid-token"
  1217. model = GenerativeModel(
  1218. modelName: testModelName,
  1219. modelResourceName: testModelResourceName,
  1220. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1221. appCheck: AppCheckInteropFake(token: appCheckToken)
  1222. ),
  1223. apiConfig: apiConfig,
  1224. tools: nil,
  1225. requestOptions: RequestOptions(),
  1226. urlSession: urlSession
  1227. )
  1228. MockURLProtocol
  1229. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1230. forResource: "streaming-success-basic-reply-short",
  1231. withExtension: "txt",
  1232. subdirectory: vertexSubdirectory,
  1233. appCheckToken: appCheckToken
  1234. )
  1235. let stream = try model.generateContentStream(testPrompt)
  1236. for try await _ in stream {}
  1237. }
  1238. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1239. model = GenerativeModel(
  1240. modelName: testModelName,
  1241. modelResourceName: testModelResourceName,
  1242. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1243. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  1244. ),
  1245. apiConfig: apiConfig,
  1246. tools: nil,
  1247. requestOptions: RequestOptions(),
  1248. urlSession: urlSession
  1249. )
  1250. MockURLProtocol
  1251. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1252. forResource: "streaming-success-basic-reply-short",
  1253. withExtension: "txt",
  1254. subdirectory: vertexSubdirectory,
  1255. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1256. )
  1257. let stream = try model.generateContentStream(testPrompt)
  1258. for try await _ in stream {}
  1259. }
  1260. func testGenerateContentStream_usageMetadata() async throws {
  1261. MockURLProtocol
  1262. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1263. forResource: "streaming-success-basic-reply-short",
  1264. withExtension: "txt",
  1265. subdirectory: vertexSubdirectory
  1266. )
  1267. var responses = [GenerateContentResponse]()
  1268. let stream = try model.generateContentStream(testPrompt)
  1269. for try await response in stream {
  1270. responses.append(response)
  1271. }
  1272. for (index, response) in responses.enumerated() {
  1273. if index == responses.endIndex - 1 {
  1274. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1275. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1276. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1277. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1278. } else {
  1279. // Only the last streamed response contains usage metadata
  1280. XCTAssertNil(response.usageMetadata)
  1281. }
  1282. }
  1283. }
  1284. func testGenerateContentStream_errorMidStream() async throws {
  1285. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1286. forResource: "streaming-failure-error-mid-stream",
  1287. withExtension: "txt",
  1288. subdirectory: vertexSubdirectory
  1289. )
  1290. var responseCount = 0
  1291. do {
  1292. let stream = try model.generateContentStream("Hi")
  1293. for try await content in stream {
  1294. XCTAssertNotNil(content.text)
  1295. responseCount += 1
  1296. }
  1297. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1298. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1299. XCTAssertEqual(rpcError.status, .cancelled)
  1300. // Check the content count is correct.
  1301. XCTAssertEqual(responseCount, 2)
  1302. return
  1303. }
  1304. XCTFail("Expected an internalError with an RPCError.")
  1305. }
  1306. func testGenerateContentStream_nonHTTPResponse() async throws {
  1307. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  1308. let stream = try model.generateContentStream("Hi")
  1309. do {
  1310. for try await content in stream {
  1311. XCTFail("Unexpected content in stream: \(content)")
  1312. }
  1313. } catch let GenerateContentError.internalError(underlying) {
  1314. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1315. return
  1316. }
  1317. XCTFail("Expected an internal error.")
  1318. }
  1319. func testGenerateContentStream_invalidResponse() async throws {
  1320. MockURLProtocol
  1321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1322. forResource: "streaming-failure-invalid-json",
  1323. withExtension: "txt",
  1324. subdirectory: vertexSubdirectory
  1325. )
  1326. let stream = try model.generateContentStream(testPrompt)
  1327. do {
  1328. for try await content in stream {
  1329. XCTFail("Unexpected content in stream: \(content)")
  1330. }
  1331. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1332. guard case let .dataCorrupted(context) = underlying else {
  1333. XCTFail("Not a data corrupted error: \(underlying)")
  1334. return
  1335. }
  1336. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1337. return
  1338. }
  1339. XCTFail("Expected an internal error.")
  1340. }
  1341. func testGenerateContentStream_malformedContent() async throws {
  1342. MockURLProtocol
  1343. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1344. // Note: Although this file does not contain `parts` in `content`, it is not actually
  1345. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  1346. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  1347. forResource: "streaming-failure-malformed-content",
  1348. withExtension: "txt",
  1349. subdirectory: vertexSubdirectory
  1350. )
  1351. let stream = try model.generateContentStream(testPrompt)
  1352. do {
  1353. for try await content in stream {
  1354. XCTFail("Unexpected content in stream: \(content)")
  1355. }
  1356. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1357. guard case let .emptyContent(contentError) = underlyingError else {
  1358. XCTFail("Not an empty content error: \(underlyingError)")
  1359. return
  1360. }
  1361. XCTAssert(contentError is DecodingError)
  1362. return
  1363. }
  1364. XCTFail("Expected an internal decoding error.")
  1365. }
  1366. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1367. let expectedTimeout = 150.0
  1368. MockURLProtocol
  1369. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1370. forResource: "streaming-success-basic-reply-short",
  1371. withExtension: "txt",
  1372. subdirectory: vertexSubdirectory,
  1373. timeout: expectedTimeout
  1374. )
  1375. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1376. model = GenerativeModel(
  1377. modelName: testModelName,
  1378. modelResourceName: testModelResourceName,
  1379. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1380. apiConfig: apiConfig,
  1381. tools: nil,
  1382. requestOptions: requestOptions,
  1383. urlSession: urlSession
  1384. )
  1385. var responses = 0
  1386. let stream = try model.generateContentStream(testPrompt)
  1387. for try await content in stream {
  1388. XCTAssertNotNil(content.text)
  1389. responses += 1
  1390. }
  1391. XCTAssertEqual(responses, 1)
  1392. }
  1393. // MARK: - Count Tokens
  1394. func testCountTokens_succeeds() async throws {
  1395. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1396. forResource: "unary-success-total-tokens",
  1397. withExtension: "json",
  1398. subdirectory: vertexSubdirectory
  1399. )
  1400. let response = try await model.countTokens("Why is the sky blue?")
  1401. XCTAssertEqual(response.totalTokens, 6)
  1402. }
  1403. func testCountTokens_succeeds_detailed() async throws {
  1404. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1405. forResource: "unary-success-detailed-token-response",
  1406. withExtension: "json",
  1407. subdirectory: vertexSubdirectory
  1408. )
  1409. let response = try await model.countTokens("Why is the sky blue?")
  1410. XCTAssertEqual(response.totalTokens, 1837)
  1411. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1412. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1413. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1414. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1415. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1416. }
  1417. func testCountTokens_succeeds_allOptions() async throws {
  1418. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1419. forResource: "unary-success-total-tokens",
  1420. withExtension: "json",
  1421. subdirectory: vertexSubdirectory
  1422. )
  1423. let generationConfig = GenerationConfig(
  1424. temperature: 0.5,
  1425. topP: 0.9,
  1426. topK: 3,
  1427. candidateCount: 1,
  1428. maxOutputTokens: 1024,
  1429. stopSequences: ["test-stop"],
  1430. responseMIMEType: "text/plain"
  1431. )
  1432. let sumFunction = FunctionDeclaration(
  1433. name: "sum",
  1434. description: "Add two integers.",
  1435. parameters: ["x": .integer(), "y": .integer()]
  1436. )
  1437. let systemInstruction = ModelContent(
  1438. role: "system",
  1439. parts: "You are a calculator. Use the provided tools."
  1440. )
  1441. model = GenerativeModel(
  1442. modelName: testModelName,
  1443. modelResourceName: testModelResourceName,
  1444. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1445. apiConfig: apiConfig,
  1446. generationConfig: generationConfig,
  1447. tools: [Tool(functionDeclarations: [sumFunction])],
  1448. systemInstruction: systemInstruction,
  1449. requestOptions: RequestOptions(),
  1450. urlSession: urlSession
  1451. )
  1452. let response = try await model.countTokens("Why is the sky blue?")
  1453. XCTAssertEqual(response.totalTokens, 6)
  1454. }
  1455. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1456. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1457. forResource: "unary-success-no-billable-characters",
  1458. withExtension: "json",
  1459. subdirectory: vertexSubdirectory
  1460. )
  1461. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1462. XCTAssertEqual(response.totalTokens, 258)
  1463. }
  1464. func testCountTokens_modelNotFound() async throws {
  1465. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1466. forResource: "unary-failure-model-not-found", withExtension: "json",
  1467. subdirectory: vertexSubdirectory,
  1468. statusCode: 404
  1469. )
  1470. do {
  1471. _ = try await model.countTokens("Why is the sky blue?")
  1472. XCTFail("Request should not have succeeded.")
  1473. } catch let rpcError as BackendError {
  1474. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1475. XCTAssertEqual(rpcError.status, .notFound)
  1476. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1477. return
  1478. }
  1479. XCTFail("Expected internal RPCError.")
  1480. }
  1481. func testCountTokens_requestOptions_customTimeout() async throws {
  1482. let expectedTimeout = 150.0
  1483. MockURLProtocol
  1484. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1485. forResource: "unary-success-total-tokens",
  1486. withExtension: "json",
  1487. subdirectory: vertexSubdirectory,
  1488. timeout: expectedTimeout
  1489. )
  1490. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1491. model = GenerativeModel(
  1492. modelName: testModelName,
  1493. modelResourceName: testModelResourceName,
  1494. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1495. apiConfig: apiConfig,
  1496. tools: nil,
  1497. requestOptions: requestOptions,
  1498. urlSession: urlSession
  1499. )
  1500. let response = try await model.countTokens(testPrompt)
  1501. XCTAssertEqual(response.totalTokens, 6)
  1502. }
  1503. }
  1504. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1505. extension SafetyRating: Swift.Comparable {
  1506. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1507. return lhs.category.rawValue < rhs.category.rawValue
  1508. }
  1509. }