GenerativeModelVertexAITests.swift 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelVertexAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let safetyRatingsInvalidIgnored = [
  57. SafetyRating(
  58. category: .hateSpeech,
  59. probability: .negligible,
  60. probabilityScore: 0.00039444832,
  61. severity: .negligible,
  62. severityScore: 0.0,
  63. blocked: false
  64. ),
  65. SafetyRating(
  66. category: .dangerousContent,
  67. probability: .negligible,
  68. probabilityScore: 0.0010654529,
  69. severity: .negligible,
  70. severityScore: 0.0049325973,
  71. blocked: false
  72. ),
  73. SafetyRating(
  74. category: .harassment,
  75. probability: .negligible,
  76. probabilityScore: 0.00026658305,
  77. severity: .negligible,
  78. severityScore: 0.0,
  79. blocked: false
  80. ),
  81. SafetyRating(
  82. category: .sexuallyExplicit,
  83. probability: .negligible,
  84. probabilityScore: 0.0013701695,
  85. severity: .negligible,
  86. severityScore: 0.07626295,
  87. blocked: false
  88. ),
  89. // Ignored Invalid Safety Ratings: {},{},{},{}
  90. ].sorted()
  91. let testModelName = "test-model"
  92. let testModelResourceName =
  93. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  94. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  95. let vertexSubdirectory = "mock-responses/vertexai"
  96. var urlSession: URLSession!
  97. var model: GenerativeModel!
  98. override func setUp() async throws {
  99. let configuration = URLSessionConfiguration.default
  100. configuration.protocolClasses = [MockURLProtocol.self]
  101. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  102. model = GenerativeModel(
  103. modelName: testModelName,
  104. modelResourceName: testModelResourceName,
  105. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  106. apiConfig: apiConfig,
  107. tools: nil,
  108. requestOptions: RequestOptions(),
  109. urlSession: urlSession
  110. )
  111. }
  112. override func tearDown() {
  113. MockURLProtocol.requestHandler = nil
  114. }
  115. // MARK: - Generate Content
  116. func testGenerateContent_success_basicReplyLong() async throws {
  117. MockURLProtocol
  118. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  119. forResource: "unary-success-basic-reply-long",
  120. withExtension: "json",
  121. subdirectory: vertexSubdirectory
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. XCTAssertEqual(candidate.safetyRatings.count, 4)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let part = try XCTUnwrap(candidate.content.parts.first)
  131. let partText = try XCTUnwrap(part as? TextPart).text
  132. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  133. XCTAssertEqual(response.text, partText)
  134. XCTAssertEqual(response.functionCalls, [])
  135. XCTAssertEqual(response.inlineDataParts, [])
  136. }
  137. func testGenerateContent_success_basicReplyShort() async throws {
  138. MockURLProtocol
  139. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  140. forResource: "unary-success-basic-reply-short",
  141. withExtension: "json",
  142. subdirectory: vertexSubdirectory
  143. )
  144. let response = try await model.generateContent(testPrompt)
  145. XCTAssertEqual(response.candidates.count, 1)
  146. let candidate = try XCTUnwrap(response.candidates.first)
  147. let finishReason = try XCTUnwrap(candidate.finishReason)
  148. XCTAssertEqual(finishReason, .stop)
  149. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  150. XCTAssertEqual(candidate.content.parts.count, 1)
  151. let part = try XCTUnwrap(candidate.content.parts.first)
  152. let textPart = try XCTUnwrap(part as? TextPart)
  153. XCTAssertEqual(textPart.text, "Mountain View, California")
  154. XCTAssertEqual(response.text, textPart.text)
  155. XCTAssertEqual(response.functionCalls, [])
  156. }
  157. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  158. MockURLProtocol
  159. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  160. forResource: "unary-success-basic-response-long-usage-metadata",
  161. withExtension: "json",
  162. subdirectory: vertexSubdirectory
  163. )
  164. let response = try await model.generateContent(testPrompt)
  165. XCTAssertEqual(response.candidates.count, 1)
  166. let candidate = try XCTUnwrap(response.candidates.first)
  167. let finishReason = try XCTUnwrap(candidate.finishReason)
  168. XCTAssertEqual(finishReason, .stop)
  169. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  170. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  171. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  175. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  178. }
  179. func testGenerateContent_success_citations() async throws {
  180. MockURLProtocol
  181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  182. forResource: "unary-success-citations",
  183. withExtension: "json",
  184. subdirectory: vertexSubdirectory
  185. )
  186. let expectedPublicationDate = DateComponents(
  187. calendar: Calendar(identifier: .gregorian),
  188. year: 2019,
  189. month: 5,
  190. day: 10
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. XCTAssertEqual(candidate.content.parts.count, 1)
  196. XCTAssertEqual(response.text, "Some information cited from an external source")
  197. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  198. XCTAssertEqual(citationMetadata.citations.count, 3)
  199. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  200. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  201. XCTAssertEqual(citationSource1.startIndex, 0)
  202. XCTAssertEqual(citationSource1.endIndex, 128)
  203. XCTAssertNil(citationSource1.title)
  204. XCTAssertNil(citationSource1.license)
  205. XCTAssertNil(citationSource1.publicationDate)
  206. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  207. XCTAssertEqual(citationSource2.title, "some-citation-2")
  208. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  209. XCTAssertEqual(citationSource2.startIndex, 130)
  210. XCTAssertEqual(citationSource2.endIndex, 265)
  211. XCTAssertNil(citationSource2.uri)
  212. XCTAssertNil(citationSource2.license)
  213. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  214. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  215. XCTAssertEqual(citationSource3.startIndex, 272)
  216. XCTAssertEqual(citationSource3.endIndex, 431)
  217. XCTAssertEqual(citationSource3.license, "mit")
  218. XCTAssertNil(citationSource3.title)
  219. XCTAssertNil(citationSource3.publicationDate)
  220. }
  221. func testGenerateContent_success_quoteReply() async throws {
  222. MockURLProtocol
  223. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  224. forResource: "unary-success-quote-reply",
  225. withExtension: "json",
  226. subdirectory: vertexSubdirectory
  227. )
  228. let response = try await model.generateContent(testPrompt)
  229. XCTAssertEqual(response.candidates.count, 1)
  230. let candidate = try XCTUnwrap(response.candidates.first)
  231. let finishReason = try XCTUnwrap(candidate.finishReason)
  232. XCTAssertEqual(finishReason, .stop)
  233. XCTAssertEqual(candidate.safetyRatings.count, 4)
  234. XCTAssertEqual(candidate.content.parts.count, 1)
  235. let part = try XCTUnwrap(candidate.content.parts.first)
  236. let textPart = try XCTUnwrap(part as? TextPart)
  237. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  238. XCTAssertEqual(response.text, textPart.text)
  239. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  240. XCTAssertNil(promptFeedback.blockReason)
  241. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  242. }
  243. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  244. let expectedSafetyRatings = [
  245. SafetyRating(
  246. category: .harassment,
  247. probability: .medium,
  248. probabilityScore: 0.0,
  249. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  250. severityScore: 0.0,
  251. blocked: false
  252. ),
  253. SafetyRating(
  254. category: .dangerousContent,
  255. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  256. probabilityScore: 0.0,
  257. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  258. severityScore: 0.0,
  259. blocked: false
  260. ),
  261. SafetyRating(
  262. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  263. probability: .high,
  264. probabilityScore: 0.0,
  265. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  266. severityScore: 0.0,
  267. blocked: false
  268. ),
  269. ]
  270. MockURLProtocol
  271. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "unary-success-unknown-enum-safety-ratings",
  273. withExtension: "json",
  274. subdirectory: vertexSubdirectory
  275. )
  276. let response = try await model.generateContent(testPrompt)
  277. XCTAssertEqual(response.text, "Some text")
  278. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  279. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  280. }
  281. func testGenerateContent_success_prefixedModelName() async throws {
  282. MockURLProtocol
  283. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  284. forResource: "unary-success-basic-reply-short",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let model = GenerativeModel(
  289. modelName: testModelName,
  290. modelResourceName: testModelResourceName,
  291. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  292. apiConfig: apiConfig,
  293. tools: nil,
  294. requestOptions: RequestOptions(),
  295. urlSession: urlSession
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  300. MockURLProtocol
  301. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  302. forResource: "unary-success-function-call-empty-arguments",
  303. withExtension: "json",
  304. subdirectory: vertexSubdirectory
  305. )
  306. let response = try await model.generateContent(testPrompt)
  307. XCTAssertEqual(response.candidates.count, 1)
  308. let candidate = try XCTUnwrap(response.candidates.first)
  309. XCTAssertEqual(candidate.content.parts.count, 1)
  310. let part = try XCTUnwrap(candidate.content.parts.first)
  311. guard let functionCall = part as? FunctionCallPart else {
  312. XCTFail("Part is not a FunctionCall.")
  313. return
  314. }
  315. XCTAssertEqual(functionCall.name, "current_time")
  316. XCTAssertTrue(functionCall.args.isEmpty)
  317. XCTAssertEqual(response.functionCalls, [functionCall])
  318. }
  319. func testGenerateContent_success_functionCall_noArguments() async throws {
  320. MockURLProtocol
  321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  322. forResource: "unary-success-function-call-no-arguments",
  323. withExtension: "json",
  324. subdirectory: vertexSubdirectory
  325. )
  326. let response = try await model.generateContent(testPrompt)
  327. XCTAssertEqual(response.candidates.count, 1)
  328. let candidate = try XCTUnwrap(response.candidates.first)
  329. XCTAssertEqual(candidate.content.parts.count, 1)
  330. let part = try XCTUnwrap(candidate.content.parts.first)
  331. guard let functionCall = part as? FunctionCallPart else {
  332. XCTFail("Part is not a FunctionCall.")
  333. return
  334. }
  335. XCTAssertEqual(functionCall.name, "current_time")
  336. XCTAssertTrue(functionCall.args.isEmpty)
  337. XCTAssertEqual(response.functionCalls, [functionCall])
  338. }
  339. func testGenerateContent_success_functionCall_withArguments() async throws {
  340. MockURLProtocol
  341. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  342. forResource: "unary-success-function-call-with-arguments",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 1)
  350. let part = try XCTUnwrap(candidate.content.parts.first)
  351. guard let functionCall = part as? FunctionCallPart else {
  352. XCTFail("Part is not a FunctionCall.")
  353. return
  354. }
  355. XCTAssertEqual(functionCall.name, "sum")
  356. XCTAssertEqual(functionCall.args.count, 2)
  357. let argX = try XCTUnwrap(functionCall.args["x"])
  358. XCTAssertEqual(argX, .number(4))
  359. let argY = try XCTUnwrap(functionCall.args["y"])
  360. XCTAssertEqual(argY, .number(5))
  361. XCTAssertEqual(response.functionCalls, [functionCall])
  362. }
  363. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  364. MockURLProtocol
  365. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  366. forResource: "unary-success-function-call-parallel-calls",
  367. withExtension: "json",
  368. subdirectory: vertexSubdirectory
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. XCTAssertEqual(response.candidates.count, 1)
  372. let candidate = try XCTUnwrap(response.candidates.first)
  373. XCTAssertEqual(candidate.content.parts.count, 3)
  374. let functionCalls = response.functionCalls
  375. XCTAssertEqual(functionCalls.count, 3)
  376. }
  377. func testGenerateContent_success_functionCall_mixedContent() async throws {
  378. MockURLProtocol
  379. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  380. forResource: "unary-success-function-call-mixed-content",
  381. withExtension: "json",
  382. subdirectory: vertexSubdirectory
  383. )
  384. let response = try await model.generateContent(testPrompt)
  385. XCTAssertEqual(response.candidates.count, 1)
  386. let candidate = try XCTUnwrap(response.candidates.first)
  387. XCTAssertEqual(candidate.content.parts.count, 4)
  388. let functionCalls = response.functionCalls
  389. XCTAssertEqual(functionCalls.count, 2)
  390. let text = try XCTUnwrap(response.text)
  391. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  392. }
  393. func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
  394. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  395. forResource: "unary-success-image-invalid-safety-ratings",
  396. withExtension: "json",
  397. subdirectory: vertexSubdirectory
  398. )
  399. let response = try await model.generateContent(testPrompt)
  400. XCTAssertEqual(response.candidates.count, 1)
  401. let candidate = try XCTUnwrap(response.candidates.first)
  402. XCTAssertEqual(candidate.content.parts.count, 1)
  403. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  404. let inlineDataParts = response.inlineDataParts
  405. XCTAssertEqual(inlineDataParts.count, 1)
  406. let imagePart = try XCTUnwrap(inlineDataParts.first)
  407. XCTAssertEqual(imagePart.mimeType, "image/png")
  408. XCTAssertGreaterThan(imagePart.data.count, 0)
  409. }
  410. func testGenerateContent_appCheck_validToken() async throws {
  411. let appCheckToken = "test-valid-token"
  412. model = GenerativeModel(
  413. modelName: testModelName,
  414. modelResourceName: testModelResourceName,
  415. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  416. appCheck: AppCheckInteropFake(token: appCheckToken)
  417. ),
  418. apiConfig: apiConfig,
  419. tools: nil,
  420. requestOptions: RequestOptions(),
  421. urlSession: urlSession
  422. )
  423. MockURLProtocol
  424. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  425. forResource: "unary-success-basic-reply-short",
  426. withExtension: "json",
  427. subdirectory: vertexSubdirectory,
  428. appCheckToken: appCheckToken
  429. )
  430. _ = try await model.generateContent(testPrompt)
  431. }
  432. func testGenerateContent_dataCollectionOff() async throws {
  433. let appCheckToken = "test-valid-token"
  434. model = GenerativeModel(
  435. modelName: testModelName,
  436. modelResourceName: testModelResourceName,
  437. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  438. appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
  439. ),
  440. apiConfig: apiConfig,
  441. tools: nil,
  442. requestOptions: RequestOptions(),
  443. urlSession: urlSession
  444. )
  445. MockURLProtocol
  446. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  447. forResource: "unary-success-basic-reply-short",
  448. withExtension: "json",
  449. subdirectory: vertexSubdirectory,
  450. appCheckToken: appCheckToken,
  451. dataCollection: false
  452. )
  453. _ = try await model.generateContent(testPrompt)
  454. }
  455. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  456. model = GenerativeModel(
  457. modelName: testModelName,
  458. modelResourceName: testModelResourceName,
  459. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  460. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  461. ),
  462. apiConfig: apiConfig,
  463. tools: nil,
  464. requestOptions: RequestOptions(),
  465. urlSession: urlSession
  466. )
  467. MockURLProtocol
  468. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  469. forResource: "unary-success-basic-reply-short",
  470. withExtension: "json",
  471. subdirectory: vertexSubdirectory,
  472. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  473. )
  474. _ = try await model.generateContent(testPrompt)
  475. }
  476. func testGenerateContent_auth_validAuthToken() async throws {
  477. let authToken = "test-valid-token"
  478. model = GenerativeModel(
  479. modelName: testModelName,
  480. modelResourceName: testModelResourceName,
  481. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  482. auth: AuthInteropFake(token: authToken)
  483. ),
  484. apiConfig: apiConfig,
  485. tools: nil,
  486. requestOptions: RequestOptions(),
  487. urlSession: urlSession
  488. )
  489. MockURLProtocol
  490. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  491. forResource: "unary-success-basic-reply-short",
  492. withExtension: "json",
  493. subdirectory: vertexSubdirectory,
  494. authToken: authToken
  495. )
  496. _ = try await model.generateContent(testPrompt)
  497. }
  498. func testGenerateContent_auth_nilAuthToken() async throws {
  499. model = GenerativeModel(
  500. modelName: testModelName,
  501. modelResourceName: testModelResourceName,
  502. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  503. apiConfig: apiConfig,
  504. tools: nil,
  505. requestOptions: RequestOptions(),
  506. urlSession: urlSession
  507. )
  508. MockURLProtocol
  509. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  510. forResource: "unary-success-basic-reply-short",
  511. withExtension: "json",
  512. subdirectory: vertexSubdirectory,
  513. authToken: nil
  514. )
  515. _ = try await model.generateContent(testPrompt)
  516. }
  517. func testGenerateContent_auth_authTokenRefreshError() async throws {
  518. model = GenerativeModel(
  519. modelName: testModelName,
  520. modelResourceName: testModelResourceName,
  521. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  522. auth: AuthInteropFake(error: AuthErrorFake())
  523. ),
  524. apiConfig: apiConfig,
  525. tools: nil,
  526. requestOptions: RequestOptions(),
  527. urlSession: urlSession
  528. )
  529. MockURLProtocol
  530. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  531. forResource: "unary-success-basic-reply-short",
  532. withExtension: "json",
  533. subdirectory: vertexSubdirectory,
  534. authToken: nil
  535. )
  536. do {
  537. _ = try await model.generateContent(testPrompt)
  538. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  539. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  540. //
  541. } catch {
  542. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  543. }
  544. }
  545. func testGenerateContent_usageMetadata() async throws {
  546. MockURLProtocol
  547. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  548. forResource: "unary-success-basic-reply-short",
  549. withExtension: "json",
  550. subdirectory: vertexSubdirectory
  551. )
  552. let response = try await model.generateContent(testPrompt)
  553. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  554. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  555. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  556. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  557. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  558. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  559. }
  560. func testGenerateContent_failure_invalidAPIKey() async throws {
  561. let expectedStatusCode = 400
  562. MockURLProtocol
  563. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  564. forResource: "unary-failure-api-key",
  565. withExtension: "json",
  566. subdirectory: vertexSubdirectory,
  567. statusCode: expectedStatusCode
  568. )
  569. do {
  570. _ = try await model.generateContent(testPrompt)
  571. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  572. } catch let GenerateContentError.internalError(error as BackendError) {
  573. XCTAssertEqual(error.httpResponseCode, 400)
  574. XCTAssertEqual(error.status, .invalidArgument)
  575. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  576. XCTAssertTrue(error.localizedDescription.contains(error.message))
  577. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  578. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  579. let nsError = error as NSError
  580. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  581. XCTAssertEqual(nsError.code, error.httpResponseCode)
  582. return
  583. } catch {
  584. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  585. }
  586. }
  587. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  588. let expectedStatusCode = 403
  589. MockURLProtocol
  590. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  591. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  592. withExtension: "json",
  593. subdirectory: vertexSubdirectory,
  594. statusCode: expectedStatusCode
  595. )
  596. do {
  597. _ = try await model.generateContent(testPrompt)
  598. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  599. } catch let GenerateContentError.internalError(error as BackendError) {
  600. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  601. XCTAssertEqual(error.status, .permissionDenied)
  602. XCTAssertTrue(error.message
  603. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  604. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  605. return
  606. } catch {
  607. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  608. }
  609. }
  610. func testGenerateContent_failure_emptyContent() async throws {
  611. MockURLProtocol
  612. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  613. forResource: "unary-failure-empty-content",
  614. withExtension: "json",
  615. subdirectory: vertexSubdirectory
  616. )
  617. do {
  618. _ = try await model.generateContent(testPrompt)
  619. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  620. } catch let GenerateContentError
  621. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  622. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  623. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  624. return
  625. }
  626. _ = try XCTUnwrap(decodingError as? DecodingError,
  627. "Not a DecodingError: \(decodingError)")
  628. } catch {
  629. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  630. }
  631. }
  632. func testGenerateContent_failure_finishReasonSafety() async throws {
  633. MockURLProtocol
  634. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  635. forResource: "unary-failure-finish-reason-safety",
  636. withExtension: "json",
  637. subdirectory: vertexSubdirectory
  638. )
  639. do {
  640. _ = try await model.generateContent(testPrompt)
  641. XCTFail("Should throw")
  642. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  643. XCTAssertEqual(reason, .safety)
  644. XCTAssertEqual(response.text, "<redacted>")
  645. } catch {
  646. XCTFail("Should throw a responseStoppedEarly")
  647. }
  648. }
  649. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  650. MockURLProtocol
  651. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  652. forResource: "unary-failure-finish-reason-safety-no-content",
  653. withExtension: "json",
  654. subdirectory: vertexSubdirectory
  655. )
  656. do {
  657. _ = try await model.generateContent(testPrompt)
  658. XCTFail("Should throw")
  659. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  660. XCTAssertEqual(reason, .safety)
  661. XCTAssertNil(response.text)
  662. } catch {
  663. XCTFail("Should throw a responseStoppedEarly")
  664. }
  665. }
  666. func testGenerateContent_failure_imageRejected() async throws {
  667. let expectedStatusCode = 400
  668. MockURLProtocol
  669. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  670. forResource: "unary-failure-image-rejected",
  671. withExtension: "json",
  672. subdirectory: vertexSubdirectory,
  673. statusCode: 400
  674. )
  675. do {
  676. _ = try await model.generateContent(testPrompt)
  677. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  678. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  679. XCTAssertEqual(rpcError.status, .invalidArgument)
  680. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  681. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  682. } catch {
  683. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  684. }
  685. }
  686. func testGenerateContent_failure_promptBlockedSafety() async throws {
  687. MockURLProtocol
  688. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  689. forResource: "unary-failure-prompt-blocked-safety",
  690. withExtension: "json",
  691. subdirectory: vertexSubdirectory
  692. )
  693. do {
  694. _ = try await model.generateContent(testPrompt)
  695. XCTFail("Should throw")
  696. } catch let GenerateContentError.promptBlocked(response) {
  697. XCTAssertNil(response.text)
  698. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  699. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  700. XCTAssertNil(promptFeedback.blockReasonMessage)
  701. } catch {
  702. XCTFail("Should throw a promptBlocked")
  703. }
  704. }
  705. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  706. MockURLProtocol
  707. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  708. forResource: "unary-failure-prompt-blocked-safety-with-message",
  709. withExtension: "json",
  710. subdirectory: vertexSubdirectory
  711. )
  712. do {
  713. _ = try await model.generateContent(testPrompt)
  714. XCTFail("Should throw")
  715. } catch let GenerateContentError.promptBlocked(response) {
  716. XCTAssertNil(response.text)
  717. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  718. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  719. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  720. } catch {
  721. XCTFail("Should throw a promptBlocked")
  722. }
  723. }
  724. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  725. MockURLProtocol
  726. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  727. forResource: "unary-failure-unknown-enum-finish-reason",
  728. withExtension: "json",
  729. subdirectory: vertexSubdirectory
  730. )
  731. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  732. do {
  733. _ = try await model.generateContent(testPrompt)
  734. XCTFail("Should throw")
  735. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  736. XCTAssertEqual(reason, unknownFinishReason)
  737. XCTAssertEqual(response.text, "Some text")
  738. } catch {
  739. XCTFail("Should throw a responseStoppedEarly")
  740. }
  741. }
  742. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  743. MockURLProtocol
  744. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  745. forResource: "unary-failure-unknown-enum-prompt-blocked",
  746. withExtension: "json",
  747. subdirectory: vertexSubdirectory
  748. )
  749. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  750. do {
  751. _ = try await model.generateContent(testPrompt)
  752. XCTFail("Should throw")
  753. } catch let GenerateContentError.promptBlocked(response) {
  754. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  755. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  756. } catch {
  757. XCTFail("Should throw a promptBlocked")
  758. }
  759. }
  760. func testGenerateContent_failure_unknownModel() async throws {
  761. let expectedStatusCode = 404
  762. MockURLProtocol
  763. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  764. forResource: "unary-failure-unknown-model",
  765. withExtension: "json",
  766. subdirectory: vertexSubdirectory,
  767. statusCode: 404
  768. )
  769. do {
  770. _ = try await model.generateContent(testPrompt)
  771. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  772. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  773. XCTAssertEqual(rpcError.status, .notFound)
  774. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  775. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  776. } catch {
  777. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  778. }
  779. }
  780. func testGenerateContent_failure_nonHTTPResponse() async throws {
  781. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  782. var responseError: Error?
  783. var content: GenerateContentResponse?
  784. do {
  785. content = try await model.generateContent(testPrompt)
  786. } catch {
  787. responseError = error
  788. }
  789. XCTAssertNil(content)
  790. XCTAssertNotNil(responseError)
  791. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  792. guard case let .internalError(underlyingError) = generateContentError else {
  793. XCTFail("Not an internal error: \(generateContentError)")
  794. return
  795. }
  796. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  797. let underlyingNSError = underlyingError as NSError
  798. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  799. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  800. }
  801. func testGenerateContent_failure_invalidResponse() async throws {
  802. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  803. forResource: "unary-failure-invalid-response",
  804. withExtension: "json",
  805. subdirectory: vertexSubdirectory
  806. )
  807. var responseError: Error?
  808. var content: GenerateContentResponse?
  809. do {
  810. content = try await model.generateContent(testPrompt)
  811. } catch {
  812. responseError = error
  813. }
  814. XCTAssertNil(content)
  815. XCTAssertNotNil(responseError)
  816. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  817. guard case let .internalError(underlyingError) = generateContentError else {
  818. XCTFail("Not an internal error: \(generateContentError)")
  819. return
  820. }
  821. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  822. guard case let .dataCorrupted(context) = decodingError else {
  823. XCTFail("Not a data corrupted error: \(decodingError)")
  824. return
  825. }
  826. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  827. }
  828. func testGenerateContent_failure_malformedContent() async throws {
  829. MockURLProtocol
  830. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  831. // Note: Although this file does not contain `parts` in `content`, it is not actually
  832. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  833. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  834. forResource: "unary-failure-malformed-content",
  835. withExtension: "json",
  836. subdirectory: vertexSubdirectory
  837. )
  838. var responseError: Error?
  839. var content: GenerateContentResponse?
  840. do {
  841. content = try await model.generateContent(testPrompt)
  842. } catch {
  843. responseError = error
  844. }
  845. XCTAssertNil(content)
  846. XCTAssertNotNil(responseError)
  847. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  848. guard case let .internalError(underlyingError) = generateContentError else {
  849. XCTFail("Not an internal error: \(generateContentError)")
  850. return
  851. }
  852. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  853. guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
  854. XCTFail("Not an empty content error: \(invalidCandidateError)")
  855. return
  856. }
  857. _ = try XCTUnwrap(
  858. emptyContentUnderlyingError as? DecodingError,
  859. "Not a decoding error: \(emptyContentUnderlyingError)"
  860. )
  861. }
  862. func testGenerateContentMissingSafetyRatings() async throws {
  863. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  864. forResource: "unary-success-missing-safety-ratings",
  865. withExtension: "json",
  866. subdirectory: vertexSubdirectory
  867. )
  868. let content = try await model.generateContent(testPrompt)
  869. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  870. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  871. XCTAssertEqual(content.text, "This is the generated content.")
  872. }
  873. func testGenerateContent_requestOptions_customTimeout() async throws {
  874. let expectedTimeout = 150.0
  875. MockURLProtocol
  876. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  877. forResource: "unary-success-basic-reply-short",
  878. withExtension: "json",
  879. subdirectory: vertexSubdirectory,
  880. timeout: expectedTimeout
  881. )
  882. let requestOptions = RequestOptions(timeout: expectedTimeout)
  883. model = GenerativeModel(
  884. modelName: testModelName,
  885. modelResourceName: testModelResourceName,
  886. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  887. apiConfig: apiConfig,
  888. tools: nil,
  889. requestOptions: requestOptions,
  890. urlSession: urlSession
  891. )
  892. let response = try await model.generateContent(testPrompt)
  893. XCTAssertEqual(response.candidates.count, 1)
  894. }
  895. // MARK: - Generate Content (Streaming)
  896. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  897. MockURLProtocol
  898. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  899. forResource: "unary-failure-api-key",
  900. withExtension: "json",
  901. subdirectory: vertexSubdirectory
  902. )
  903. do {
  904. let stream = try model.generateContentStream("Hi")
  905. for try await _ in stream {
  906. XCTFail("No content is there, this shouldn't happen.")
  907. }
  908. } catch let GenerateContentError.internalError(error as BackendError) {
  909. XCTAssertEqual(error.httpResponseCode, 400)
  910. XCTAssertEqual(error.status, .invalidArgument)
  911. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  912. XCTAssertTrue(error.localizedDescription.contains(error.message))
  913. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  914. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  915. let nsError = error as NSError
  916. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  917. XCTAssertEqual(nsError.code, error.httpResponseCode)
  918. return
  919. }
  920. XCTFail("Should have caught an error.")
  921. }
  922. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  923. let expectedStatusCode = 403
  924. MockURLProtocol
  925. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  926. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  927. withExtension: "json",
  928. subdirectory: vertexSubdirectory,
  929. statusCode: expectedStatusCode
  930. )
  931. do {
  932. let stream = try model.generateContentStream(testPrompt)
  933. for try await _ in stream {
  934. XCTFail("No content is there, this shouldn't happen.")
  935. }
  936. } catch let GenerateContentError.internalError(error as BackendError) {
  937. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  938. XCTAssertEqual(error.status, .permissionDenied)
  939. XCTAssertTrue(error.message
  940. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  941. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  942. return
  943. }
  944. XCTFail("Should have caught an error.")
  945. }
  946. func testGenerateContentStream_failureEmptyContent() async throws {
  947. MockURLProtocol
  948. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  949. forResource: "streaming-failure-empty-content",
  950. withExtension: "txt",
  951. subdirectory: vertexSubdirectory
  952. )
  953. do {
  954. let stream = try model.generateContentStream("Hi")
  955. for try await _ in stream {
  956. XCTFail("No content is there, this shouldn't happen.")
  957. }
  958. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  959. // Underlying error is as expected, nothing else to check.
  960. return
  961. }
  962. XCTFail("Should have caught an error.")
  963. }
  964. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  965. MockURLProtocol
  966. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  967. forResource: "streaming-failure-finish-reason-safety",
  968. withExtension: "txt",
  969. subdirectory: vertexSubdirectory
  970. )
  971. do {
  972. let stream = try model.generateContentStream("Hi")
  973. for try await _ in stream {
  974. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  975. }
  976. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  977. XCTAssertEqual(reason, .safety)
  978. let candidate = try XCTUnwrap(response.candidates.first)
  979. XCTAssertEqual(candidate.finishReason, reason)
  980. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  981. return
  982. }
  983. XCTFail("Should have caught an error.")
  984. }
  985. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  986. MockURLProtocol
  987. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  988. forResource: "streaming-failure-prompt-blocked-safety",
  989. withExtension: "txt",
  990. subdirectory: vertexSubdirectory
  991. )
  992. do {
  993. let stream = try model.generateContentStream("Hi")
  994. for try await _ in stream {
  995. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  996. }
  997. } catch let GenerateContentError.promptBlocked(response) {
  998. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  999. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1000. XCTAssertNil(promptFeedback.blockReasonMessage)
  1001. return
  1002. }
  1003. XCTFail("Should have caught an error.")
  1004. }
  1005. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  1006. MockURLProtocol
  1007. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1008. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  1009. withExtension: "txt",
  1010. subdirectory: vertexSubdirectory
  1011. )
  1012. do {
  1013. let stream = try model.generateContentStream("Hi")
  1014. for try await _ in stream {
  1015. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1016. }
  1017. } catch let GenerateContentError.promptBlocked(response) {
  1018. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1019. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1020. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  1021. return
  1022. }
  1023. XCTFail("Should have caught an error.")
  1024. }
  1025. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  1026. MockURLProtocol
  1027. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1028. forResource: "streaming-failure-unknown-finish-enum",
  1029. withExtension: "txt",
  1030. subdirectory: vertexSubdirectory
  1031. )
  1032. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  1033. let stream = try model.generateContentStream("Hi")
  1034. do {
  1035. for try await content in stream {
  1036. XCTAssertNotNil(content.text)
  1037. }
  1038. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  1039. XCTAssertEqual(reason, unknownFinishReason)
  1040. return
  1041. }
  1042. XCTFail("Should have caught an error.")
  1043. }
  1044. func testGenerateContentStream_successBasicReplyLong() async throws {
  1045. MockURLProtocol
  1046. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1047. forResource: "streaming-success-basic-reply-long",
  1048. withExtension: "txt",
  1049. subdirectory: vertexSubdirectory
  1050. )
  1051. var responses = 0
  1052. let stream = try model.generateContentStream("Hi")
  1053. for try await content in stream {
  1054. XCTAssertNotNil(content.text)
  1055. responses += 1
  1056. }
  1057. XCTAssertEqual(responses, 4)
  1058. }
  1059. func testGenerateContentStream_successBasicReplyShort() async throws {
  1060. MockURLProtocol
  1061. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1062. forResource: "streaming-success-basic-reply-short",
  1063. withExtension: "txt",
  1064. subdirectory: vertexSubdirectory
  1065. )
  1066. var responses = 0
  1067. let stream = try model.generateContentStream("Hi")
  1068. for try await content in stream {
  1069. XCTAssertNotNil(content.text)
  1070. responses += 1
  1071. }
  1072. XCTAssertEqual(responses, 1)
  1073. }
  1074. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1075. MockURLProtocol
  1076. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1077. forResource: "streaming-success-unknown-safety-enum",
  1078. withExtension: "txt",
  1079. subdirectory: vertexSubdirectory
  1080. )
  1081. let unknownSafetyRating = SafetyRating(
  1082. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1083. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1084. probabilityScore: 0.0,
  1085. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1086. severityScore: 0.0,
  1087. blocked: false
  1088. )
  1089. var foundUnknownSafetyRating = false
  1090. let stream = try model.generateContentStream("Hi")
  1091. for try await content in stream {
  1092. XCTAssertNotNil(content.text)
  1093. if let ratings = content.candidates.first?.safetyRatings,
  1094. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1095. foundUnknownSafetyRating = true
  1096. }
  1097. }
  1098. XCTAssertTrue(foundUnknownSafetyRating)
  1099. }
  1100. func testGenerateContentStream_successWithCitations() async throws {
  1101. MockURLProtocol
  1102. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1103. forResource: "streaming-success-citations",
  1104. withExtension: "txt",
  1105. subdirectory: vertexSubdirectory
  1106. )
  1107. let expectedPublicationDate = DateComponents(
  1108. calendar: Calendar(identifier: .gregorian),
  1109. year: 2014,
  1110. month: 3,
  1111. day: 30
  1112. )
  1113. let stream = try model.generateContentStream("Hi")
  1114. var citations = [Citation]()
  1115. var responses = [GenerateContentResponse]()
  1116. for try await content in stream {
  1117. responses.append(content)
  1118. XCTAssertNotNil(content.text)
  1119. let candidate = try XCTUnwrap(content.candidates.first)
  1120. if let sources = candidate.citationMetadata?.citations {
  1121. citations.append(contentsOf: sources)
  1122. }
  1123. }
  1124. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1125. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1126. XCTAssertEqual(citations.count, 6)
  1127. XCTAssertTrue(citations
  1128. .contains {
  1129. $0.startIndex == 0 && $0.endIndex == 128
  1130. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1131. && $0.license == nil && $0.publicationDate == nil
  1132. })
  1133. XCTAssertTrue(citations
  1134. .contains {
  1135. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1136. && $0.title == "some-citation-2" && $0.license == nil
  1137. && $0.publicationDate == expectedPublicationDate
  1138. })
  1139. XCTAssertTrue(citations
  1140. .contains {
  1141. $0.startIndex == 272 && $0.endIndex == 431
  1142. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1143. && $0.license == "mit" && $0.publicationDate == nil
  1144. })
  1145. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1146. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1147. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1148. }
  1149. func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
  1150. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1151. forResource: "streaming-success-image-invalid-safety-ratings",
  1152. withExtension: "txt",
  1153. subdirectory: vertexSubdirectory
  1154. )
  1155. let stream = try model.generateContentStream(testPrompt)
  1156. var responses = [GenerateContentResponse]()
  1157. for try await content in stream {
  1158. responses.append(content)
  1159. }
  1160. let response = try XCTUnwrap(responses.first)
  1161. XCTAssertEqual(response.candidates.count, 1)
  1162. let candidate = try XCTUnwrap(response.candidates.first)
  1163. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  1164. XCTAssertEqual(candidate.content.parts.count, 1)
  1165. let inlineDataParts = response.inlineDataParts
  1166. XCTAssertEqual(inlineDataParts.count, 1)
  1167. let imagePart = try XCTUnwrap(inlineDataParts.first)
  1168. XCTAssertEqual(imagePart.mimeType, "image/png")
  1169. XCTAssertGreaterThan(imagePart.data.count, 0)
  1170. }
  1171. func testGenerateContentStream_appCheck_validToken() async throws {
  1172. let appCheckToken = "test-valid-token"
  1173. model = GenerativeModel(
  1174. modelName: testModelName,
  1175. modelResourceName: testModelResourceName,
  1176. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1177. appCheck: AppCheckInteropFake(token: appCheckToken)
  1178. ),
  1179. apiConfig: apiConfig,
  1180. tools: nil,
  1181. requestOptions: RequestOptions(),
  1182. urlSession: urlSession
  1183. )
  1184. MockURLProtocol
  1185. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1186. forResource: "streaming-success-basic-reply-short",
  1187. withExtension: "txt",
  1188. subdirectory: vertexSubdirectory,
  1189. appCheckToken: appCheckToken
  1190. )
  1191. let stream = try model.generateContentStream(testPrompt)
  1192. for try await _ in stream {}
  1193. }
  1194. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1195. model = GenerativeModel(
  1196. modelName: testModelName,
  1197. modelResourceName: testModelResourceName,
  1198. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1199. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  1200. ),
  1201. apiConfig: apiConfig,
  1202. tools: nil,
  1203. requestOptions: RequestOptions(),
  1204. urlSession: urlSession
  1205. )
  1206. MockURLProtocol
  1207. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1208. forResource: "streaming-success-basic-reply-short",
  1209. withExtension: "txt",
  1210. subdirectory: vertexSubdirectory,
  1211. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1212. )
  1213. let stream = try model.generateContentStream(testPrompt)
  1214. for try await _ in stream {}
  1215. }
  1216. func testGenerateContentStream_usageMetadata() async throws {
  1217. MockURLProtocol
  1218. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1219. forResource: "streaming-success-basic-reply-short",
  1220. withExtension: "txt",
  1221. subdirectory: vertexSubdirectory
  1222. )
  1223. var responses = [GenerateContentResponse]()
  1224. let stream = try model.generateContentStream(testPrompt)
  1225. for try await response in stream {
  1226. responses.append(response)
  1227. }
  1228. for (index, response) in responses.enumerated() {
  1229. if index == responses.endIndex - 1 {
  1230. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1231. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1232. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1233. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1234. } else {
  1235. // Only the last streamed response contains usage metadata
  1236. XCTAssertNil(response.usageMetadata)
  1237. }
  1238. }
  1239. }
  1240. func testGenerateContentStream_errorMidStream() async throws {
  1241. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1242. forResource: "streaming-failure-error-mid-stream",
  1243. withExtension: "txt",
  1244. subdirectory: vertexSubdirectory
  1245. )
  1246. var responseCount = 0
  1247. do {
  1248. let stream = try model.generateContentStream("Hi")
  1249. for try await content in stream {
  1250. XCTAssertNotNil(content.text)
  1251. responseCount += 1
  1252. }
  1253. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1254. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1255. XCTAssertEqual(rpcError.status, .cancelled)
  1256. // Check the content count is correct.
  1257. XCTAssertEqual(responseCount, 2)
  1258. return
  1259. }
  1260. XCTFail("Expected an internalError with an RPCError.")
  1261. }
  1262. func testGenerateContentStream_nonHTTPResponse() async throws {
  1263. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  1264. let stream = try model.generateContentStream("Hi")
  1265. do {
  1266. for try await content in stream {
  1267. XCTFail("Unexpected content in stream: \(content)")
  1268. }
  1269. } catch let GenerateContentError.internalError(underlying) {
  1270. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1271. return
  1272. }
  1273. XCTFail("Expected an internal error.")
  1274. }
  1275. func testGenerateContentStream_invalidResponse() async throws {
  1276. MockURLProtocol
  1277. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1278. forResource: "streaming-failure-invalid-json",
  1279. withExtension: "txt",
  1280. subdirectory: vertexSubdirectory
  1281. )
  1282. let stream = try model.generateContentStream(testPrompt)
  1283. do {
  1284. for try await content in stream {
  1285. XCTFail("Unexpected content in stream: \(content)")
  1286. }
  1287. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1288. guard case let .dataCorrupted(context) = underlying else {
  1289. XCTFail("Not a data corrupted error: \(underlying)")
  1290. return
  1291. }
  1292. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1293. return
  1294. }
  1295. XCTFail("Expected an internal error.")
  1296. }
  1297. func testGenerateContentStream_malformedContent() async throws {
  1298. MockURLProtocol
  1299. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1300. // Note: Although this file does not contain `parts` in `content`, it is not actually
  1301. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  1302. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  1303. forResource: "streaming-failure-malformed-content",
  1304. withExtension: "txt",
  1305. subdirectory: vertexSubdirectory
  1306. )
  1307. let stream = try model.generateContentStream(testPrompt)
  1308. do {
  1309. for try await content in stream {
  1310. XCTFail("Unexpected content in stream: \(content)")
  1311. }
  1312. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1313. guard case let .emptyContent(contentError) = underlyingError else {
  1314. XCTFail("Not an empty content error: \(underlyingError)")
  1315. return
  1316. }
  1317. XCTAssert(contentError is DecodingError)
  1318. return
  1319. }
  1320. XCTFail("Expected an internal decoding error.")
  1321. }
  1322. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1323. let expectedTimeout = 150.0
  1324. MockURLProtocol
  1325. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1326. forResource: "streaming-success-basic-reply-short",
  1327. withExtension: "txt",
  1328. subdirectory: vertexSubdirectory,
  1329. timeout: expectedTimeout
  1330. )
  1331. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1332. model = GenerativeModel(
  1333. modelName: testModelName,
  1334. modelResourceName: testModelResourceName,
  1335. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1336. apiConfig: apiConfig,
  1337. tools: nil,
  1338. requestOptions: requestOptions,
  1339. urlSession: urlSession
  1340. )
  1341. var responses = 0
  1342. let stream = try model.generateContentStream(testPrompt)
  1343. for try await content in stream {
  1344. XCTAssertNotNil(content.text)
  1345. responses += 1
  1346. }
  1347. XCTAssertEqual(responses, 1)
  1348. }
  1349. // MARK: - Count Tokens
  1350. func testCountTokens_succeeds() async throws {
  1351. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1352. forResource: "unary-success-total-tokens",
  1353. withExtension: "json",
  1354. subdirectory: vertexSubdirectory
  1355. )
  1356. let response = try await model.countTokens("Why is the sky blue?")
  1357. XCTAssertEqual(response.totalTokens, 6)
  1358. XCTAssertEqual(response.deprecated.totalBillableCharacters, 16)
  1359. }
  1360. func testCountTokens_succeeds_detailed() async throws {
  1361. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1362. forResource: "unary-success-detailed-token-response",
  1363. withExtension: "json",
  1364. subdirectory: vertexSubdirectory
  1365. )
  1366. let response = try await model.countTokens("Why is the sky blue?")
  1367. XCTAssertEqual(response.totalTokens, 1837)
  1368. XCTAssertEqual(response.deprecated.totalBillableCharacters, 117)
  1369. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1370. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1371. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1372. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1373. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1374. }
  1375. func testCountTokens_succeeds_allOptions() async throws {
  1376. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1377. forResource: "unary-success-total-tokens",
  1378. withExtension: "json",
  1379. subdirectory: vertexSubdirectory
  1380. )
  1381. let generationConfig = GenerationConfig(
  1382. temperature: 0.5,
  1383. topP: 0.9,
  1384. topK: 3,
  1385. candidateCount: 1,
  1386. maxOutputTokens: 1024,
  1387. stopSequences: ["test-stop"],
  1388. responseMIMEType: "text/plain"
  1389. )
  1390. let sumFunction = FunctionDeclaration(
  1391. name: "sum",
  1392. description: "Add two integers.",
  1393. parameters: ["x": .integer(), "y": .integer()]
  1394. )
  1395. let systemInstruction = ModelContent(
  1396. role: "system",
  1397. parts: "You are a calculator. Use the provided tools."
  1398. )
  1399. model = GenerativeModel(
  1400. modelName: testModelName,
  1401. modelResourceName: testModelResourceName,
  1402. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1403. apiConfig: apiConfig,
  1404. generationConfig: generationConfig,
  1405. tools: [Tool(functionDeclarations: [sumFunction])],
  1406. systemInstruction: systemInstruction,
  1407. requestOptions: RequestOptions(),
  1408. urlSession: urlSession
  1409. )
  1410. let response = try await model.countTokens("Why is the sky blue?")
  1411. XCTAssertEqual(response.totalTokens, 6)
  1412. XCTAssertEqual(response.deprecated.totalBillableCharacters, 16)
  1413. }
  1414. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1415. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1416. forResource: "unary-success-no-billable-characters",
  1417. withExtension: "json",
  1418. subdirectory: vertexSubdirectory
  1419. )
  1420. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1421. XCTAssertEqual(response.totalTokens, 258)
  1422. XCTAssertNil(response.deprecated.totalBillableCharacters)
  1423. }
  1424. func testCountTokens_modelNotFound() async throws {
  1425. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1426. forResource: "unary-failure-model-not-found", withExtension: "json",
  1427. subdirectory: vertexSubdirectory,
  1428. statusCode: 404
  1429. )
  1430. do {
  1431. _ = try await model.countTokens("Why is the sky blue?")
  1432. XCTFail("Request should not have succeeded.")
  1433. } catch let rpcError as BackendError {
  1434. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1435. XCTAssertEqual(rpcError.status, .notFound)
  1436. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1437. return
  1438. }
  1439. XCTFail("Expected internal RPCError.")
  1440. }
  1441. func testCountTokens_requestOptions_customTimeout() async throws {
  1442. let expectedTimeout = 150.0
  1443. MockURLProtocol
  1444. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1445. forResource: "unary-success-total-tokens",
  1446. withExtension: "json",
  1447. subdirectory: vertexSubdirectory,
  1448. timeout: expectedTimeout
  1449. )
  1450. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1451. model = GenerativeModel(
  1452. modelName: testModelName,
  1453. modelResourceName: testModelResourceName,
  1454. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1455. apiConfig: apiConfig,
  1456. tools: nil,
  1457. requestOptions: requestOptions,
  1458. urlSession: urlSession
  1459. )
  1460. let response = try await model.countTokens(testPrompt)
  1461. XCTAssertEqual(response.totalTokens, 6)
  1462. }
  1463. }
  1464. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1465. extension SafetyRating: Swift.Comparable {
  1466. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1467. return lhs.category.rawValue < rhs.category.rawValue
  1468. }
  1469. }