GenerativeModelTests.swift 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable public import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let testModelName = "test-model"
  57. let testModelResourceName =
  58. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  59. let apiConfig = VertexAI.defaultVertexAIAPIConfig
  60. let vertexSubdirectory = "vertexai"
  61. var urlSession: URLSession!
  62. var model: GenerativeModel!
  63. override func setUp() async throws {
  64. let configuration = URLSessionConfiguration.default
  65. configuration.protocolClasses = [MockURLProtocol.self]
  66. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  67. model = GenerativeModel(
  68. modelName: testModelName,
  69. modelResourceName: testModelResourceName,
  70. firebaseInfo: testFirebaseInfo(),
  71. apiConfig: apiConfig,
  72. tools: nil,
  73. requestOptions: RequestOptions(),
  74. urlSession: urlSession
  75. )
  76. }
  77. override func tearDown() {
  78. MockURLProtocol.requestHandler = nil
  79. }
  80. // MARK: - Generate Content
  81. func testGenerateContent_success_basicReplyLong() async throws {
  82. MockURLProtocol
  83. .requestHandler = try httpRequestHandler(
  84. forResource: "unary-success-basic-reply-long",
  85. withExtension: "json",
  86. subdirectory: vertexSubdirectory
  87. )
  88. let response = try await model.generateContent(testPrompt)
  89. XCTAssertEqual(response.candidates.count, 1)
  90. let candidate = try XCTUnwrap(response.candidates.first)
  91. let finishReason = try XCTUnwrap(candidate.finishReason)
  92. XCTAssertEqual(finishReason, .stop)
  93. XCTAssertEqual(candidate.safetyRatings.count, 4)
  94. XCTAssertEqual(candidate.content.parts.count, 1)
  95. let part = try XCTUnwrap(candidate.content.parts.first)
  96. let partText = try XCTUnwrap(part as? TextPart).text
  97. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  98. XCTAssertEqual(response.text, partText)
  99. XCTAssertEqual(response.functionCalls, [])
  100. }
  101. func testGenerateContent_success_basicReplyShort() async throws {
  102. MockURLProtocol
  103. .requestHandler = try httpRequestHandler(
  104. forResource: "unary-success-basic-reply-short",
  105. withExtension: "json",
  106. subdirectory: vertexSubdirectory
  107. )
  108. let response = try await model.generateContent(testPrompt)
  109. XCTAssertEqual(response.candidates.count, 1)
  110. let candidate = try XCTUnwrap(response.candidates.first)
  111. let finishReason = try XCTUnwrap(candidate.finishReason)
  112. XCTAssertEqual(finishReason, .stop)
  113. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  114. XCTAssertEqual(candidate.content.parts.count, 1)
  115. let part = try XCTUnwrap(candidate.content.parts.first)
  116. let textPart = try XCTUnwrap(part as? TextPart)
  117. XCTAssertEqual(textPart.text, "Mountain View, California")
  118. XCTAssertEqual(response.text, textPart.text)
  119. XCTAssertEqual(response.functionCalls, [])
  120. }
  121. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  122. MockURLProtocol
  123. .requestHandler = try httpRequestHandler(
  124. forResource: "unary-success-basic-response-long-usage-metadata",
  125. withExtension: "json",
  126. subdirectory: vertexSubdirectory
  127. )
  128. let response = try await model.generateContent(testPrompt)
  129. XCTAssertEqual(response.candidates.count, 1)
  130. let candidate = try XCTUnwrap(response.candidates.first)
  131. let finishReason = try XCTUnwrap(candidate.finishReason)
  132. XCTAssertEqual(finishReason, .stop)
  133. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  134. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  135. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  136. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  137. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  138. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  139. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  140. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  141. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  142. }
  143. func testGenerateContent_success_citations() async throws {
  144. MockURLProtocol
  145. .requestHandler = try httpRequestHandler(
  146. forResource: "unary-success-citations",
  147. withExtension: "json",
  148. subdirectory: vertexSubdirectory
  149. )
  150. let expectedPublicationDate = DateComponents(
  151. calendar: Calendar(identifier: .gregorian),
  152. year: 2019,
  153. month: 5,
  154. day: 10
  155. )
  156. let response = try await model.generateContent(testPrompt)
  157. XCTAssertEqual(response.candidates.count, 1)
  158. let candidate = try XCTUnwrap(response.candidates.first)
  159. XCTAssertEqual(candidate.content.parts.count, 1)
  160. XCTAssertEqual(response.text, "Some information cited from an external source")
  161. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  162. XCTAssertEqual(citationMetadata.citations.count, 3)
  163. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  164. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  165. XCTAssertEqual(citationSource1.startIndex, 0)
  166. XCTAssertEqual(citationSource1.endIndex, 128)
  167. XCTAssertNil(citationSource1.title)
  168. XCTAssertNil(citationSource1.license)
  169. XCTAssertNil(citationSource1.publicationDate)
  170. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  171. XCTAssertEqual(citationSource2.title, "some-citation-2")
  172. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  173. XCTAssertEqual(citationSource2.startIndex, 130)
  174. XCTAssertEqual(citationSource2.endIndex, 265)
  175. XCTAssertNil(citationSource2.uri)
  176. XCTAssertNil(citationSource2.license)
  177. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  178. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  179. XCTAssertEqual(citationSource3.startIndex, 272)
  180. XCTAssertEqual(citationSource3.endIndex, 431)
  181. XCTAssertEqual(citationSource3.license, "mit")
  182. XCTAssertNil(citationSource3.title)
  183. XCTAssertNil(citationSource3.publicationDate)
  184. }
  185. func testGenerateContent_success_quoteReply() async throws {
  186. MockURLProtocol
  187. .requestHandler = try httpRequestHandler(
  188. forResource: "unary-success-quote-reply",
  189. withExtension: "json",
  190. subdirectory: vertexSubdirectory
  191. )
  192. let response = try await model.generateContent(testPrompt)
  193. XCTAssertEqual(response.candidates.count, 1)
  194. let candidate = try XCTUnwrap(response.candidates.first)
  195. let finishReason = try XCTUnwrap(candidate.finishReason)
  196. XCTAssertEqual(finishReason, .stop)
  197. XCTAssertEqual(candidate.safetyRatings.count, 4)
  198. XCTAssertEqual(candidate.content.parts.count, 1)
  199. let part = try XCTUnwrap(candidate.content.parts.first)
  200. let textPart = try XCTUnwrap(part as? TextPart)
  201. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  202. XCTAssertEqual(response.text, textPart.text)
  203. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  204. XCTAssertNil(promptFeedback.blockReason)
  205. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  206. }
  207. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  208. let expectedSafetyRatings = [
  209. SafetyRating(
  210. category: .harassment,
  211. probability: .medium,
  212. probabilityScore: 0.0,
  213. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  214. severityScore: 0.0,
  215. blocked: false
  216. ),
  217. SafetyRating(
  218. category: .dangerousContent,
  219. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  220. probabilityScore: 0.0,
  221. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  222. severityScore: 0.0,
  223. blocked: false
  224. ),
  225. SafetyRating(
  226. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  227. probability: .high,
  228. probabilityScore: 0.0,
  229. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  230. severityScore: 0.0,
  231. blocked: false
  232. ),
  233. ]
  234. MockURLProtocol
  235. .requestHandler = try httpRequestHandler(
  236. forResource: "unary-success-unknown-enum-safety-ratings",
  237. withExtension: "json",
  238. subdirectory: vertexSubdirectory
  239. )
  240. let response = try await model.generateContent(testPrompt)
  241. XCTAssertEqual(response.text, "Some text")
  242. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  243. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  244. }
  245. func testGenerateContent_success_prefixedModelName() async throws {
  246. MockURLProtocol
  247. .requestHandler = try httpRequestHandler(
  248. forResource: "unary-success-basic-reply-short",
  249. withExtension: "json",
  250. subdirectory: vertexSubdirectory
  251. )
  252. let model = GenerativeModel(
  253. modelName: testModelName,
  254. modelResourceName: testModelResourceName,
  255. firebaseInfo: testFirebaseInfo(),
  256. apiConfig: apiConfig,
  257. tools: nil,
  258. requestOptions: RequestOptions(),
  259. urlSession: urlSession
  260. )
  261. _ = try await model.generateContent(testPrompt)
  262. }
  263. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  264. MockURLProtocol
  265. .requestHandler = try httpRequestHandler(
  266. forResource: "unary-success-function-call-empty-arguments",
  267. withExtension: "json",
  268. subdirectory: vertexSubdirectory
  269. )
  270. let response = try await model.generateContent(testPrompt)
  271. XCTAssertEqual(response.candidates.count, 1)
  272. let candidate = try XCTUnwrap(response.candidates.first)
  273. XCTAssertEqual(candidate.content.parts.count, 1)
  274. let part = try XCTUnwrap(candidate.content.parts.first)
  275. guard let functionCall = part as? FunctionCallPart else {
  276. XCTFail("Part is not a FunctionCall.")
  277. return
  278. }
  279. XCTAssertEqual(functionCall.name, "current_time")
  280. XCTAssertTrue(functionCall.args.isEmpty)
  281. XCTAssertEqual(response.functionCalls, [functionCall])
  282. }
  283. func testGenerateContent_success_functionCall_noArguments() async throws {
  284. MockURLProtocol
  285. .requestHandler = try httpRequestHandler(
  286. forResource: "unary-success-function-call-no-arguments",
  287. withExtension: "json",
  288. subdirectory: vertexSubdirectory
  289. )
  290. let response = try await model.generateContent(testPrompt)
  291. XCTAssertEqual(response.candidates.count, 1)
  292. let candidate = try XCTUnwrap(response.candidates.first)
  293. XCTAssertEqual(candidate.content.parts.count, 1)
  294. let part = try XCTUnwrap(candidate.content.parts.first)
  295. guard let functionCall = part as? FunctionCallPart else {
  296. XCTFail("Part is not a FunctionCall.")
  297. return
  298. }
  299. XCTAssertEqual(functionCall.name, "current_time")
  300. XCTAssertTrue(functionCall.args.isEmpty)
  301. XCTAssertEqual(response.functionCalls, [functionCall])
  302. }
  303. func testGenerateContent_success_functionCall_withArguments() async throws {
  304. MockURLProtocol
  305. .requestHandler = try httpRequestHandler(
  306. forResource: "unary-success-function-call-with-arguments",
  307. withExtension: "json",
  308. subdirectory: vertexSubdirectory
  309. )
  310. let response = try await model.generateContent(testPrompt)
  311. XCTAssertEqual(response.candidates.count, 1)
  312. let candidate = try XCTUnwrap(response.candidates.first)
  313. XCTAssertEqual(candidate.content.parts.count, 1)
  314. let part = try XCTUnwrap(candidate.content.parts.first)
  315. guard let functionCall = part as? FunctionCallPart else {
  316. XCTFail("Part is not a FunctionCall.")
  317. return
  318. }
  319. XCTAssertEqual(functionCall.name, "sum")
  320. XCTAssertEqual(functionCall.args.count, 2)
  321. let argX = try XCTUnwrap(functionCall.args["x"])
  322. XCTAssertEqual(argX, .number(4))
  323. let argY = try XCTUnwrap(functionCall.args["y"])
  324. XCTAssertEqual(argY, .number(5))
  325. XCTAssertEqual(response.functionCalls, [functionCall])
  326. }
  327. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  328. MockURLProtocol
  329. .requestHandler = try httpRequestHandler(
  330. forResource: "unary-success-function-call-parallel-calls",
  331. withExtension: "json",
  332. subdirectory: vertexSubdirectory
  333. )
  334. let response = try await model.generateContent(testPrompt)
  335. XCTAssertEqual(response.candidates.count, 1)
  336. let candidate = try XCTUnwrap(response.candidates.first)
  337. XCTAssertEqual(candidate.content.parts.count, 3)
  338. let functionCalls = response.functionCalls
  339. XCTAssertEqual(functionCalls.count, 3)
  340. }
  341. func testGenerateContent_success_functionCall_mixedContent() async throws {
  342. MockURLProtocol
  343. .requestHandler = try httpRequestHandler(
  344. forResource: "unary-success-function-call-mixed-content",
  345. withExtension: "json",
  346. subdirectory: vertexSubdirectory
  347. )
  348. let response = try await model.generateContent(testPrompt)
  349. XCTAssertEqual(response.candidates.count, 1)
  350. let candidate = try XCTUnwrap(response.candidates.first)
  351. XCTAssertEqual(candidate.content.parts.count, 4)
  352. let functionCalls = response.functionCalls
  353. XCTAssertEqual(functionCalls.count, 2)
  354. let text = try XCTUnwrap(response.text)
  355. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  356. }
  357. func testGenerateContent_appCheck_validToken() async throws {
  358. let appCheckToken = "test-valid-token"
  359. model = GenerativeModel(
  360. modelName: testModelName,
  361. modelResourceName: testModelResourceName,
  362. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
  363. apiConfig: apiConfig,
  364. tools: nil,
  365. requestOptions: RequestOptions(),
  366. urlSession: urlSession
  367. )
  368. MockURLProtocol
  369. .requestHandler = try httpRequestHandler(
  370. forResource: "unary-success-basic-reply-short",
  371. withExtension: "json",
  372. subdirectory: vertexSubdirectory,
  373. appCheckToken: appCheckToken
  374. )
  375. _ = try await model.generateContent(testPrompt)
  376. }
  377. func testGenerateContent_dataCollectionOff() async throws {
  378. let appCheckToken = "test-valid-token"
  379. model = GenerativeModel(
  380. modelName: testModelName,
  381. modelResourceName: testModelResourceName,
  382. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken),
  383. privateAppID: true),
  384. apiConfig: apiConfig,
  385. tools: nil,
  386. requestOptions: RequestOptions(),
  387. urlSession: urlSession
  388. )
  389. MockURLProtocol
  390. .requestHandler = try httpRequestHandler(
  391. forResource: "unary-success-basic-reply-short",
  392. withExtension: "json",
  393. subdirectory: vertexSubdirectory,
  394. appCheckToken: appCheckToken,
  395. dataCollection: false
  396. )
  397. _ = try await model.generateContent(testPrompt)
  398. }
  399. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  400. model = GenerativeModel(
  401. modelName: testModelName,
  402. modelResourceName: testModelResourceName,
  403. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
  404. apiConfig: apiConfig,
  405. tools: nil,
  406. requestOptions: RequestOptions(),
  407. urlSession: urlSession
  408. )
  409. MockURLProtocol
  410. .requestHandler = try httpRequestHandler(
  411. forResource: "unary-success-basic-reply-short",
  412. withExtension: "json",
  413. subdirectory: vertexSubdirectory,
  414. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  415. )
  416. _ = try await model.generateContent(testPrompt)
  417. }
  418. func testGenerateContent_auth_validAuthToken() async throws {
  419. let authToken = "test-valid-token"
  420. model = GenerativeModel(
  421. modelName: testModelName,
  422. modelResourceName: testModelResourceName,
  423. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)),
  424. apiConfig: apiConfig,
  425. tools: nil,
  426. requestOptions: RequestOptions(),
  427. urlSession: urlSession
  428. )
  429. MockURLProtocol
  430. .requestHandler = try httpRequestHandler(
  431. forResource: "unary-success-basic-reply-short",
  432. withExtension: "json",
  433. subdirectory: vertexSubdirectory,
  434. authToken: authToken
  435. )
  436. _ = try await model.generateContent(testPrompt)
  437. }
  438. func testGenerateContent_auth_nilAuthToken() async throws {
  439. model = GenerativeModel(
  440. modelName: testModelName,
  441. modelResourceName: testModelResourceName,
  442. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  443. apiConfig: apiConfig,
  444. tools: nil,
  445. requestOptions: RequestOptions(),
  446. urlSession: urlSession
  447. )
  448. MockURLProtocol
  449. .requestHandler = try httpRequestHandler(
  450. forResource: "unary-success-basic-reply-short",
  451. withExtension: "json",
  452. subdirectory: vertexSubdirectory,
  453. authToken: nil
  454. )
  455. _ = try await model.generateContent(testPrompt)
  456. }
  457. func testGenerateContent_auth_authTokenRefreshError() async throws {
  458. model = GenerativeModel(
  459. modelName: testModelName,
  460. modelResourceName: testModelResourceName,
  461. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())),
  462. apiConfig: apiConfig,
  463. tools: nil,
  464. requestOptions: RequestOptions(),
  465. urlSession: urlSession
  466. )
  467. MockURLProtocol
  468. .requestHandler = try httpRequestHandler(
  469. forResource: "unary-success-basic-reply-short",
  470. withExtension: "json",
  471. subdirectory: vertexSubdirectory,
  472. authToken: nil
  473. )
  474. do {
  475. _ = try await model.generateContent(testPrompt)
  476. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  477. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  478. //
  479. } catch {
  480. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  481. }
  482. }
  483. func testGenerateContent_usageMetadata() async throws {
  484. MockURLProtocol
  485. .requestHandler = try httpRequestHandler(
  486. forResource: "unary-success-basic-reply-short",
  487. withExtension: "json",
  488. subdirectory: vertexSubdirectory
  489. )
  490. let response = try await model.generateContent(testPrompt)
  491. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  492. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  493. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  494. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  495. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  496. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  497. }
  498. func testGenerateContent_failure_invalidAPIKey() async throws {
  499. let expectedStatusCode = 400
  500. MockURLProtocol
  501. .requestHandler = try httpRequestHandler(
  502. forResource: "unary-failure-api-key",
  503. withExtension: "json",
  504. subdirectory: vertexSubdirectory,
  505. statusCode: expectedStatusCode
  506. )
  507. do {
  508. _ = try await model.generateContent(testPrompt)
  509. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  510. } catch let GenerateContentError.internalError(error as BackendError) {
  511. XCTAssertEqual(error.httpResponseCode, 400)
  512. XCTAssertEqual(error.status, .invalidArgument)
  513. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  514. XCTAssertTrue(error.localizedDescription.contains(error.message))
  515. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  516. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  517. let nsError = error as NSError
  518. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  519. XCTAssertEqual(nsError.code, error.httpResponseCode)
  520. return
  521. } catch {
  522. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  523. }
  524. }
  525. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  526. let expectedStatusCode = 403
  527. MockURLProtocol
  528. .requestHandler = try httpRequestHandler(
  529. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  530. withExtension: "json",
  531. subdirectory: vertexSubdirectory,
  532. statusCode: expectedStatusCode
  533. )
  534. do {
  535. _ = try await model.generateContent(testPrompt)
  536. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  537. } catch let GenerateContentError.internalError(error as BackendError) {
  538. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  539. XCTAssertEqual(error.status, .permissionDenied)
  540. XCTAssertTrue(error.message
  541. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  542. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  543. return
  544. } catch {
  545. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  546. }
  547. }
  548. func testGenerateContent_failure_emptyContent() async throws {
  549. MockURLProtocol
  550. .requestHandler = try httpRequestHandler(
  551. forResource: "unary-failure-empty-content",
  552. withExtension: "json",
  553. subdirectory: vertexSubdirectory
  554. )
  555. do {
  556. _ = try await model.generateContent(testPrompt)
  557. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  558. } catch let GenerateContentError
  559. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  560. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  561. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  562. return
  563. }
  564. _ = try XCTUnwrap(decodingError as? DecodingError,
  565. "Not a DecodingError: \(decodingError)")
  566. } catch {
  567. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  568. }
  569. }
  570. func testGenerateContent_failure_finishReasonSafety() async throws {
  571. MockURLProtocol
  572. .requestHandler = try httpRequestHandler(
  573. forResource: "unary-failure-finish-reason-safety",
  574. withExtension: "json",
  575. subdirectory: vertexSubdirectory
  576. )
  577. do {
  578. _ = try await model.generateContent(testPrompt)
  579. XCTFail("Should throw")
  580. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  581. XCTAssertEqual(reason, .safety)
  582. XCTAssertEqual(response.text, "<redacted>")
  583. } catch {
  584. XCTFail("Should throw a responseStoppedEarly")
  585. }
  586. }
  587. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  588. MockURLProtocol
  589. .requestHandler = try httpRequestHandler(
  590. forResource: "unary-failure-finish-reason-safety-no-content",
  591. withExtension: "json",
  592. subdirectory: vertexSubdirectory
  593. )
  594. do {
  595. _ = try await model.generateContent(testPrompt)
  596. XCTFail("Should throw")
  597. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  598. XCTAssertEqual(reason, .safety)
  599. XCTAssertNil(response.text)
  600. } catch {
  601. XCTFail("Should throw a responseStoppedEarly")
  602. }
  603. }
  604. func testGenerateContent_failure_imageRejected() async throws {
  605. let expectedStatusCode = 400
  606. MockURLProtocol
  607. .requestHandler = try httpRequestHandler(
  608. forResource: "unary-failure-image-rejected",
  609. withExtension: "json",
  610. subdirectory: vertexSubdirectory,
  611. statusCode: 400
  612. )
  613. do {
  614. _ = try await model.generateContent(testPrompt)
  615. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  616. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  617. XCTAssertEqual(rpcError.status, .invalidArgument)
  618. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  619. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  620. } catch {
  621. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  622. }
  623. }
  624. func testGenerateContent_failure_promptBlockedSafety() async throws {
  625. MockURLProtocol
  626. .requestHandler = try httpRequestHandler(
  627. forResource: "unary-failure-prompt-blocked-safety",
  628. withExtension: "json",
  629. subdirectory: vertexSubdirectory
  630. )
  631. do {
  632. _ = try await model.generateContent(testPrompt)
  633. XCTFail("Should throw")
  634. } catch let GenerateContentError.promptBlocked(response) {
  635. XCTAssertNil(response.text)
  636. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  637. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  638. XCTAssertNil(promptFeedback.blockReasonMessage)
  639. } catch {
  640. XCTFail("Should throw a promptBlocked")
  641. }
  642. }
  643. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  644. MockURLProtocol
  645. .requestHandler = try httpRequestHandler(
  646. forResource: "unary-failure-prompt-blocked-safety-with-message",
  647. withExtension: "json",
  648. subdirectory: vertexSubdirectory
  649. )
  650. do {
  651. _ = try await model.generateContent(testPrompt)
  652. XCTFail("Should throw")
  653. } catch let GenerateContentError.promptBlocked(response) {
  654. XCTAssertNil(response.text)
  655. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  656. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  657. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  658. } catch {
  659. XCTFail("Should throw a promptBlocked")
  660. }
  661. }
  662. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  663. MockURLProtocol
  664. .requestHandler = try httpRequestHandler(
  665. forResource: "unary-failure-unknown-enum-finish-reason",
  666. withExtension: "json",
  667. subdirectory: vertexSubdirectory
  668. )
  669. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  670. do {
  671. _ = try await model.generateContent(testPrompt)
  672. XCTFail("Should throw")
  673. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  674. XCTAssertEqual(reason, unknownFinishReason)
  675. XCTAssertEqual(response.text, "Some text")
  676. } catch {
  677. XCTFail("Should throw a responseStoppedEarly")
  678. }
  679. }
  680. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  681. MockURLProtocol
  682. .requestHandler = try httpRequestHandler(
  683. forResource: "unary-failure-unknown-enum-prompt-blocked",
  684. withExtension: "json",
  685. subdirectory: vertexSubdirectory
  686. )
  687. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  688. do {
  689. _ = try await model.generateContent(testPrompt)
  690. XCTFail("Should throw")
  691. } catch let GenerateContentError.promptBlocked(response) {
  692. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  693. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  694. } catch {
  695. XCTFail("Should throw a promptBlocked")
  696. }
  697. }
  698. func testGenerateContent_failure_unknownModel() async throws {
  699. let expectedStatusCode = 404
  700. MockURLProtocol
  701. .requestHandler = try httpRequestHandler(
  702. forResource: "unary-failure-unknown-model",
  703. withExtension: "json",
  704. subdirectory: vertexSubdirectory,
  705. statusCode: 404
  706. )
  707. do {
  708. _ = try await model.generateContent(testPrompt)
  709. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  710. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  711. XCTAssertEqual(rpcError.status, .notFound)
  712. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  713. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  714. } catch {
  715. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  716. }
  717. }
  718. func testGenerateContent_failure_nonHTTPResponse() async throws {
  719. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  720. var responseError: Error?
  721. var content: GenerateContentResponse?
  722. do {
  723. content = try await model.generateContent(testPrompt)
  724. } catch {
  725. responseError = error
  726. }
  727. XCTAssertNil(content)
  728. XCTAssertNotNil(responseError)
  729. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  730. guard case let .internalError(underlyingError) = generateContentError else {
  731. XCTFail("Not an internal error: \(generateContentError)")
  732. return
  733. }
  734. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  735. let underlyingNSError = underlyingError as NSError
  736. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  737. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  738. }
  739. func testGenerateContent_failure_invalidResponse() async throws {
  740. MockURLProtocol.requestHandler = try httpRequestHandler(
  741. forResource: "unary-failure-invalid-response",
  742. withExtension: "json",
  743. subdirectory: vertexSubdirectory
  744. )
  745. var responseError: Error?
  746. var content: GenerateContentResponse?
  747. do {
  748. content = try await model.generateContent(testPrompt)
  749. } catch {
  750. responseError = error
  751. }
  752. XCTAssertNil(content)
  753. XCTAssertNotNil(responseError)
  754. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  755. guard case let .internalError(underlyingError) = generateContentError else {
  756. XCTFail("Not an internal error: \(generateContentError)")
  757. return
  758. }
  759. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  760. guard case let .dataCorrupted(context) = decodingError else {
  761. XCTFail("Not a data corrupted error: \(decodingError)")
  762. return
  763. }
  764. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  765. }
  766. func testGenerateContent_failure_malformedContent() async throws {
  767. MockURLProtocol
  768. .requestHandler = try httpRequestHandler(
  769. forResource: "unary-failure-malformed-content",
  770. withExtension: "json",
  771. subdirectory: vertexSubdirectory
  772. )
  773. var responseError: Error?
  774. var content: GenerateContentResponse?
  775. do {
  776. content = try await model.generateContent(testPrompt)
  777. } catch {
  778. responseError = error
  779. }
  780. XCTAssertNil(content)
  781. XCTAssertNotNil(responseError)
  782. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  783. guard case let .internalError(underlyingError) = generateContentError else {
  784. XCTFail("Not an internal error: \(generateContentError)")
  785. return
  786. }
  787. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  788. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  789. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  790. return
  791. }
  792. _ = try XCTUnwrap(
  793. malformedContentUnderlyingError as? DecodingError,
  794. "Not a decoding error: \(malformedContentUnderlyingError)"
  795. )
  796. }
  797. func testGenerateContentMissingSafetyRatings() async throws {
  798. MockURLProtocol.requestHandler = try httpRequestHandler(
  799. forResource: "unary-success-missing-safety-ratings",
  800. withExtension: "json",
  801. subdirectory: vertexSubdirectory
  802. )
  803. let content = try await model.generateContent(testPrompt)
  804. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  805. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  806. XCTAssertEqual(content.text, "This is the generated content.")
  807. }
  808. func testGenerateContent_requestOptions_customTimeout() async throws {
  809. let expectedTimeout = 150.0
  810. MockURLProtocol
  811. .requestHandler = try httpRequestHandler(
  812. forResource: "unary-success-basic-reply-short",
  813. withExtension: "json",
  814. subdirectory: vertexSubdirectory,
  815. timeout: expectedTimeout
  816. )
  817. let requestOptions = RequestOptions(timeout: expectedTimeout)
  818. model = GenerativeModel(
  819. modelName: testModelName,
  820. modelResourceName: testModelResourceName,
  821. firebaseInfo: testFirebaseInfo(),
  822. apiConfig: apiConfig,
  823. tools: nil,
  824. requestOptions: requestOptions,
  825. urlSession: urlSession
  826. )
  827. let response = try await model.generateContent(testPrompt)
  828. XCTAssertEqual(response.candidates.count, 1)
  829. }
  830. // MARK: - Generate Content (Streaming)
  831. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  832. MockURLProtocol
  833. .requestHandler = try httpRequestHandler(
  834. forResource: "unary-failure-api-key",
  835. withExtension: "json",
  836. subdirectory: vertexSubdirectory
  837. )
  838. do {
  839. let stream = try model.generateContentStream("Hi")
  840. for try await _ in stream {
  841. XCTFail("No content is there, this shouldn't happen.")
  842. }
  843. } catch let GenerateContentError.internalError(error as BackendError) {
  844. XCTAssertEqual(error.httpResponseCode, 400)
  845. XCTAssertEqual(error.status, .invalidArgument)
  846. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  847. XCTAssertTrue(error.localizedDescription.contains(error.message))
  848. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  849. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  850. let nsError = error as NSError
  851. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  852. XCTAssertEqual(nsError.code, error.httpResponseCode)
  853. return
  854. }
  855. XCTFail("Should have caught an error.")
  856. }
  857. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  858. let expectedStatusCode = 403
  859. MockURLProtocol
  860. .requestHandler = try httpRequestHandler(
  861. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  862. withExtension: "json",
  863. subdirectory: vertexSubdirectory,
  864. statusCode: expectedStatusCode
  865. )
  866. do {
  867. let stream = try model.generateContentStream(testPrompt)
  868. for try await _ in stream {
  869. XCTFail("No content is there, this shouldn't happen.")
  870. }
  871. } catch let GenerateContentError.internalError(error as BackendError) {
  872. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  873. XCTAssertEqual(error.status, .permissionDenied)
  874. XCTAssertTrue(error.message
  875. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  876. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  877. return
  878. }
  879. XCTFail("Should have caught an error.")
  880. }
  881. func testGenerateContentStream_failureEmptyContent() async throws {
  882. MockURLProtocol
  883. .requestHandler = try httpRequestHandler(
  884. forResource: "streaming-failure-empty-content",
  885. withExtension: "txt",
  886. subdirectory: vertexSubdirectory
  887. )
  888. do {
  889. let stream = try model.generateContentStream("Hi")
  890. for try await _ in stream {
  891. XCTFail("No content is there, this shouldn't happen.")
  892. }
  893. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  894. // Underlying error is as expected, nothing else to check.
  895. return
  896. }
  897. XCTFail("Should have caught an error.")
  898. }
  899. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  900. MockURLProtocol
  901. .requestHandler = try httpRequestHandler(
  902. forResource: "streaming-failure-finish-reason-safety",
  903. withExtension: "txt",
  904. subdirectory: vertexSubdirectory
  905. )
  906. do {
  907. let stream = try model.generateContentStream("Hi")
  908. for try await _ in stream {
  909. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  910. }
  911. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  912. XCTAssertEqual(reason, .safety)
  913. let candidate = try XCTUnwrap(response.candidates.first)
  914. XCTAssertEqual(candidate.finishReason, reason)
  915. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  916. return
  917. }
  918. XCTFail("Should have caught an error.")
  919. }
  920. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  921. MockURLProtocol
  922. .requestHandler = try httpRequestHandler(
  923. forResource: "streaming-failure-prompt-blocked-safety",
  924. withExtension: "txt",
  925. subdirectory: vertexSubdirectory
  926. )
  927. do {
  928. let stream = try model.generateContentStream("Hi")
  929. for try await _ in stream {
  930. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  931. }
  932. } catch let GenerateContentError.promptBlocked(response) {
  933. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  934. XCTAssertEqual(promptFeedback.blockReason, .safety)
  935. XCTAssertNil(promptFeedback.blockReasonMessage)
  936. return
  937. }
  938. XCTFail("Should have caught an error.")
  939. }
  940. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  941. MockURLProtocol
  942. .requestHandler = try httpRequestHandler(
  943. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  944. withExtension: "txt",
  945. subdirectory: vertexSubdirectory
  946. )
  947. do {
  948. let stream = try model.generateContentStream("Hi")
  949. for try await _ in stream {
  950. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  951. }
  952. } catch let GenerateContentError.promptBlocked(response) {
  953. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  954. XCTAssertEqual(promptFeedback.blockReason, .safety)
  955. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  956. return
  957. }
  958. XCTFail("Should have caught an error.")
  959. }
  960. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  961. MockURLProtocol
  962. .requestHandler = try httpRequestHandler(
  963. forResource: "streaming-failure-unknown-finish-enum",
  964. withExtension: "txt",
  965. subdirectory: vertexSubdirectory
  966. )
  967. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  968. let stream = try model.generateContentStream("Hi")
  969. do {
  970. for try await content in stream {
  971. XCTAssertNotNil(content.text)
  972. }
  973. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  974. XCTAssertEqual(reason, unknownFinishReason)
  975. return
  976. }
  977. XCTFail("Should have caught an error.")
  978. }
  979. func testGenerateContentStream_successBasicReplyLong() async throws {
  980. MockURLProtocol
  981. .requestHandler = try httpRequestHandler(
  982. forResource: "streaming-success-basic-reply-long",
  983. withExtension: "txt",
  984. subdirectory: vertexSubdirectory
  985. )
  986. var responses = 0
  987. let stream = try model.generateContentStream("Hi")
  988. for try await content in stream {
  989. XCTAssertNotNil(content.text)
  990. responses += 1
  991. }
  992. XCTAssertEqual(responses, 6)
  993. }
  994. func testGenerateContentStream_successBasicReplyShort() async throws {
  995. MockURLProtocol
  996. .requestHandler = try httpRequestHandler(
  997. forResource: "streaming-success-basic-reply-short",
  998. withExtension: "txt",
  999. subdirectory: vertexSubdirectory
  1000. )
  1001. var responses = 0
  1002. let stream = try model.generateContentStream("Hi")
  1003. for try await content in stream {
  1004. XCTAssertNotNil(content.text)
  1005. responses += 1
  1006. }
  1007. XCTAssertEqual(responses, 1)
  1008. }
  1009. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1010. MockURLProtocol
  1011. .requestHandler = try httpRequestHandler(
  1012. forResource: "streaming-success-unknown-safety-enum",
  1013. withExtension: "txt",
  1014. subdirectory: vertexSubdirectory
  1015. )
  1016. let unknownSafetyRating = SafetyRating(
  1017. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1018. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1019. probabilityScore: 0.0,
  1020. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1021. severityScore: 0.0,
  1022. blocked: false
  1023. )
  1024. var foundUnknownSafetyRating = false
  1025. let stream = try model.generateContentStream("Hi")
  1026. for try await content in stream {
  1027. XCTAssertNotNil(content.text)
  1028. if let ratings = content.candidates.first?.safetyRatings,
  1029. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1030. foundUnknownSafetyRating = true
  1031. }
  1032. }
  1033. XCTAssertTrue(foundUnknownSafetyRating)
  1034. }
  1035. func testGenerateContentStream_successWithCitations() async throws {
  1036. MockURLProtocol
  1037. .requestHandler = try httpRequestHandler(
  1038. forResource: "streaming-success-citations",
  1039. withExtension: "txt",
  1040. subdirectory: vertexSubdirectory
  1041. )
  1042. let expectedPublicationDate = DateComponents(
  1043. calendar: Calendar(identifier: .gregorian),
  1044. year: 2014,
  1045. month: 3,
  1046. day: 30
  1047. )
  1048. let stream = try model.generateContentStream("Hi")
  1049. var citations = [Citation]()
  1050. var responses = [GenerateContentResponse]()
  1051. for try await content in stream {
  1052. responses.append(content)
  1053. XCTAssertNotNil(content.text)
  1054. let candidate = try XCTUnwrap(content.candidates.first)
  1055. if let sources = candidate.citationMetadata?.citations {
  1056. citations.append(contentsOf: sources)
  1057. }
  1058. }
  1059. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1060. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1061. XCTAssertEqual(citations.count, 6)
  1062. XCTAssertTrue(citations
  1063. .contains {
  1064. $0.startIndex == 0 && $0.endIndex == 128
  1065. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1066. && $0.license == nil && $0.publicationDate == nil
  1067. })
  1068. XCTAssertTrue(citations
  1069. .contains {
  1070. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1071. && $0.title == "some-citation-2" && $0.license == nil
  1072. && $0.publicationDate == expectedPublicationDate
  1073. })
  1074. XCTAssertTrue(citations
  1075. .contains {
  1076. $0.startIndex == 272 && $0.endIndex == 431
  1077. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1078. && $0.license == "mit" && $0.publicationDate == nil
  1079. })
  1080. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1081. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1082. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1083. }
  1084. func testGenerateContentStream_appCheck_validToken() async throws {
  1085. let appCheckToken = "test-valid-token"
  1086. model = GenerativeModel(
  1087. modelName: testModelName,
  1088. modelResourceName: testModelResourceName,
  1089. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
  1090. apiConfig: apiConfig,
  1091. tools: nil,
  1092. requestOptions: RequestOptions(),
  1093. urlSession: urlSession
  1094. )
  1095. MockURLProtocol
  1096. .requestHandler = try httpRequestHandler(
  1097. forResource: "streaming-success-basic-reply-short",
  1098. withExtension: "txt",
  1099. subdirectory: vertexSubdirectory,
  1100. appCheckToken: appCheckToken
  1101. )
  1102. let stream = try model.generateContentStream(testPrompt)
  1103. for try await _ in stream {}
  1104. }
  1105. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1106. model = GenerativeModel(
  1107. modelName: testModelName,
  1108. modelResourceName: testModelResourceName,
  1109. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
  1110. apiConfig: apiConfig,
  1111. tools: nil,
  1112. requestOptions: RequestOptions(),
  1113. urlSession: urlSession
  1114. )
  1115. MockURLProtocol
  1116. .requestHandler = try httpRequestHandler(
  1117. forResource: "streaming-success-basic-reply-short",
  1118. withExtension: "txt",
  1119. subdirectory: vertexSubdirectory,
  1120. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1121. )
  1122. let stream = try model.generateContentStream(testPrompt)
  1123. for try await _ in stream {}
  1124. }
  1125. func testGenerateContentStream_usageMetadata() async throws {
  1126. MockURLProtocol
  1127. .requestHandler = try httpRequestHandler(
  1128. forResource: "streaming-success-basic-reply-short",
  1129. withExtension: "txt",
  1130. subdirectory: vertexSubdirectory
  1131. )
  1132. var responses = [GenerateContentResponse]()
  1133. let stream = try model.generateContentStream(testPrompt)
  1134. for try await response in stream {
  1135. responses.append(response)
  1136. }
  1137. for (index, response) in responses.enumerated() {
  1138. if index == responses.endIndex - 1 {
  1139. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1140. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1141. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1142. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1143. } else {
  1144. // Only the last streamed response contains usage metadata
  1145. XCTAssertNil(response.usageMetadata)
  1146. }
  1147. }
  1148. }
  1149. func testGenerateContentStream_errorMidStream() async throws {
  1150. MockURLProtocol.requestHandler = try httpRequestHandler(
  1151. forResource: "streaming-failure-error-mid-stream",
  1152. withExtension: "txt",
  1153. subdirectory: vertexSubdirectory
  1154. )
  1155. var responseCount = 0
  1156. do {
  1157. let stream = try model.generateContentStream("Hi")
  1158. for try await content in stream {
  1159. XCTAssertNotNil(content.text)
  1160. responseCount += 1
  1161. }
  1162. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1163. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1164. XCTAssertEqual(rpcError.status, .cancelled)
  1165. // Check the content count is correct.
  1166. XCTAssertEqual(responseCount, 2)
  1167. return
  1168. }
  1169. XCTFail("Expected an internalError with an RPCError.")
  1170. }
  1171. func testGenerateContentStream_nonHTTPResponse() async throws {
  1172. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  1173. let stream = try model.generateContentStream("Hi")
  1174. do {
  1175. for try await content in stream {
  1176. XCTFail("Unexpected content in stream: \(content)")
  1177. }
  1178. } catch let GenerateContentError.internalError(underlying) {
  1179. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1180. return
  1181. }
  1182. XCTFail("Expected an internal error.")
  1183. }
  1184. func testGenerateContentStream_invalidResponse() async throws {
  1185. MockURLProtocol
  1186. .requestHandler = try httpRequestHandler(
  1187. forResource: "streaming-failure-invalid-json",
  1188. withExtension: "txt",
  1189. subdirectory: vertexSubdirectory
  1190. )
  1191. let stream = try model.generateContentStream(testPrompt)
  1192. do {
  1193. for try await content in stream {
  1194. XCTFail("Unexpected content in stream: \(content)")
  1195. }
  1196. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1197. guard case let .dataCorrupted(context) = underlying else {
  1198. XCTFail("Not a data corrupted error: \(underlying)")
  1199. return
  1200. }
  1201. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1202. return
  1203. }
  1204. XCTFail("Expected an internal error.")
  1205. }
  1206. func testGenerateContentStream_malformedContent() async throws {
  1207. MockURLProtocol
  1208. .requestHandler = try httpRequestHandler(
  1209. forResource: "streaming-failure-malformed-content",
  1210. withExtension: "txt",
  1211. subdirectory: vertexSubdirectory
  1212. )
  1213. let stream = try model.generateContentStream(testPrompt)
  1214. do {
  1215. for try await content in stream {
  1216. XCTFail("Unexpected content in stream: \(content)")
  1217. }
  1218. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1219. guard case let .malformedContent(contentError) = underlyingError else {
  1220. XCTFail("Not a malformed content error: \(underlyingError)")
  1221. return
  1222. }
  1223. XCTAssert(contentError is DecodingError)
  1224. return
  1225. }
  1226. XCTFail("Expected an internal decoding error.")
  1227. }
  1228. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1229. let expectedTimeout = 150.0
  1230. MockURLProtocol
  1231. .requestHandler = try httpRequestHandler(
  1232. forResource: "streaming-success-basic-reply-short",
  1233. withExtension: "txt",
  1234. subdirectory: vertexSubdirectory,
  1235. timeout: expectedTimeout
  1236. )
  1237. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1238. model = GenerativeModel(
  1239. modelName: testModelName,
  1240. modelResourceName: testModelResourceName,
  1241. firebaseInfo: testFirebaseInfo(),
  1242. apiConfig: apiConfig,
  1243. tools: nil,
  1244. requestOptions: requestOptions,
  1245. urlSession: urlSession
  1246. )
  1247. var responses = 0
  1248. let stream = try model.generateContentStream(testPrompt)
  1249. for try await content in stream {
  1250. XCTAssertNotNil(content.text)
  1251. responses += 1
  1252. }
  1253. XCTAssertEqual(responses, 1)
  1254. }
  1255. // MARK: - Count Tokens
  1256. func testCountTokens_succeeds() async throws {
  1257. MockURLProtocol.requestHandler = try httpRequestHandler(
  1258. forResource: "unary-success-total-tokens",
  1259. withExtension: "json",
  1260. subdirectory: vertexSubdirectory
  1261. )
  1262. let response = try await model.countTokens("Why is the sky blue?")
  1263. XCTAssertEqual(response.totalTokens, 6)
  1264. XCTAssertEqual(response.totalBillableCharacters, 16)
  1265. }
  1266. func testCountTokens_succeeds_detailed() async throws {
  1267. MockURLProtocol.requestHandler = try httpRequestHandler(
  1268. forResource: "unary-success-detailed-token-response",
  1269. withExtension: "json",
  1270. subdirectory: vertexSubdirectory
  1271. )
  1272. let response = try await model.countTokens("Why is the sky blue?")
  1273. XCTAssertEqual(response.totalTokens, 1837)
  1274. XCTAssertEqual(response.totalBillableCharacters, 117)
  1275. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1276. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1277. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1278. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1279. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1280. }
  1281. func testCountTokens_succeeds_allOptions() async throws {
  1282. MockURLProtocol.requestHandler = try httpRequestHandler(
  1283. forResource: "unary-success-total-tokens",
  1284. withExtension: "json",
  1285. subdirectory: vertexSubdirectory
  1286. )
  1287. let generationConfig = GenerationConfig(
  1288. temperature: 0.5,
  1289. topP: 0.9,
  1290. topK: 3,
  1291. candidateCount: 1,
  1292. maxOutputTokens: 1024,
  1293. stopSequences: ["test-stop"],
  1294. responseMIMEType: "text/plain"
  1295. )
  1296. let sumFunction = FunctionDeclaration(
  1297. name: "sum",
  1298. description: "Add two integers.",
  1299. parameters: ["x": .integer(), "y": .integer()]
  1300. )
  1301. let systemInstruction = ModelContent(
  1302. role: "system",
  1303. parts: "You are a calculator. Use the provided tools."
  1304. )
  1305. model = GenerativeModel(
  1306. modelName: testModelName,
  1307. modelResourceName: testModelResourceName,
  1308. firebaseInfo: testFirebaseInfo(),
  1309. apiConfig: apiConfig,
  1310. generationConfig: generationConfig,
  1311. tools: [Tool(functionDeclarations: [sumFunction])],
  1312. systemInstruction: systemInstruction,
  1313. requestOptions: RequestOptions(),
  1314. urlSession: urlSession
  1315. )
  1316. let response = try await model.countTokens("Why is the sky blue?")
  1317. XCTAssertEqual(response.totalTokens, 6)
  1318. XCTAssertEqual(response.totalBillableCharacters, 16)
  1319. }
  1320. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1321. MockURLProtocol.requestHandler = try httpRequestHandler(
  1322. forResource: "unary-success-no-billable-characters",
  1323. withExtension: "json",
  1324. subdirectory: vertexSubdirectory
  1325. )
  1326. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1327. XCTAssertEqual(response.totalTokens, 258)
  1328. XCTAssertNil(response.totalBillableCharacters)
  1329. }
  1330. func testCountTokens_modelNotFound() async throws {
  1331. MockURLProtocol.requestHandler = try httpRequestHandler(
  1332. forResource: "unary-failure-model-not-found", withExtension: "json",
  1333. subdirectory: vertexSubdirectory,
  1334. statusCode: 404
  1335. )
  1336. do {
  1337. _ = try await model.countTokens("Why is the sky blue?")
  1338. XCTFail("Request should not have succeeded.")
  1339. } catch let rpcError as BackendError {
  1340. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1341. XCTAssertEqual(rpcError.status, .notFound)
  1342. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1343. return
  1344. }
  1345. XCTFail("Expected internal RPCError.")
  1346. }
  1347. func testCountTokens_requestOptions_customTimeout() async throws {
  1348. let expectedTimeout = 150.0
  1349. MockURLProtocol
  1350. .requestHandler = try httpRequestHandler(
  1351. forResource: "unary-success-total-tokens",
  1352. withExtension: "json",
  1353. subdirectory: vertexSubdirectory,
  1354. timeout: expectedTimeout
  1355. )
  1356. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1357. model = GenerativeModel(
  1358. modelName: testModelName,
  1359. modelResourceName: testModelResourceName,
  1360. firebaseInfo: testFirebaseInfo(),
  1361. apiConfig: apiConfig,
  1362. tools: nil,
  1363. requestOptions: requestOptions,
  1364. urlSession: urlSession
  1365. )
  1366. let response = try await model.countTokens(testPrompt)
  1367. XCTAssertEqual(response.totalTokens, 6)
  1368. }
  1369. // MARK: - Helpers
  1370. private func testFirebaseInfo(appCheck: AppCheckInterop? = nil,
  1371. auth: AuthInterop? = nil,
  1372. privateAppID: Bool = false) -> FirebaseInfo {
  1373. let app = FirebaseApp(instanceWithName: "testApp",
  1374. options: FirebaseOptions(googleAppID: "ignore",
  1375. gcmSenderID: "ignore"))
  1376. app.isDataCollectionDefaultEnabled = !privateAppID
  1377. return FirebaseInfo(
  1378. appCheck: appCheck,
  1379. auth: auth,
  1380. projectID: "my-project-id",
  1381. apiKey: "API_KEY",
  1382. firebaseAppID: "My app ID",
  1383. firebaseApp: app
  1384. )
  1385. }
  1386. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1387. URLResponse,
  1388. AsyncLineSequence<URL.AsyncBytes>?
  1389. )) {
  1390. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1391. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1392. #if os(watchOS)
  1393. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1394. #endif // os(watchOS)
  1395. return { request in
  1396. // This is *not* an HTTPURLResponse
  1397. let response = URLResponse(
  1398. url: request.url!,
  1399. mimeType: nil,
  1400. expectedContentLength: 0,
  1401. textEncodingName: nil
  1402. )
  1403. return (response, nil)
  1404. }
  1405. }
  1406. private func httpRequestHandler(forResource name: String,
  1407. withExtension ext: String,
  1408. subdirectory subpath: String,
  1409. statusCode: Int = 200,
  1410. timeout: TimeInterval = RequestOptions().timeout,
  1411. appCheckToken: String? = nil,
  1412. authToken: String? = nil,
  1413. dataCollection: Bool = true) throws -> ((URLRequest) throws -> (
  1414. URLResponse,
  1415. AsyncLineSequence<URL.AsyncBytes>?
  1416. )) {
  1417. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1418. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1419. #if os(watchOS)
  1420. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1421. #endif // os(watchOS)
  1422. let bundle = BundleTestUtil.bundle()
  1423. let fileURL = try XCTUnwrap(
  1424. bundle.url(forResource: name, withExtension: ext, subdirectory: subpath)
  1425. )
  1426. return { request in
  1427. let requestURL = try XCTUnwrap(request.url)
  1428. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1429. XCTAssertEqual(request.timeoutInterval, timeout)
  1430. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1431. .components(separatedBy: " ")
  1432. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1433. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1434. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1435. let firebaseAppID = request.value(forHTTPHeaderField: "X-Firebase-AppId")
  1436. let appVersion = request.value(forHTTPHeaderField: "X-Firebase-AppVersion")
  1437. let expectedAppVersion =
  1438. try? XCTUnwrap(Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String)
  1439. XCTAssertEqual(firebaseAppID, dataCollection ? "My app ID" : nil)
  1440. XCTAssertEqual(appVersion, dataCollection ? expectedAppVersion : nil)
  1441. if let authToken {
  1442. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1443. } else {
  1444. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1445. }
  1446. let response = try XCTUnwrap(HTTPURLResponse(
  1447. url: requestURL,
  1448. statusCode: statusCode,
  1449. httpVersion: nil,
  1450. headerFields: nil
  1451. ))
  1452. return (response, fileURL.lines)
  1453. }
  1454. }
  1455. }
  1456. private extension String {
  1457. /// Returns the number of occurrences of `substring` in the `String`.
  1458. func occurrenceCount(of substring: String) -> Int {
  1459. return components(separatedBy: substring).count - 1
  1460. }
  1461. }
  1462. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1463. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1464. /// The placeholder token value returned when an error occurs
  1465. static let placeholderTokenValue = "placeholder-token"
  1466. var token: String
  1467. var error: Error?
  1468. private init(token: String, error: Error?) {
  1469. self.token = token
  1470. self.error = error
  1471. }
  1472. convenience init(token: String) {
  1473. self.init(token: token, error: nil)
  1474. }
  1475. convenience init(error: Error) {
  1476. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1477. }
  1478. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1479. return AppCheckTokenResultInteropFake(token: token, error: error)
  1480. }
  1481. func tokenDidChangeNotificationName() -> String {
  1482. fatalError("\(#function) not implemented.")
  1483. }
  1484. func notificationTokenKey() -> String {
  1485. fatalError("\(#function) not implemented.")
  1486. }
  1487. func notificationAppNameKey() -> String {
  1488. fatalError("\(#function) not implemented.")
  1489. }
  1490. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1491. var token: String
  1492. var error: Error?
  1493. init(token: String, error: Error?) {
  1494. self.token = token
  1495. self.error = error
  1496. }
  1497. }
  1498. }
  1499. struct AppCheckErrorFake: Error {}
  1500. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1501. extension SafetyRating: Swift.Comparable {
  1502. public static func < (lhs: SafetyRating, rhs: SafetyRating) -> Bool {
  1503. return lhs.category.rawValue < rhs.category.rawValue
  1504. }
  1505. }