GenerativeModelVertexAITests.swift 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelVertexAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let safetyRatingsInvalidIgnored = [
  57. SafetyRating(
  58. category: .hateSpeech,
  59. probability: .negligible,
  60. probabilityScore: 0.00039444832,
  61. severity: .negligible,
  62. severityScore: 0.0,
  63. blocked: false
  64. ),
  65. SafetyRating(
  66. category: .dangerousContent,
  67. probability: .negligible,
  68. probabilityScore: 0.0010654529,
  69. severity: .negligible,
  70. severityScore: 0.0049325973,
  71. blocked: false
  72. ),
  73. SafetyRating(
  74. category: .harassment,
  75. probability: .negligible,
  76. probabilityScore: 0.00026658305,
  77. severity: .negligible,
  78. severityScore: 0.0,
  79. blocked: false
  80. ),
  81. SafetyRating(
  82. category: .sexuallyExplicit,
  83. probability: .negligible,
  84. probabilityScore: 0.0013701695,
  85. severity: .negligible,
  86. severityScore: 0.07626295,
  87. blocked: false
  88. ),
  89. // Ignored Invalid Safety Ratings: {},{},{},{}
  90. ].sorted()
  91. let testModelName = "test-model"
  92. let testModelResourceName =
  93. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  94. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  95. let vertexSubdirectory = "mock-responses/vertexai"
  96. var urlSession: URLSession!
  97. var model: GenerativeModel!
  98. override func setUp() async throws {
  99. let configuration = URLSessionConfiguration.default
  100. configuration.protocolClasses = [MockURLProtocol.self]
  101. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  102. model = GenerativeModel(
  103. modelName: testModelName,
  104. modelResourceName: testModelResourceName,
  105. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  106. apiConfig: apiConfig,
  107. tools: nil,
  108. requestOptions: RequestOptions(),
  109. urlSession: urlSession
  110. )
  111. }
  112. override func tearDown() {
  113. MockURLProtocol.requestHandler = nil
  114. }
  115. // MARK: - Generate Content
  116. func testGenerateContent_success_basicReplyLong() async throws {
  117. MockURLProtocol
  118. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  119. forResource: "unary-success-basic-reply-long",
  120. withExtension: "json",
  121. subdirectory: vertexSubdirectory
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. XCTAssertEqual(candidate.safetyRatings.count, 4)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let part = try XCTUnwrap(candidate.content.parts.first)
  131. let partText = try XCTUnwrap(part as? TextPart).text
  132. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  133. XCTAssertEqual(response.text, partText)
  134. XCTAssertEqual(response.functionCalls, [])
  135. XCTAssertEqual(response.inlineDataParts, [])
  136. }
  137. func testGenerateContent_success_basicReplyShort() async throws {
  138. MockURLProtocol
  139. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  140. forResource: "unary-success-basic-reply-short",
  141. withExtension: "json",
  142. subdirectory: vertexSubdirectory
  143. )
  144. let response = try await model.generateContent(testPrompt)
  145. XCTAssertEqual(response.candidates.count, 1)
  146. let candidate = try XCTUnwrap(response.candidates.first)
  147. let finishReason = try XCTUnwrap(candidate.finishReason)
  148. XCTAssertEqual(finishReason, .stop)
  149. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  150. XCTAssertEqual(candidate.content.parts.count, 1)
  151. let part = try XCTUnwrap(candidate.content.parts.first)
  152. let textPart = try XCTUnwrap(part as? TextPart)
  153. XCTAssertEqual(textPart.text, "Mountain View, California")
  154. XCTAssertEqual(response.text, textPart.text)
  155. XCTAssertEqual(response.functionCalls, [])
  156. }
  157. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  158. MockURLProtocol
  159. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  160. forResource: "unary-success-basic-response-long-usage-metadata",
  161. withExtension: "json",
  162. subdirectory: vertexSubdirectory
  163. )
  164. let response = try await model.generateContent(testPrompt)
  165. XCTAssertEqual(response.candidates.count, 1)
  166. let candidate = try XCTUnwrap(response.candidates.first)
  167. let finishReason = try XCTUnwrap(candidate.finishReason)
  168. XCTAssertEqual(finishReason, .stop)
  169. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  170. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  171. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  175. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  178. }
  179. func testGenerateContent_success_citations() async throws {
  180. MockURLProtocol
  181. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  182. forResource: "unary-success-citations",
  183. withExtension: "json",
  184. subdirectory: vertexSubdirectory
  185. )
  186. let expectedPublicationDate = DateComponents(
  187. calendar: Calendar(identifier: .gregorian),
  188. year: 2019,
  189. month: 5,
  190. day: 10
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. XCTAssertEqual(candidate.content.parts.count, 1)
  196. XCTAssertEqual(response.text, "Some information cited from an external source")
  197. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  198. XCTAssertEqual(citationMetadata.citations.count, 3)
  199. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  200. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  201. XCTAssertEqual(citationSource1.startIndex, 0)
  202. XCTAssertEqual(citationSource1.endIndex, 128)
  203. XCTAssertNil(citationSource1.title)
  204. XCTAssertNil(citationSource1.license)
  205. XCTAssertNil(citationSource1.publicationDate)
  206. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  207. XCTAssertEqual(citationSource2.title, "some-citation-2")
  208. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  209. XCTAssertEqual(citationSource2.startIndex, 130)
  210. XCTAssertEqual(citationSource2.endIndex, 265)
  211. XCTAssertNil(citationSource2.uri)
  212. XCTAssertNil(citationSource2.license)
  213. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  214. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  215. XCTAssertEqual(citationSource3.startIndex, 272)
  216. XCTAssertEqual(citationSource3.endIndex, 431)
  217. XCTAssertEqual(citationSource3.license, "mit")
  218. XCTAssertNil(citationSource3.title)
  219. XCTAssertNil(citationSource3.publicationDate)
  220. }
  221. func testGenerateContent_success_quoteReply() async throws {
  222. MockURLProtocol
  223. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  224. forResource: "unary-success-quote-reply",
  225. withExtension: "json",
  226. subdirectory: vertexSubdirectory
  227. )
  228. let response = try await model.generateContent(testPrompt)
  229. XCTAssertEqual(response.candidates.count, 1)
  230. let candidate = try XCTUnwrap(response.candidates.first)
  231. let finishReason = try XCTUnwrap(candidate.finishReason)
  232. XCTAssertEqual(finishReason, .stop)
  233. XCTAssertEqual(candidate.safetyRatings.count, 4)
  234. XCTAssertEqual(candidate.content.parts.count, 1)
  235. let part = try XCTUnwrap(candidate.content.parts.first)
  236. let textPart = try XCTUnwrap(part as? TextPart)
  237. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  238. XCTAssertEqual(response.text, textPart.text)
  239. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  240. XCTAssertNil(promptFeedback.blockReason)
  241. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  242. }
  243. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  244. let expectedSafetyRatings = [
  245. SafetyRating(
  246. category: .harassment,
  247. probability: .medium,
  248. probabilityScore: 0.0,
  249. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  250. severityScore: 0.0,
  251. blocked: false
  252. ),
  253. SafetyRating(
  254. category: .dangerousContent,
  255. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  256. probabilityScore: 0.0,
  257. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  258. severityScore: 0.0,
  259. blocked: false
  260. ),
  261. SafetyRating(
  262. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  263. probability: .high,
  264. probabilityScore: 0.0,
  265. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  266. severityScore: 0.0,
  267. blocked: false
  268. ),
  269. ]
  270. MockURLProtocol
  271. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "unary-success-unknown-enum-safety-ratings",
  273. withExtension: "json",
  274. subdirectory: vertexSubdirectory
  275. )
  276. let response = try await model.generateContent(testPrompt)
  277. XCTAssertEqual(response.text, "Some text")
  278. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  279. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  280. }
  281. func testGenerateContent_success_prefixedModelName() async throws {
  282. MockURLProtocol
  283. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  284. forResource: "unary-success-basic-reply-short",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let model = GenerativeModel(
  289. modelName: testModelName,
  290. modelResourceName: testModelResourceName,
  291. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  292. apiConfig: apiConfig,
  293. tools: nil,
  294. requestOptions: RequestOptions(),
  295. urlSession: urlSession
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  300. MockURLProtocol
  301. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  302. forResource: "unary-success-function-call-empty-arguments",
  303. withExtension: "json",
  304. subdirectory: vertexSubdirectory
  305. )
  306. let response = try await model.generateContent(testPrompt)
  307. XCTAssertEqual(response.candidates.count, 1)
  308. let candidate = try XCTUnwrap(response.candidates.first)
  309. XCTAssertEqual(candidate.content.parts.count, 1)
  310. let part = try XCTUnwrap(candidate.content.parts.first)
  311. guard let functionCall = part as? FunctionCallPart else {
  312. XCTFail("Part is not a FunctionCall.")
  313. return
  314. }
  315. XCTAssertEqual(functionCall.name, "current_time")
  316. XCTAssertTrue(functionCall.args.isEmpty)
  317. XCTAssertEqual(response.functionCalls, [functionCall])
  318. }
  319. func testGenerateContent_success_functionCall_noArguments() async throws {
  320. MockURLProtocol
  321. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  322. forResource: "unary-success-function-call-no-arguments",
  323. withExtension: "json",
  324. subdirectory: vertexSubdirectory
  325. )
  326. let response = try await model.generateContent(testPrompt)
  327. XCTAssertEqual(response.candidates.count, 1)
  328. let candidate = try XCTUnwrap(response.candidates.first)
  329. XCTAssertEqual(candidate.content.parts.count, 1)
  330. let part = try XCTUnwrap(candidate.content.parts.first)
  331. guard let functionCall = part as? FunctionCallPart else {
  332. XCTFail("Part is not a FunctionCall.")
  333. return
  334. }
  335. XCTAssertEqual(functionCall.name, "current_time")
  336. XCTAssertTrue(functionCall.args.isEmpty)
  337. XCTAssertEqual(response.functionCalls, [functionCall])
  338. }
  339. func testGenerateContent_success_functionCall_withArguments() async throws {
  340. MockURLProtocol
  341. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  342. forResource: "unary-success-function-call-with-arguments",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 1)
  350. let part = try XCTUnwrap(candidate.content.parts.first)
  351. guard let functionCall = part as? FunctionCallPart else {
  352. XCTFail("Part is not a FunctionCall.")
  353. return
  354. }
  355. XCTAssertEqual(functionCall.name, "sum")
  356. XCTAssertEqual(functionCall.args.count, 2)
  357. let argX = try XCTUnwrap(functionCall.args["x"])
  358. XCTAssertEqual(argX, .number(4))
  359. let argY = try XCTUnwrap(functionCall.args["y"])
  360. XCTAssertEqual(argY, .number(5))
  361. XCTAssertEqual(response.functionCalls, [functionCall])
  362. }
  363. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  364. MockURLProtocol
  365. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  366. forResource: "unary-success-function-call-parallel-calls",
  367. withExtension: "json",
  368. subdirectory: vertexSubdirectory
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. XCTAssertEqual(response.candidates.count, 1)
  372. let candidate = try XCTUnwrap(response.candidates.first)
  373. XCTAssertEqual(candidate.content.parts.count, 3)
  374. let functionCalls = response.functionCalls
  375. XCTAssertEqual(functionCalls.count, 3)
  376. }
  377. func testGenerateContent_success_functionCall_mixedContent() async throws {
  378. MockURLProtocol
  379. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  380. forResource: "unary-success-function-call-mixed-content",
  381. withExtension: "json",
  382. subdirectory: vertexSubdirectory
  383. )
  384. let response = try await model.generateContent(testPrompt)
  385. XCTAssertEqual(response.candidates.count, 1)
  386. let candidate = try XCTUnwrap(response.candidates.first)
  387. XCTAssertEqual(candidate.content.parts.count, 4)
  388. let functionCalls = response.functionCalls
  389. XCTAssertEqual(functionCalls.count, 2)
  390. let text = try XCTUnwrap(response.text)
  391. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  392. }
  393. func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws {
  394. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  395. forResource: "unary-success-image-invalid-safety-ratings",
  396. withExtension: "json",
  397. subdirectory: vertexSubdirectory
  398. )
  399. let response = try await model.generateContent(testPrompt)
  400. XCTAssertEqual(response.candidates.count, 1)
  401. let candidate = try XCTUnwrap(response.candidates.first)
  402. XCTAssertEqual(candidate.content.parts.count, 1)
  403. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  404. let inlineDataParts = response.inlineDataParts
  405. XCTAssertEqual(inlineDataParts.count, 1)
  406. let imagePart = try XCTUnwrap(inlineDataParts.first)
  407. XCTAssertEqual(imagePart.mimeType, "image/png")
  408. XCTAssertGreaterThan(imagePart.data.count, 0)
  409. }
  410. func testGenerateContent_appCheck_validToken() async throws {
  411. let appCheckToken = "test-valid-token"
  412. model = GenerativeModel(
  413. modelName: testModelName,
  414. modelResourceName: testModelResourceName,
  415. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  416. appCheck: AppCheckInteropFake(token: appCheckToken)
  417. ),
  418. apiConfig: apiConfig,
  419. tools: nil,
  420. requestOptions: RequestOptions(),
  421. urlSession: urlSession
  422. )
  423. MockURLProtocol
  424. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  425. forResource: "unary-success-basic-reply-short",
  426. withExtension: "json",
  427. subdirectory: vertexSubdirectory,
  428. appCheckToken: appCheckToken
  429. )
  430. _ = try await model.generateContent(testPrompt)
  431. }
  432. func testGenerateContent_dataCollectionOff() async throws {
  433. let appCheckToken = "test-valid-token"
  434. model = GenerativeModel(
  435. modelName: testModelName,
  436. modelResourceName: testModelResourceName,
  437. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  438. appCheck: AppCheckInteropFake(token: appCheckToken), privateAppID: true
  439. ),
  440. apiConfig: apiConfig,
  441. tools: nil,
  442. requestOptions: RequestOptions(),
  443. urlSession: urlSession
  444. )
  445. MockURLProtocol
  446. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  447. forResource: "unary-success-basic-reply-short",
  448. withExtension: "json",
  449. subdirectory: vertexSubdirectory,
  450. appCheckToken: appCheckToken,
  451. dataCollection: false
  452. )
  453. _ = try await model.generateContent(testPrompt)
  454. }
  455. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  456. model = GenerativeModel(
  457. modelName: testModelName,
  458. modelResourceName: testModelResourceName,
  459. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  460. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  461. ),
  462. apiConfig: apiConfig,
  463. tools: nil,
  464. requestOptions: RequestOptions(),
  465. urlSession: urlSession
  466. )
  467. MockURLProtocol
  468. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  469. forResource: "unary-success-basic-reply-short",
  470. withExtension: "json",
  471. subdirectory: vertexSubdirectory,
  472. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  473. )
  474. _ = try await model.generateContent(testPrompt)
  475. }
  476. func testGenerateContent_auth_validAuthToken() async throws {
  477. let authToken = "test-valid-token"
  478. model = GenerativeModel(
  479. modelName: testModelName,
  480. modelResourceName: testModelResourceName,
  481. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  482. auth: AuthInteropFake(token: authToken)
  483. ),
  484. apiConfig: apiConfig,
  485. tools: nil,
  486. requestOptions: RequestOptions(),
  487. urlSession: urlSession
  488. )
  489. MockURLProtocol
  490. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  491. forResource: "unary-success-basic-reply-short",
  492. withExtension: "json",
  493. subdirectory: vertexSubdirectory,
  494. authToken: authToken
  495. )
  496. _ = try await model.generateContent(testPrompt)
  497. }
  498. func testGenerateContent_auth_nilAuthToken() async throws {
  499. model = GenerativeModel(
  500. modelName: testModelName,
  501. modelResourceName: testModelResourceName,
  502. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  503. apiConfig: apiConfig,
  504. tools: nil,
  505. requestOptions: RequestOptions(),
  506. urlSession: urlSession
  507. )
  508. MockURLProtocol
  509. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  510. forResource: "unary-success-basic-reply-short",
  511. withExtension: "json",
  512. subdirectory: vertexSubdirectory,
  513. authToken: nil
  514. )
  515. _ = try await model.generateContent(testPrompt)
  516. }
  517. func testGenerateContent_auth_authTokenRefreshError() async throws {
  518. model = GenerativeModel(
  519. modelName: testModelName,
  520. modelResourceName: testModelResourceName,
  521. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  522. auth: AuthInteropFake(error: AuthErrorFake())
  523. ),
  524. apiConfig: apiConfig,
  525. tools: nil,
  526. requestOptions: RequestOptions(),
  527. urlSession: urlSession
  528. )
  529. MockURLProtocol
  530. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  531. forResource: "unary-success-basic-reply-short",
  532. withExtension: "json",
  533. subdirectory: vertexSubdirectory,
  534. authToken: nil
  535. )
  536. do {
  537. _ = try await model.generateContent(testPrompt)
  538. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  539. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  540. //
  541. } catch {
  542. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  543. }
  544. }
  545. func testGenerateContent_usageMetadata() async throws {
  546. MockURLProtocol
  547. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  548. forResource: "unary-success-basic-reply-short",
  549. withExtension: "json",
  550. subdirectory: vertexSubdirectory
  551. )
  552. let response = try await model.generateContent(testPrompt)
  553. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  554. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  555. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  556. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  557. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  558. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  559. }
  560. func testGenerateContent_failure_invalidAPIKey() async throws {
  561. let expectedStatusCode = 400
  562. MockURLProtocol
  563. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  564. forResource: "unary-failure-api-key",
  565. withExtension: "json",
  566. subdirectory: vertexSubdirectory,
  567. statusCode: expectedStatusCode
  568. )
  569. do {
  570. _ = try await model.generateContent(testPrompt)
  571. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  572. } catch let GenerateContentError.internalError(error as BackendError) {
  573. XCTAssertEqual(error.httpResponseCode, 400)
  574. XCTAssertEqual(error.status, .invalidArgument)
  575. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  576. XCTAssertTrue(error.localizedDescription.contains(error.message))
  577. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  578. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  579. let nsError = error as NSError
  580. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  581. XCTAssertEqual(nsError.code, error.httpResponseCode)
  582. return
  583. } catch {
  584. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  585. }
  586. }
  587. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  588. let expectedStatusCode = 403
  589. MockURLProtocol
  590. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  591. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  592. withExtension: "json",
  593. subdirectory: vertexSubdirectory,
  594. statusCode: expectedStatusCode
  595. )
  596. do {
  597. _ = try await model.generateContent(testPrompt)
  598. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  599. } catch let GenerateContentError.internalError(error as BackendError) {
  600. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  601. XCTAssertEqual(error.status, .permissionDenied)
  602. XCTAssertTrue(error.message
  603. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  604. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  605. return
  606. } catch {
  607. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  608. }
  609. }
  610. func testGenerateContent_failure_emptyContent() async throws {
  611. MockURLProtocol
  612. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  613. forResource: "unary-failure-empty-content",
  614. withExtension: "json",
  615. subdirectory: vertexSubdirectory
  616. )
  617. do {
  618. _ = try await model.generateContent(testPrompt)
  619. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  620. } catch let GenerateContentError
  621. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  622. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  623. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  624. return
  625. }
  626. _ = try XCTUnwrap(decodingError as? DecodingError,
  627. "Not a DecodingError: \(decodingError)")
  628. } catch {
  629. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  630. }
  631. }
  632. func testGenerateContent_failure_finishReasonSafety() async throws {
  633. MockURLProtocol
  634. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  635. forResource: "unary-failure-finish-reason-safety",
  636. withExtension: "json",
  637. subdirectory: vertexSubdirectory
  638. )
  639. do {
  640. _ = try await model.generateContent(testPrompt)
  641. XCTFail("Should throw")
  642. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  643. XCTAssertEqual(reason, .safety)
  644. XCTAssertEqual(response.text, "<redacted>")
  645. } catch {
  646. XCTFail("Should throw a responseStoppedEarly")
  647. }
  648. }
  649. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  650. MockURLProtocol
  651. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  652. forResource: "unary-failure-finish-reason-safety-no-content",
  653. withExtension: "json",
  654. subdirectory: vertexSubdirectory
  655. )
  656. do {
  657. _ = try await model.generateContent(testPrompt)
  658. XCTFail("Should throw")
  659. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  660. XCTAssertEqual(reason, .safety)
  661. XCTAssertNil(response.text)
  662. } catch {
  663. XCTFail("Should throw a responseStoppedEarly")
  664. }
  665. }
  666. func testGenerateContent_failure_imageRejected() async throws {
  667. let expectedStatusCode = 400
  668. MockURLProtocol
  669. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  670. forResource: "unary-failure-image-rejected",
  671. withExtension: "json",
  672. subdirectory: vertexSubdirectory,
  673. statusCode: 400
  674. )
  675. do {
  676. _ = try await model.generateContent(testPrompt)
  677. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  678. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  679. XCTAssertEqual(rpcError.status, .invalidArgument)
  680. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  681. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  682. } catch {
  683. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  684. }
  685. }
  686. func testGenerateContent_failure_promptBlockedSafety() async throws {
  687. MockURLProtocol
  688. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  689. forResource: "unary-failure-prompt-blocked-safety",
  690. withExtension: "json",
  691. subdirectory: vertexSubdirectory
  692. )
  693. do {
  694. _ = try await model.generateContent(testPrompt)
  695. XCTFail("Should throw")
  696. } catch let GenerateContentError.promptBlocked(response) {
  697. XCTAssertNil(response.text)
  698. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  699. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  700. XCTAssertNil(promptFeedback.blockReasonMessage)
  701. } catch {
  702. XCTFail("Should throw a promptBlocked")
  703. }
  704. }
  705. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  706. MockURLProtocol
  707. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  708. forResource: "unary-failure-prompt-blocked-safety-with-message",
  709. withExtension: "json",
  710. subdirectory: vertexSubdirectory
  711. )
  712. do {
  713. _ = try await model.generateContent(testPrompt)
  714. XCTFail("Should throw")
  715. } catch let GenerateContentError.promptBlocked(response) {
  716. XCTAssertNil(response.text)
  717. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  718. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  719. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  720. } catch {
  721. XCTFail("Should throw a promptBlocked")
  722. }
  723. }
  724. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  725. MockURLProtocol
  726. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  727. forResource: "unary-failure-unknown-enum-finish-reason",
  728. withExtension: "json",
  729. subdirectory: vertexSubdirectory
  730. )
  731. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  732. do {
  733. _ = try await model.generateContent(testPrompt)
  734. XCTFail("Should throw")
  735. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  736. XCTAssertEqual(reason, unknownFinishReason)
  737. XCTAssertEqual(response.text, "Some text")
  738. } catch {
  739. XCTFail("Should throw a responseStoppedEarly")
  740. }
  741. }
  742. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  743. MockURLProtocol
  744. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  745. forResource: "unary-failure-unknown-enum-prompt-blocked",
  746. withExtension: "json",
  747. subdirectory: vertexSubdirectory
  748. )
  749. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  750. do {
  751. _ = try await model.generateContent(testPrompt)
  752. XCTFail("Should throw")
  753. } catch let GenerateContentError.promptBlocked(response) {
  754. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  755. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  756. } catch {
  757. XCTFail("Should throw a promptBlocked")
  758. }
  759. }
  760. func testGenerateContent_failure_unknownModel() async throws {
  761. let expectedStatusCode = 404
  762. MockURLProtocol
  763. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  764. forResource: "unary-failure-unknown-model",
  765. withExtension: "json",
  766. subdirectory: vertexSubdirectory,
  767. statusCode: 404
  768. )
  769. do {
  770. _ = try await model.generateContent(testPrompt)
  771. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  772. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  773. XCTAssertEqual(rpcError.status, .notFound)
  774. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  775. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  776. } catch {
  777. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  778. }
  779. }
  780. func testGenerateContent_failure_nonHTTPResponse() async throws {
  781. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  782. var responseError: Error?
  783. var content: GenerateContentResponse?
  784. do {
  785. content = try await model.generateContent(testPrompt)
  786. } catch {
  787. responseError = error
  788. }
  789. XCTAssertNil(content)
  790. XCTAssertNotNil(responseError)
  791. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  792. guard case let .internalError(underlyingError) = generateContentError else {
  793. XCTFail("Not an internal error: \(generateContentError)")
  794. return
  795. }
  796. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  797. let underlyingNSError = underlyingError as NSError
  798. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  799. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  800. }
  801. func testGenerateContent_failure_invalidResponse() async throws {
  802. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  803. forResource: "unary-failure-invalid-response",
  804. withExtension: "json",
  805. subdirectory: vertexSubdirectory
  806. )
  807. var responseError: Error?
  808. var content: GenerateContentResponse?
  809. do {
  810. content = try await model.generateContent(testPrompt)
  811. } catch {
  812. responseError = error
  813. }
  814. XCTAssertNil(content)
  815. XCTAssertNotNil(responseError)
  816. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  817. guard case let .internalError(underlyingError) = generateContentError else {
  818. XCTFail("Not an internal error: \(generateContentError)")
  819. return
  820. }
  821. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  822. guard case let .dataCorrupted(context) = decodingError else {
  823. XCTFail("Not a data corrupted error: \(decodingError)")
  824. return
  825. }
  826. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  827. }
  828. func testGenerateContent_failure_malformedContent() async throws {
  829. MockURLProtocol
  830. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  831. // Note: Although this file does not contain `parts` in `content`, it is not actually
  832. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  833. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  834. forResource: "unary-failure-malformed-content",
  835. withExtension: "json",
  836. subdirectory: vertexSubdirectory
  837. )
  838. var responseError: Error?
  839. var content: GenerateContentResponse?
  840. do {
  841. content = try await model.generateContent(testPrompt)
  842. } catch {
  843. responseError = error
  844. }
  845. XCTAssertNil(content)
  846. XCTAssertNotNil(responseError)
  847. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  848. guard case let .internalError(underlyingError) = generateContentError else {
  849. XCTFail("Not an internal error: \(generateContentError)")
  850. return
  851. }
  852. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  853. guard case let .emptyContent(emptyContentUnderlyingError) = invalidCandidateError else {
  854. XCTFail("Not an empty content error: \(invalidCandidateError)")
  855. return
  856. }
  857. _ = try XCTUnwrap(
  858. emptyContentUnderlyingError as? DecodingError,
  859. "Not a decoding error: \(emptyContentUnderlyingError)"
  860. )
  861. }
  862. func testGenerateContentMissingSafetyRatings() async throws {
  863. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  864. forResource: "unary-success-missing-safety-ratings",
  865. withExtension: "json",
  866. subdirectory: vertexSubdirectory
  867. )
  868. let content = try await model.generateContent(testPrompt)
  869. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  870. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  871. XCTAssertEqual(content.text, "This is the generated content.")
  872. }
  873. func testGenerateContent_requestOptions_customTimeout() async throws {
  874. let expectedTimeout = 150.0
  875. MockURLProtocol
  876. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  877. forResource: "unary-success-basic-reply-short",
  878. withExtension: "json",
  879. subdirectory: vertexSubdirectory,
  880. timeout: expectedTimeout
  881. )
  882. let requestOptions = RequestOptions(timeout: expectedTimeout)
  883. model = GenerativeModel(
  884. modelName: testModelName,
  885. modelResourceName: testModelResourceName,
  886. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  887. apiConfig: apiConfig,
  888. tools: nil,
  889. requestOptions: requestOptions,
  890. urlSession: urlSession
  891. )
  892. let response = try await model.generateContent(testPrompt)
  893. XCTAssertEqual(response.candidates.count, 1)
  894. }
  895. // MARK: - Generate Content (Streaming)
  896. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  897. MockURLProtocol
  898. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  899. forResource: "unary-failure-api-key",
  900. withExtension: "json",
  901. subdirectory: vertexSubdirectory
  902. )
  903. do {
  904. let stream = try model.generateContentStream("Hi")
  905. for try await _ in stream {
  906. XCTFail("No content is there, this shouldn't happen.")
  907. }
  908. } catch let GenerateContentError.internalError(error as BackendError) {
  909. XCTAssertEqual(error.httpResponseCode, 400)
  910. XCTAssertEqual(error.status, .invalidArgument)
  911. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  912. XCTAssertTrue(error.localizedDescription.contains(error.message))
  913. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  914. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  915. let nsError = error as NSError
  916. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  917. XCTAssertEqual(nsError.code, error.httpResponseCode)
  918. return
  919. }
  920. XCTFail("Should have caught an error.")
  921. }
  922. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  923. let expectedStatusCode = 403
  924. MockURLProtocol
  925. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  926. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  927. withExtension: "json",
  928. subdirectory: vertexSubdirectory,
  929. statusCode: expectedStatusCode
  930. )
  931. do {
  932. let stream = try model.generateContentStream(testPrompt)
  933. for try await _ in stream {
  934. XCTFail("No content is there, this shouldn't happen.")
  935. }
  936. } catch let GenerateContentError.internalError(error as BackendError) {
  937. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  938. XCTAssertEqual(error.status, .permissionDenied)
  939. XCTAssertTrue(error.message
  940. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  941. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  942. return
  943. }
  944. XCTFail("Should have caught an error.")
  945. }
  946. func testGenerateContentStream_failureEmptyContent() async throws {
  947. MockURLProtocol
  948. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  949. forResource: "streaming-failure-empty-content",
  950. withExtension: "txt",
  951. subdirectory: vertexSubdirectory
  952. )
  953. do {
  954. let stream = try model.generateContentStream("Hi")
  955. for try await _ in stream {
  956. XCTFail("No content is there, this shouldn't happen.")
  957. }
  958. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  959. // Underlying error is as expected, nothing else to check.
  960. return
  961. }
  962. XCTFail("Should have caught an error.")
  963. }
  964. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  965. MockURLProtocol
  966. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  967. forResource: "streaming-failure-finish-reason-safety",
  968. withExtension: "txt",
  969. subdirectory: vertexSubdirectory
  970. )
  971. do {
  972. let stream = try model.generateContentStream("Hi")
  973. for try await _ in stream {
  974. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  975. }
  976. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  977. XCTAssertEqual(reason, .safety)
  978. let candidate = try XCTUnwrap(response.candidates.first)
  979. XCTAssertEqual(candidate.finishReason, reason)
  980. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  981. return
  982. }
  983. XCTFail("Should have caught an error.")
  984. }
  985. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  986. MockURLProtocol
  987. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  988. forResource: "streaming-failure-prompt-blocked-safety",
  989. withExtension: "txt",
  990. subdirectory: vertexSubdirectory
  991. )
  992. do {
  993. let stream = try model.generateContentStream("Hi")
  994. for try await _ in stream {
  995. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  996. }
  997. } catch let GenerateContentError.promptBlocked(response) {
  998. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  999. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1000. XCTAssertNil(promptFeedback.blockReasonMessage)
  1001. return
  1002. }
  1003. XCTFail("Should have caught an error.")
  1004. }
  1005. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  1006. MockURLProtocol
  1007. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1008. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  1009. withExtension: "txt",
  1010. subdirectory: vertexSubdirectory
  1011. )
  1012. do {
  1013. let stream = try model.generateContentStream("Hi")
  1014. for try await _ in stream {
  1015. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  1016. }
  1017. } catch let GenerateContentError.promptBlocked(response) {
  1018. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  1019. XCTAssertEqual(promptFeedback.blockReason, .safety)
  1020. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  1021. return
  1022. }
  1023. XCTFail("Should have caught an error.")
  1024. }
  1025. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  1026. MockURLProtocol
  1027. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1028. forResource: "streaming-failure-unknown-finish-enum",
  1029. withExtension: "txt",
  1030. subdirectory: vertexSubdirectory
  1031. )
  1032. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  1033. let stream = try model.generateContentStream("Hi")
  1034. do {
  1035. for try await content in stream {
  1036. XCTAssertNotNil(content.text)
  1037. }
  1038. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  1039. XCTAssertEqual(reason, unknownFinishReason)
  1040. return
  1041. }
  1042. XCTFail("Should have caught an error.")
  1043. }
  1044. func testGenerateContentStream_successBasicReplyLong() async throws {
  1045. MockURLProtocol
  1046. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1047. forResource: "streaming-success-basic-reply-long",
  1048. withExtension: "txt",
  1049. subdirectory: vertexSubdirectory
  1050. )
  1051. var responses = 0
  1052. let stream = try model.generateContentStream("Hi")
  1053. for try await content in stream {
  1054. XCTAssertNotNil(content.text)
  1055. responses += 1
  1056. }
  1057. XCTAssertEqual(responses, 4)
  1058. }
  1059. func testGenerateContentStream_successBasicReplyShort() async throws {
  1060. MockURLProtocol
  1061. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1062. forResource: "streaming-success-basic-reply-short",
  1063. withExtension: "txt",
  1064. subdirectory: vertexSubdirectory
  1065. )
  1066. var responses = 0
  1067. let stream = try model.generateContentStream("Hi")
  1068. for try await content in stream {
  1069. XCTAssertNotNil(content.text)
  1070. responses += 1
  1071. }
  1072. XCTAssertEqual(responses, 1)
  1073. }
  1074. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1075. MockURLProtocol
  1076. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1077. forResource: "streaming-success-unknown-safety-enum",
  1078. withExtension: "txt",
  1079. subdirectory: vertexSubdirectory
  1080. )
  1081. let unknownSafetyRating = SafetyRating(
  1082. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1083. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1084. probabilityScore: 0.0,
  1085. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1086. severityScore: 0.0,
  1087. blocked: false
  1088. )
  1089. var foundUnknownSafetyRating = false
  1090. let stream = try model.generateContentStream("Hi")
  1091. for try await content in stream {
  1092. XCTAssertNotNil(content.text)
  1093. if let ratings = content.candidates.first?.safetyRatings,
  1094. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1095. foundUnknownSafetyRating = true
  1096. }
  1097. }
  1098. XCTAssertTrue(foundUnknownSafetyRating)
  1099. }
  1100. func testGenerateContentStream_successWithCitations() async throws {
  1101. MockURLProtocol
  1102. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1103. forResource: "streaming-success-citations",
  1104. withExtension: "txt",
  1105. subdirectory: vertexSubdirectory
  1106. )
  1107. let expectedPublicationDate = DateComponents(
  1108. calendar: Calendar(identifier: .gregorian),
  1109. year: 2014,
  1110. month: 3,
  1111. day: 30
  1112. )
  1113. let stream = try model.generateContentStream("Hi")
  1114. var citations = [Citation]()
  1115. var responses = [GenerateContentResponse]()
  1116. for try await content in stream {
  1117. responses.append(content)
  1118. XCTAssertNotNil(content.text)
  1119. let candidate = try XCTUnwrap(content.candidates.first)
  1120. if let sources = candidate.citationMetadata?.citations {
  1121. citations.append(contentsOf: sources)
  1122. }
  1123. }
  1124. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1125. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1126. XCTAssertEqual(citations.count, 6)
  1127. XCTAssertTrue(citations
  1128. .contains {
  1129. $0.startIndex == 0 && $0.endIndex == 128
  1130. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1131. && $0.license == nil && $0.publicationDate == nil
  1132. })
  1133. XCTAssertTrue(citations
  1134. .contains {
  1135. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1136. && $0.title == "some-citation-2" && $0.license == nil
  1137. && $0.publicationDate == expectedPublicationDate
  1138. })
  1139. XCTAssertTrue(citations
  1140. .contains {
  1141. $0.startIndex == 272 && $0.endIndex == 431
  1142. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1143. && $0.license == "mit" && $0.publicationDate == nil
  1144. })
  1145. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1146. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1147. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1148. }
  1149. func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws {
  1150. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1151. forResource: "streaming-success-image-invalid-safety-ratings",
  1152. withExtension: "txt",
  1153. subdirectory: vertexSubdirectory
  1154. )
  1155. let stream = try model.generateContentStream(testPrompt)
  1156. var responses = [GenerateContentResponse]()
  1157. for try await content in stream {
  1158. responses.append(content)
  1159. }
  1160. let response = try XCTUnwrap(responses.first)
  1161. XCTAssertEqual(response.candidates.count, 1)
  1162. let candidate = try XCTUnwrap(response.candidates.first)
  1163. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsInvalidIgnored)
  1164. XCTAssertEqual(candidate.content.parts.count, 1)
  1165. let inlineDataParts = response.inlineDataParts
  1166. XCTAssertEqual(inlineDataParts.count, 1)
  1167. let imagePart = try XCTUnwrap(inlineDataParts.first)
  1168. XCTAssertEqual(imagePart.mimeType, "image/png")
  1169. XCTAssertGreaterThan(imagePart.data.count, 0)
  1170. }
  1171. func testGenerateContentStream_appCheck_validToken() async throws {
  1172. let appCheckToken = "test-valid-token"
  1173. model = GenerativeModel(
  1174. modelName: testModelName,
  1175. modelResourceName: testModelResourceName,
  1176. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1177. appCheck: AppCheckInteropFake(token: appCheckToken)
  1178. ),
  1179. apiConfig: apiConfig,
  1180. tools: nil,
  1181. requestOptions: RequestOptions(),
  1182. urlSession: urlSession
  1183. )
  1184. MockURLProtocol
  1185. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1186. forResource: "streaming-success-basic-reply-short",
  1187. withExtension: "txt",
  1188. subdirectory: vertexSubdirectory,
  1189. appCheckToken: appCheckToken
  1190. )
  1191. let stream = try model.generateContentStream(testPrompt)
  1192. for try await _ in stream {}
  1193. }
  1194. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1195. model = GenerativeModel(
  1196. modelName: testModelName,
  1197. modelResourceName: testModelResourceName,
  1198. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(
  1199. appCheck: AppCheckInteropFake(error: AppCheckErrorFake())
  1200. ),
  1201. apiConfig: apiConfig,
  1202. tools: nil,
  1203. requestOptions: RequestOptions(),
  1204. urlSession: urlSession
  1205. )
  1206. MockURLProtocol
  1207. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1208. forResource: "streaming-success-basic-reply-short",
  1209. withExtension: "txt",
  1210. subdirectory: vertexSubdirectory,
  1211. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1212. )
  1213. let stream = try model.generateContentStream(testPrompt)
  1214. for try await _ in stream {}
  1215. }
  1216. func testGenerateContentStream_usageMetadata() async throws {
  1217. MockURLProtocol
  1218. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1219. forResource: "streaming-success-basic-reply-short",
  1220. withExtension: "txt",
  1221. subdirectory: vertexSubdirectory
  1222. )
  1223. var responses = [GenerateContentResponse]()
  1224. let stream = try model.generateContentStream(testPrompt)
  1225. for try await response in stream {
  1226. responses.append(response)
  1227. }
  1228. for (index, response) in responses.enumerated() {
  1229. if index == responses.endIndex - 1 {
  1230. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1231. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1232. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1233. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1234. } else {
  1235. // Only the last streamed response contains usage metadata
  1236. XCTAssertNil(response.usageMetadata)
  1237. }
  1238. }
  1239. }
  1240. func testGenerateContentStream_errorMidStream() async throws {
  1241. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1242. forResource: "streaming-failure-error-mid-stream",
  1243. withExtension: "txt",
  1244. subdirectory: vertexSubdirectory
  1245. )
  1246. var responseCount = 0
  1247. do {
  1248. let stream = try model.generateContentStream("Hi")
  1249. for try await content in stream {
  1250. XCTAssertNotNil(content.text)
  1251. responseCount += 1
  1252. }
  1253. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1254. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1255. XCTAssertEqual(rpcError.status, .cancelled)
  1256. // Check the content count is correct.
  1257. XCTAssertEqual(responseCount, 2)
  1258. return
  1259. }
  1260. XCTFail("Expected an internalError with an RPCError.")
  1261. }
  1262. func testGenerateContentStream_nonHTTPResponse() async throws {
  1263. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.nonHTTPRequestHandler()
  1264. let stream = try model.generateContentStream("Hi")
  1265. do {
  1266. for try await content in stream {
  1267. XCTFail("Unexpected content in stream: \(content)")
  1268. }
  1269. } catch let GenerateContentError.internalError(underlying) {
  1270. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1271. return
  1272. }
  1273. XCTFail("Expected an internal error.")
  1274. }
  1275. func testGenerateContentStream_invalidResponse() async throws {
  1276. MockURLProtocol
  1277. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1278. forResource: "streaming-failure-invalid-json",
  1279. withExtension: "txt",
  1280. subdirectory: vertexSubdirectory
  1281. )
  1282. let stream = try model.generateContentStream(testPrompt)
  1283. do {
  1284. for try await content in stream {
  1285. XCTFail("Unexpected content in stream: \(content)")
  1286. }
  1287. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1288. guard case let .dataCorrupted(context) = underlying else {
  1289. XCTFail("Not a data corrupted error: \(underlying)")
  1290. return
  1291. }
  1292. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1293. return
  1294. }
  1295. XCTFail("Expected an internal error.")
  1296. }
  1297. func testGenerateContentStream_malformedContent() async throws {
  1298. MockURLProtocol
  1299. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1300. // Note: Although this file does not contain `parts` in `content`, it is not actually
  1301. // malformed. The `invalid-field` in the payload could be added, as a non-breaking change to
  1302. // the proto API. Therefore, this test checks for the `emptyContent` error instead.
  1303. forResource: "streaming-failure-malformed-content",
  1304. withExtension: "txt",
  1305. subdirectory: vertexSubdirectory
  1306. )
  1307. let stream = try model.generateContentStream(testPrompt)
  1308. do {
  1309. for try await content in stream {
  1310. XCTFail("Unexpected content in stream: \(content)")
  1311. }
  1312. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1313. guard case let .emptyContent(contentError) = underlyingError else {
  1314. XCTFail("Not an empty content error: \(underlyingError)")
  1315. return
  1316. }
  1317. XCTAssert(contentError is DecodingError)
  1318. return
  1319. }
  1320. XCTFail("Expected an internal decoding error.")
  1321. }
  1322. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1323. let expectedTimeout = 150.0
  1324. MockURLProtocol
  1325. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1326. forResource: "streaming-success-basic-reply-short",
  1327. withExtension: "txt",
  1328. subdirectory: vertexSubdirectory,
  1329. timeout: expectedTimeout
  1330. )
  1331. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1332. model = GenerativeModel(
  1333. modelName: testModelName,
  1334. modelResourceName: testModelResourceName,
  1335. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1336. apiConfig: apiConfig,
  1337. tools: nil,
  1338. requestOptions: requestOptions,
  1339. urlSession: urlSession
  1340. )
  1341. var responses = 0
  1342. let stream = try model.generateContentStream(testPrompt)
  1343. for try await content in stream {
  1344. XCTAssertNotNil(content.text)
  1345. responses += 1
  1346. }
  1347. XCTAssertEqual(responses, 1)
  1348. }
  1349. // MARK: - Count Tokens
  1350. func testCountTokens_succeeds() async throws {
  1351. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1352. forResource: "unary-success-total-tokens",
  1353. withExtension: "json",
  1354. subdirectory: vertexSubdirectory
  1355. )
  1356. let response = try await model.countTokens("Why is the sky blue?")
  1357. XCTAssertEqual(response.totalTokens, 6)
  1358. }
  1359. func testCountTokens_succeeds_detailed() async throws {
  1360. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1361. forResource: "unary-success-detailed-token-response",
  1362. withExtension: "json",
  1363. subdirectory: vertexSubdirectory
  1364. )
  1365. let response = try await model.countTokens("Why is the sky blue?")
  1366. XCTAssertEqual(response.totalTokens, 1837)
  1367. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1368. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1369. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1370. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1371. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1372. }
  1373. func testCountTokens_succeeds_allOptions() async throws {
  1374. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1375. forResource: "unary-success-total-tokens",
  1376. withExtension: "json",
  1377. subdirectory: vertexSubdirectory
  1378. )
  1379. let generationConfig = GenerationConfig(
  1380. temperature: 0.5,
  1381. topP: 0.9,
  1382. topK: 3,
  1383. candidateCount: 1,
  1384. maxOutputTokens: 1024,
  1385. stopSequences: ["test-stop"],
  1386. responseMIMEType: "text/plain"
  1387. )
  1388. let sumFunction = FunctionDeclaration(
  1389. name: "sum",
  1390. description: "Add two integers.",
  1391. parameters: ["x": .integer(), "y": .integer()]
  1392. )
  1393. let systemInstruction = ModelContent(
  1394. role: "system",
  1395. parts: "You are a calculator. Use the provided tools."
  1396. )
  1397. model = GenerativeModel(
  1398. modelName: testModelName,
  1399. modelResourceName: testModelResourceName,
  1400. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1401. apiConfig: apiConfig,
  1402. generationConfig: generationConfig,
  1403. tools: [Tool(functionDeclarations: [sumFunction])],
  1404. systemInstruction: systemInstruction,
  1405. requestOptions: RequestOptions(),
  1406. urlSession: urlSession
  1407. )
  1408. let response = try await model.countTokens("Why is the sky blue?")
  1409. XCTAssertEqual(response.totalTokens, 6)
  1410. }
  1411. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1412. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1413. forResource: "unary-success-no-billable-characters",
  1414. withExtension: "json",
  1415. subdirectory: vertexSubdirectory
  1416. )
  1417. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1418. XCTAssertEqual(response.totalTokens, 258)
  1419. }
  1420. func testCountTokens_modelNotFound() async throws {
  1421. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1422. forResource: "unary-failure-model-not-found", withExtension: "json",
  1423. subdirectory: vertexSubdirectory,
  1424. statusCode: 404
  1425. )
  1426. do {
  1427. _ = try await model.countTokens("Why is the sky blue?")
  1428. XCTFail("Request should not have succeeded.")
  1429. } catch let rpcError as BackendError {
  1430. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1431. XCTAssertEqual(rpcError.status, .notFound)
  1432. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1433. return
  1434. }
  1435. XCTFail("Expected internal RPCError.")
  1436. }
  1437. func testCountTokens_requestOptions_customTimeout() async throws {
  1438. let expectedTimeout = 150.0
  1439. MockURLProtocol
  1440. .requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  1441. forResource: "unary-success-total-tokens",
  1442. withExtension: "json",
  1443. subdirectory: vertexSubdirectory,
  1444. timeout: expectedTimeout
  1445. )
  1446. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1447. model = GenerativeModel(
  1448. modelName: testModelName,
  1449. modelResourceName: testModelResourceName,
  1450. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  1451. apiConfig: apiConfig,
  1452. tools: nil,
  1453. requestOptions: requestOptions,
  1454. urlSession: urlSession
  1455. )
  1456. let response = try await model.countTokens(testPrompt)
  1457. XCTAssertEqual(response.totalTokens, 6)
  1458. }
  1459. }
  1460. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1461. extension SafetyRating: Swift.Comparable {
  1462. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1463. return lhs.category.rawValue < rhs.category.rawValue
  1464. }
  1465. }