GenerativeModelTests.swift 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let testModelResourceName =
  57. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  58. var urlSession: URLSession!
  59. var model: GenerativeModel!
  60. override func setUp() async throws {
  61. let configuration = URLSessionConfiguration.default
  62. configuration.protocolClasses = [MockURLProtocol.self]
  63. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  64. model = GenerativeModel(
  65. name: testModelResourceName,
  66. projectID: "my-project-id",
  67. apiKey: "API_KEY",
  68. tools: nil,
  69. requestOptions: RequestOptions(),
  70. appCheck: nil,
  71. auth: nil,
  72. urlSession: urlSession
  73. )
  74. }
  75. override func tearDown() {
  76. MockURLProtocol.requestHandler = nil
  77. }
  78. // MARK: - Generate Content
  79. func testGenerateContent_success_basicReplyLong() async throws {
  80. MockURLProtocol
  81. .requestHandler = try httpRequestHandler(
  82. forResource: "unary-success-basic-reply-long",
  83. withExtension: "json"
  84. )
  85. let response = try await model.generateContent(testPrompt)
  86. XCTAssertEqual(response.candidates.count, 1)
  87. let candidate = try XCTUnwrap(response.candidates.first)
  88. let finishReason = try XCTUnwrap(candidate.finishReason)
  89. XCTAssertEqual(finishReason, .stop)
  90. XCTAssertEqual(candidate.safetyRatings.count, 4)
  91. XCTAssertEqual(candidate.content.parts.count, 1)
  92. let part = try XCTUnwrap(candidate.content.parts.first)
  93. let partText = try XCTUnwrap(part as? TextPart).text
  94. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  95. XCTAssertEqual(response.text, partText)
  96. XCTAssertEqual(response.functionCalls, [])
  97. }
  98. func testGenerateContent_success_basicReplyShort() async throws {
  99. MockURLProtocol
  100. .requestHandler = try httpRequestHandler(
  101. forResource: "unary-success-basic-reply-short",
  102. withExtension: "json"
  103. )
  104. let response = try await model.generateContent(testPrompt)
  105. XCTAssertEqual(response.candidates.count, 1)
  106. let candidate = try XCTUnwrap(response.candidates.first)
  107. let finishReason = try XCTUnwrap(candidate.finishReason)
  108. XCTAssertEqual(finishReason, .stop)
  109. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  110. XCTAssertEqual(candidate.content.parts.count, 1)
  111. let part = try XCTUnwrap(candidate.content.parts.first)
  112. let textPart = try XCTUnwrap(part as? TextPart)
  113. XCTAssertEqual(textPart.text, "Mountain View, California")
  114. XCTAssertEqual(response.text, textPart.text)
  115. XCTAssertEqual(response.functionCalls, [])
  116. }
  117. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  118. MockURLProtocol
  119. .requestHandler = try httpRequestHandler(
  120. forResource: "unary-success-basic-response-long-usage-metadata",
  121. withExtension: "json"
  122. )
  123. let response = try await model.generateContent(testPrompt)
  124. XCTAssertEqual(response.candidates.count, 1)
  125. let candidate = try XCTUnwrap(response.candidates.first)
  126. let finishReason = try XCTUnwrap(candidate.finishReason)
  127. XCTAssertEqual(finishReason, .stop)
  128. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  129. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  130. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  131. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  132. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  133. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  134. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  135. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  136. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  137. }
  138. func testGenerateContent_success_citations() async throws {
  139. MockURLProtocol
  140. .requestHandler = try httpRequestHandler(
  141. forResource: "unary-success-citations",
  142. withExtension: "json"
  143. )
  144. let expectedPublicationDate = DateComponents(
  145. calendar: Calendar(identifier: .gregorian),
  146. year: 2019,
  147. month: 5,
  148. day: 10
  149. )
  150. let response = try await model.generateContent(testPrompt)
  151. XCTAssertEqual(response.candidates.count, 1)
  152. let candidate = try XCTUnwrap(response.candidates.first)
  153. XCTAssertEqual(candidate.content.parts.count, 1)
  154. XCTAssertEqual(response.text, "Some information cited from an external source")
  155. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  156. XCTAssertEqual(citationMetadata.citations.count, 3)
  157. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  158. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  159. XCTAssertEqual(citationSource1.startIndex, 0)
  160. XCTAssertEqual(citationSource1.endIndex, 128)
  161. XCTAssertNil(citationSource1.title)
  162. XCTAssertNil(citationSource1.license)
  163. XCTAssertNil(citationSource1.publicationDate)
  164. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  165. XCTAssertEqual(citationSource2.title, "some-citation-2")
  166. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  167. XCTAssertEqual(citationSource2.startIndex, 130)
  168. XCTAssertEqual(citationSource2.endIndex, 265)
  169. XCTAssertNil(citationSource2.uri)
  170. XCTAssertNil(citationSource2.license)
  171. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  172. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  173. XCTAssertEqual(citationSource3.startIndex, 272)
  174. XCTAssertEqual(citationSource3.endIndex, 431)
  175. XCTAssertEqual(citationSource3.license, "mit")
  176. XCTAssertNil(citationSource3.title)
  177. XCTAssertNil(citationSource3.publicationDate)
  178. }
  179. func testGenerateContent_success_quoteReply() async throws {
  180. MockURLProtocol
  181. .requestHandler = try httpRequestHandler(
  182. forResource: "unary-success-quote-reply",
  183. withExtension: "json"
  184. )
  185. let response = try await model.generateContent(testPrompt)
  186. XCTAssertEqual(response.candidates.count, 1)
  187. let candidate = try XCTUnwrap(response.candidates.first)
  188. let finishReason = try XCTUnwrap(candidate.finishReason)
  189. XCTAssertEqual(finishReason, .stop)
  190. XCTAssertEqual(candidate.safetyRatings.count, 4)
  191. XCTAssertEqual(candidate.content.parts.count, 1)
  192. let part = try XCTUnwrap(candidate.content.parts.first)
  193. let textPart = try XCTUnwrap(part as? TextPart)
  194. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  195. XCTAssertEqual(response.text, textPart.text)
  196. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  197. XCTAssertNil(promptFeedback.blockReason)
  198. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  199. }
  200. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  201. let expectedSafetyRatings = [
  202. SafetyRating(
  203. category: .harassment,
  204. probability: .medium,
  205. probabilityScore: 0.0,
  206. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  207. severityScore: 0.0,
  208. blocked: false
  209. ),
  210. SafetyRating(
  211. category: .dangerousContent,
  212. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  213. probabilityScore: 0.0,
  214. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  215. severityScore: 0.0,
  216. blocked: false
  217. ),
  218. SafetyRating(
  219. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  220. probability: .high,
  221. probabilityScore: 0.0,
  222. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  223. severityScore: 0.0,
  224. blocked: false
  225. ),
  226. ]
  227. MockURLProtocol
  228. .requestHandler = try httpRequestHandler(
  229. forResource: "unary-success-unknown-enum-safety-ratings",
  230. withExtension: "json"
  231. )
  232. let response = try await model.generateContent(testPrompt)
  233. XCTAssertEqual(response.text, "Some text")
  234. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  235. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  236. }
  237. func testGenerateContent_success_prefixedModelName() async throws {
  238. MockURLProtocol
  239. .requestHandler = try httpRequestHandler(
  240. forResource: "unary-success-basic-reply-short",
  241. withExtension: "json"
  242. )
  243. let model = GenerativeModel(
  244. // Model name is prefixed with "models/".
  245. name: "models/test-model",
  246. projectID: "my-project-id",
  247. apiKey: "API_KEY",
  248. tools: nil,
  249. requestOptions: RequestOptions(),
  250. appCheck: nil,
  251. auth: nil,
  252. urlSession: urlSession
  253. )
  254. _ = try await model.generateContent(testPrompt)
  255. }
  256. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  257. MockURLProtocol
  258. .requestHandler = try httpRequestHandler(
  259. forResource: "unary-success-function-call-empty-arguments",
  260. withExtension: "json"
  261. )
  262. let response = try await model.generateContent(testPrompt)
  263. XCTAssertEqual(response.candidates.count, 1)
  264. let candidate = try XCTUnwrap(response.candidates.first)
  265. XCTAssertEqual(candidate.content.parts.count, 1)
  266. let part = try XCTUnwrap(candidate.content.parts.first)
  267. guard let functionCall = part as? FunctionCallPart else {
  268. XCTFail("Part is not a FunctionCall.")
  269. return
  270. }
  271. XCTAssertEqual(functionCall.name, "current_time")
  272. XCTAssertTrue(functionCall.args.isEmpty)
  273. XCTAssertEqual(response.functionCalls, [functionCall])
  274. }
  275. func testGenerateContent_success_functionCall_noArguments() async throws {
  276. MockURLProtocol
  277. .requestHandler = try httpRequestHandler(
  278. forResource: "unary-success-function-call-no-arguments",
  279. withExtension: "json"
  280. )
  281. let response = try await model.generateContent(testPrompt)
  282. XCTAssertEqual(response.candidates.count, 1)
  283. let candidate = try XCTUnwrap(response.candidates.first)
  284. XCTAssertEqual(candidate.content.parts.count, 1)
  285. let part = try XCTUnwrap(candidate.content.parts.first)
  286. guard let functionCall = part as? FunctionCallPart else {
  287. XCTFail("Part is not a FunctionCall.")
  288. return
  289. }
  290. XCTAssertEqual(functionCall.name, "current_time")
  291. XCTAssertTrue(functionCall.args.isEmpty)
  292. XCTAssertEqual(response.functionCalls, [functionCall])
  293. }
  294. func testGenerateContent_success_functionCall_withArguments() async throws {
  295. MockURLProtocol
  296. .requestHandler = try httpRequestHandler(
  297. forResource: "unary-success-function-call-with-arguments",
  298. withExtension: "json"
  299. )
  300. let response = try await model.generateContent(testPrompt)
  301. XCTAssertEqual(response.candidates.count, 1)
  302. let candidate = try XCTUnwrap(response.candidates.first)
  303. XCTAssertEqual(candidate.content.parts.count, 1)
  304. let part = try XCTUnwrap(candidate.content.parts.first)
  305. guard let functionCall = part as? FunctionCallPart else {
  306. XCTFail("Part is not a FunctionCall.")
  307. return
  308. }
  309. XCTAssertEqual(functionCall.name, "sum")
  310. XCTAssertEqual(functionCall.args.count, 2)
  311. let argX = try XCTUnwrap(functionCall.args["x"])
  312. XCTAssertEqual(argX, .number(4))
  313. let argY = try XCTUnwrap(functionCall.args["y"])
  314. XCTAssertEqual(argY, .number(5))
  315. XCTAssertEqual(response.functionCalls, [functionCall])
  316. }
  317. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  318. MockURLProtocol
  319. .requestHandler = try httpRequestHandler(
  320. forResource: "unary-success-function-call-parallel-calls",
  321. withExtension: "json"
  322. )
  323. let response = try await model.generateContent(testPrompt)
  324. XCTAssertEqual(response.candidates.count, 1)
  325. let candidate = try XCTUnwrap(response.candidates.first)
  326. XCTAssertEqual(candidate.content.parts.count, 3)
  327. let functionCalls = response.functionCalls
  328. XCTAssertEqual(functionCalls.count, 3)
  329. }
  330. func testGenerateContent_success_functionCall_mixedContent() async throws {
  331. MockURLProtocol
  332. .requestHandler = try httpRequestHandler(
  333. forResource: "unary-success-function-call-mixed-content",
  334. withExtension: "json"
  335. )
  336. let response = try await model.generateContent(testPrompt)
  337. XCTAssertEqual(response.candidates.count, 1)
  338. let candidate = try XCTUnwrap(response.candidates.first)
  339. XCTAssertEqual(candidate.content.parts.count, 4)
  340. let functionCalls = response.functionCalls
  341. XCTAssertEqual(functionCalls.count, 2)
  342. let text = try XCTUnwrap(response.text)
  343. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  344. }
  345. func testGenerateContent_appCheck_validToken() async throws {
  346. let appCheckToken = "test-valid-token"
  347. model = GenerativeModel(
  348. name: testModelResourceName,
  349. projectID: "my-project-id",
  350. apiKey: "API_KEY",
  351. tools: nil,
  352. requestOptions: RequestOptions(),
  353. appCheck: AppCheckInteropFake(token: appCheckToken),
  354. auth: nil,
  355. urlSession: urlSession
  356. )
  357. MockURLProtocol
  358. .requestHandler = try httpRequestHandler(
  359. forResource: "unary-success-basic-reply-short",
  360. withExtension: "json",
  361. appCheckToken: appCheckToken
  362. )
  363. _ = try await model.generateContent(testPrompt)
  364. }
  365. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  366. model = GenerativeModel(
  367. name: testModelResourceName,
  368. projectID: "my-project-id",
  369. apiKey: "API_KEY",
  370. tools: nil,
  371. requestOptions: RequestOptions(),
  372. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  373. auth: nil,
  374. urlSession: urlSession
  375. )
  376. MockURLProtocol
  377. .requestHandler = try httpRequestHandler(
  378. forResource: "unary-success-basic-reply-short",
  379. withExtension: "json",
  380. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  381. )
  382. _ = try await model.generateContent(testPrompt)
  383. }
  384. func testGenerateContent_auth_validAuthToken() async throws {
  385. let authToken = "test-valid-token"
  386. model = GenerativeModel(
  387. name: testModelResourceName,
  388. projectID: "my-project-id",
  389. apiKey: "API_KEY",
  390. tools: nil,
  391. requestOptions: RequestOptions(),
  392. appCheck: nil,
  393. auth: AuthInteropFake(token: authToken),
  394. urlSession: urlSession
  395. )
  396. MockURLProtocol
  397. .requestHandler = try httpRequestHandler(
  398. forResource: "unary-success-basic-reply-short",
  399. withExtension: "json",
  400. authToken: authToken
  401. )
  402. _ = try await model.generateContent(testPrompt)
  403. }
  404. func testGenerateContent_auth_nilAuthToken() async throws {
  405. model = GenerativeModel(
  406. name: testModelResourceName,
  407. projectID: "my-project-id",
  408. apiKey: "API_KEY",
  409. tools: nil,
  410. requestOptions: RequestOptions(),
  411. appCheck: nil,
  412. auth: AuthInteropFake(token: nil),
  413. urlSession: urlSession
  414. )
  415. MockURLProtocol
  416. .requestHandler = try httpRequestHandler(
  417. forResource: "unary-success-basic-reply-short",
  418. withExtension: "json",
  419. authToken: nil
  420. )
  421. _ = try await model.generateContent(testPrompt)
  422. }
  423. func testGenerateContent_auth_authTokenRefreshError() async throws {
  424. model = GenerativeModel(
  425. name: "my-model",
  426. projectID: "my-project-id",
  427. apiKey: "API_KEY",
  428. tools: nil,
  429. requestOptions: RequestOptions(),
  430. appCheck: nil,
  431. auth: AuthInteropFake(error: AuthErrorFake()),
  432. urlSession: urlSession
  433. )
  434. MockURLProtocol
  435. .requestHandler = try httpRequestHandler(
  436. forResource: "unary-success-basic-reply-short",
  437. withExtension: "json",
  438. authToken: nil
  439. )
  440. do {
  441. _ = try await model.generateContent(testPrompt)
  442. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  443. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  444. //
  445. } catch {
  446. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  447. }
  448. }
  449. func testGenerateContent_usageMetadata() async throws {
  450. MockURLProtocol
  451. .requestHandler = try httpRequestHandler(
  452. forResource: "unary-success-basic-reply-short",
  453. withExtension: "json"
  454. )
  455. let response = try await model.generateContent(testPrompt)
  456. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  457. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  458. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  459. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  460. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  461. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  462. }
  463. func testGenerateContent_failure_invalidAPIKey() async throws {
  464. let expectedStatusCode = 400
  465. MockURLProtocol
  466. .requestHandler = try httpRequestHandler(
  467. forResource: "unary-failure-api-key",
  468. withExtension: "json",
  469. statusCode: expectedStatusCode
  470. )
  471. do {
  472. _ = try await model.generateContent(testPrompt)
  473. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  474. } catch let GenerateContentError.internalError(error as BackendError) {
  475. XCTAssertEqual(error.httpResponseCode, 400)
  476. XCTAssertEqual(error.status, .invalidArgument)
  477. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  478. XCTAssertTrue(error.localizedDescription.contains(error.message))
  479. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  480. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  481. let nsError = error as NSError
  482. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  483. XCTAssertEqual(nsError.code, error.httpResponseCode)
  484. return
  485. } catch {
  486. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  487. }
  488. }
  489. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  490. let expectedStatusCode = 403
  491. MockURLProtocol
  492. .requestHandler = try httpRequestHandler(
  493. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  494. withExtension: "json",
  495. statusCode: expectedStatusCode
  496. )
  497. do {
  498. _ = try await model.generateContent(testPrompt)
  499. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  500. } catch let GenerateContentError.internalError(error as BackendError) {
  501. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  502. XCTAssertEqual(error.status, .permissionDenied)
  503. XCTAssertTrue(error.message
  504. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  505. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  506. return
  507. } catch {
  508. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  509. }
  510. }
  511. func testGenerateContent_failure_emptyContent() async throws {
  512. MockURLProtocol
  513. .requestHandler = try httpRequestHandler(
  514. forResource: "unary-failure-empty-content",
  515. withExtension: "json"
  516. )
  517. do {
  518. _ = try await model.generateContent(testPrompt)
  519. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  520. } catch let GenerateContentError
  521. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  522. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  523. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  524. return
  525. }
  526. _ = try XCTUnwrap(decodingError as? DecodingError,
  527. "Not a DecodingError: \(decodingError)")
  528. } catch {
  529. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  530. }
  531. }
  532. func testGenerateContent_failure_finishReasonSafety() async throws {
  533. MockURLProtocol
  534. .requestHandler = try httpRequestHandler(
  535. forResource: "unary-failure-finish-reason-safety",
  536. withExtension: "json"
  537. )
  538. do {
  539. _ = try await model.generateContent(testPrompt)
  540. XCTFail("Should throw")
  541. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  542. XCTAssertEqual(reason, .safety)
  543. XCTAssertEqual(response.text, "<redacted>")
  544. } catch {
  545. XCTFail("Should throw a responseStoppedEarly")
  546. }
  547. }
  548. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  549. MockURLProtocol
  550. .requestHandler = try httpRequestHandler(
  551. forResource: "unary-failure-finish-reason-safety-no-content",
  552. withExtension: "json"
  553. )
  554. do {
  555. _ = try await model.generateContent(testPrompt)
  556. XCTFail("Should throw")
  557. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  558. XCTAssertEqual(reason, .safety)
  559. XCTAssertNil(response.text)
  560. } catch {
  561. XCTFail("Should throw a responseStoppedEarly")
  562. }
  563. }
  564. func testGenerateContent_failure_imageRejected() async throws {
  565. let expectedStatusCode = 400
  566. MockURLProtocol
  567. .requestHandler = try httpRequestHandler(
  568. forResource: "unary-failure-image-rejected",
  569. withExtension: "json",
  570. statusCode: 400
  571. )
  572. do {
  573. _ = try await model.generateContent(testPrompt)
  574. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  575. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  576. XCTAssertEqual(rpcError.status, .invalidArgument)
  577. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  578. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  579. } catch {
  580. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  581. }
  582. }
  583. func testGenerateContent_failure_promptBlockedSafety() async throws {
  584. MockURLProtocol
  585. .requestHandler = try httpRequestHandler(
  586. forResource: "unary-failure-prompt-blocked-safety",
  587. withExtension: "json"
  588. )
  589. do {
  590. _ = try await model.generateContent(testPrompt)
  591. XCTFail("Should throw")
  592. } catch let GenerateContentError.promptBlocked(response) {
  593. XCTAssertNil(response.text)
  594. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  595. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  596. XCTAssertNil(promptFeedback.blockReasonMessage)
  597. } catch {
  598. XCTFail("Should throw a promptBlocked")
  599. }
  600. }
  601. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  602. MockURLProtocol
  603. .requestHandler = try httpRequestHandler(
  604. forResource: "unary-failure-prompt-blocked-safety-with-message",
  605. withExtension: "json"
  606. )
  607. do {
  608. _ = try await model.generateContent(testPrompt)
  609. XCTFail("Should throw")
  610. } catch let GenerateContentError.promptBlocked(response) {
  611. XCTAssertNil(response.text)
  612. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  613. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  614. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  615. } catch {
  616. XCTFail("Should throw a promptBlocked")
  617. }
  618. }
  619. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  620. MockURLProtocol
  621. .requestHandler = try httpRequestHandler(
  622. forResource: "unary-failure-unknown-enum-finish-reason",
  623. withExtension: "json"
  624. )
  625. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  626. do {
  627. _ = try await model.generateContent(testPrompt)
  628. XCTFail("Should throw")
  629. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  630. XCTAssertEqual(reason, unknownFinishReason)
  631. XCTAssertEqual(response.text, "Some text")
  632. } catch {
  633. XCTFail("Should throw a responseStoppedEarly")
  634. }
  635. }
  636. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  637. MockURLProtocol
  638. .requestHandler = try httpRequestHandler(
  639. forResource: "unary-failure-unknown-enum-prompt-blocked",
  640. withExtension: "json"
  641. )
  642. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  643. do {
  644. _ = try await model.generateContent(testPrompt)
  645. XCTFail("Should throw")
  646. } catch let GenerateContentError.promptBlocked(response) {
  647. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  648. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  649. } catch {
  650. XCTFail("Should throw a promptBlocked")
  651. }
  652. }
  653. func testGenerateContent_failure_unknownModel() async throws {
  654. let expectedStatusCode = 404
  655. MockURLProtocol
  656. .requestHandler = try httpRequestHandler(
  657. forResource: "unary-failure-unknown-model",
  658. withExtension: "json",
  659. statusCode: 404
  660. )
  661. do {
  662. _ = try await model.generateContent(testPrompt)
  663. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  664. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  665. XCTAssertEqual(rpcError.status, .notFound)
  666. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  667. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  668. } catch {
  669. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  670. }
  671. }
  672. func testGenerateContent_failure_nonHTTPResponse() async throws {
  673. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  674. var responseError: Error?
  675. var content: GenerateContentResponse?
  676. do {
  677. content = try await model.generateContent(testPrompt)
  678. } catch {
  679. responseError = error
  680. }
  681. XCTAssertNil(content)
  682. XCTAssertNotNil(responseError)
  683. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  684. guard case let .internalError(underlyingError) = generateContentError else {
  685. XCTFail("Not an internal error: \(generateContentError)")
  686. return
  687. }
  688. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  689. let underlyingNSError = underlyingError as NSError
  690. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  691. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  692. }
  693. func testGenerateContent_failure_invalidResponse() async throws {
  694. MockURLProtocol.requestHandler = try httpRequestHandler(
  695. forResource: "unary-failure-invalid-response",
  696. withExtension: "json"
  697. )
  698. var responseError: Error?
  699. var content: GenerateContentResponse?
  700. do {
  701. content = try await model.generateContent(testPrompt)
  702. } catch {
  703. responseError = error
  704. }
  705. XCTAssertNil(content)
  706. XCTAssertNotNil(responseError)
  707. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  708. guard case let .internalError(underlyingError) = generateContentError else {
  709. XCTFail("Not an internal error: \(generateContentError)")
  710. return
  711. }
  712. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  713. guard case let .dataCorrupted(context) = decodingError else {
  714. XCTFail("Not a data corrupted error: \(decodingError)")
  715. return
  716. }
  717. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  718. }
  719. func testGenerateContent_failure_malformedContent() async throws {
  720. MockURLProtocol
  721. .requestHandler = try httpRequestHandler(
  722. forResource: "unary-failure-malformed-content",
  723. withExtension: "json"
  724. )
  725. var responseError: Error?
  726. var content: GenerateContentResponse?
  727. do {
  728. content = try await model.generateContent(testPrompt)
  729. } catch {
  730. responseError = error
  731. }
  732. XCTAssertNil(content)
  733. XCTAssertNotNil(responseError)
  734. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  735. guard case let .internalError(underlyingError) = generateContentError else {
  736. XCTFail("Not an internal error: \(generateContentError)")
  737. return
  738. }
  739. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  740. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  741. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  742. return
  743. }
  744. _ = try XCTUnwrap(
  745. malformedContentUnderlyingError as? DecodingError,
  746. "Not a decoding error: \(malformedContentUnderlyingError)"
  747. )
  748. }
  749. func testGenerateContentMissingSafetyRatings() async throws {
  750. MockURLProtocol.requestHandler = try httpRequestHandler(
  751. forResource: "unary-success-missing-safety-ratings",
  752. withExtension: "json"
  753. )
  754. let content = try await model.generateContent(testPrompt)
  755. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  756. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  757. XCTAssertEqual(content.text, "This is the generated content.")
  758. }
  759. func testGenerateContent_requestOptions_customTimeout() async throws {
  760. let expectedTimeout = 150.0
  761. MockURLProtocol
  762. .requestHandler = try httpRequestHandler(
  763. forResource: "unary-success-basic-reply-short",
  764. withExtension: "json",
  765. timeout: expectedTimeout
  766. )
  767. let requestOptions = RequestOptions(timeout: expectedTimeout)
  768. model = GenerativeModel(
  769. name: testModelResourceName,
  770. projectID: "my-project-id",
  771. apiKey: "API_KEY",
  772. tools: nil,
  773. requestOptions: requestOptions,
  774. appCheck: nil,
  775. auth: nil,
  776. urlSession: urlSession
  777. )
  778. let response = try await model.generateContent(testPrompt)
  779. XCTAssertEqual(response.candidates.count, 1)
  780. }
  781. // MARK: - Generate Content (Streaming)
  782. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  783. MockURLProtocol
  784. .requestHandler = try httpRequestHandler(
  785. forResource: "unary-failure-api-key",
  786. withExtension: "json"
  787. )
  788. do {
  789. let stream = try model.generateContentStream("Hi")
  790. for try await _ in stream {
  791. XCTFail("No content is there, this shouldn't happen.")
  792. }
  793. } catch let GenerateContentError.internalError(error as BackendError) {
  794. XCTAssertEqual(error.httpResponseCode, 400)
  795. XCTAssertEqual(error.status, .invalidArgument)
  796. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  797. XCTAssertTrue(error.localizedDescription.contains(error.message))
  798. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  799. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  800. let nsError = error as NSError
  801. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  802. XCTAssertEqual(nsError.code, error.httpResponseCode)
  803. return
  804. }
  805. XCTFail("Should have caught an error.")
  806. }
  807. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  808. let expectedStatusCode = 403
  809. MockURLProtocol
  810. .requestHandler = try httpRequestHandler(
  811. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  812. withExtension: "json",
  813. statusCode: expectedStatusCode
  814. )
  815. do {
  816. let stream = try model.generateContentStream(testPrompt)
  817. for try await _ in stream {
  818. XCTFail("No content is there, this shouldn't happen.")
  819. }
  820. } catch let GenerateContentError.internalError(error as BackendError) {
  821. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  822. XCTAssertEqual(error.status, .permissionDenied)
  823. XCTAssertTrue(error.message
  824. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  825. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  826. return
  827. }
  828. XCTFail("Should have caught an error.")
  829. }
  830. func testGenerateContentStream_failureEmptyContent() async throws {
  831. MockURLProtocol
  832. .requestHandler = try httpRequestHandler(
  833. forResource: "streaming-failure-empty-content",
  834. withExtension: "txt"
  835. )
  836. do {
  837. let stream = try model.generateContentStream("Hi")
  838. for try await _ in stream {
  839. XCTFail("No content is there, this shouldn't happen.")
  840. }
  841. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  842. // Underlying error is as expected, nothing else to check.
  843. return
  844. }
  845. XCTFail("Should have caught an error.")
  846. }
  847. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  848. MockURLProtocol
  849. .requestHandler = try httpRequestHandler(
  850. forResource: "streaming-failure-finish-reason-safety",
  851. withExtension: "txt"
  852. )
  853. do {
  854. let stream = try model.generateContentStream("Hi")
  855. for try await _ in stream {
  856. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  857. }
  858. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  859. XCTAssertEqual(reason, .safety)
  860. let candidate = try XCTUnwrap(response.candidates.first)
  861. XCTAssertEqual(candidate.finishReason, reason)
  862. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  863. return
  864. }
  865. XCTFail("Should have caught an error.")
  866. }
  867. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  868. MockURLProtocol
  869. .requestHandler = try httpRequestHandler(
  870. forResource: "streaming-failure-prompt-blocked-safety",
  871. withExtension: "txt"
  872. )
  873. do {
  874. let stream = try model.generateContentStream("Hi")
  875. for try await _ in stream {
  876. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  877. }
  878. } catch let GenerateContentError.promptBlocked(response) {
  879. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  880. XCTAssertEqual(promptFeedback.blockReason, .safety)
  881. XCTAssertNil(promptFeedback.blockReasonMessage)
  882. return
  883. }
  884. XCTFail("Should have caught an error.")
  885. }
  886. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  887. MockURLProtocol
  888. .requestHandler = try httpRequestHandler(
  889. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  890. withExtension: "txt"
  891. )
  892. do {
  893. let stream = try model.generateContentStream("Hi")
  894. for try await _ in stream {
  895. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  896. }
  897. } catch let GenerateContentError.promptBlocked(response) {
  898. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  899. XCTAssertEqual(promptFeedback.blockReason, .safety)
  900. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  901. return
  902. }
  903. XCTFail("Should have caught an error.")
  904. }
  905. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  906. MockURLProtocol
  907. .requestHandler = try httpRequestHandler(
  908. forResource: "streaming-failure-unknown-finish-enum",
  909. withExtension: "txt"
  910. )
  911. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  912. let stream = try model.generateContentStream("Hi")
  913. do {
  914. for try await content in stream {
  915. XCTAssertNotNil(content.text)
  916. }
  917. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  918. XCTAssertEqual(reason, unknownFinishReason)
  919. return
  920. }
  921. XCTFail("Should have caught an error.")
  922. }
  923. func testGenerateContentStream_successBasicReplyLong() async throws {
  924. MockURLProtocol
  925. .requestHandler = try httpRequestHandler(
  926. forResource: "streaming-success-basic-reply-long",
  927. withExtension: "txt"
  928. )
  929. var responses = 0
  930. let stream = try model.generateContentStream("Hi")
  931. for try await content in stream {
  932. XCTAssertNotNil(content.text)
  933. responses += 1
  934. }
  935. XCTAssertEqual(responses, 6)
  936. }
  937. func testGenerateContentStream_successBasicReplyShort() async throws {
  938. MockURLProtocol
  939. .requestHandler = try httpRequestHandler(
  940. forResource: "streaming-success-basic-reply-short",
  941. withExtension: "txt"
  942. )
  943. var responses = 0
  944. let stream = try model.generateContentStream("Hi")
  945. for try await content in stream {
  946. XCTAssertNotNil(content.text)
  947. responses += 1
  948. }
  949. XCTAssertEqual(responses, 1)
  950. }
  951. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  952. MockURLProtocol
  953. .requestHandler = try httpRequestHandler(
  954. forResource: "streaming-success-unknown-safety-enum",
  955. withExtension: "txt"
  956. )
  957. let unknownSafetyRating = SafetyRating(
  958. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  959. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  960. probabilityScore: 0.0,
  961. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  962. severityScore: 0.0,
  963. blocked: false
  964. )
  965. var foundUnknownSafetyRating = false
  966. let stream = try model.generateContentStream("Hi")
  967. for try await content in stream {
  968. XCTAssertNotNil(content.text)
  969. if let ratings = content.candidates.first?.safetyRatings,
  970. ratings.contains(where: { $0 == unknownSafetyRating }) {
  971. foundUnknownSafetyRating = true
  972. }
  973. }
  974. XCTAssertTrue(foundUnknownSafetyRating)
  975. }
  976. func testGenerateContentStream_successWithCitations() async throws {
  977. MockURLProtocol
  978. .requestHandler = try httpRequestHandler(
  979. forResource: "streaming-success-citations",
  980. withExtension: "txt"
  981. )
  982. let expectedPublicationDate = DateComponents(
  983. calendar: Calendar(identifier: .gregorian),
  984. year: 2014,
  985. month: 3,
  986. day: 30
  987. )
  988. let stream = try model.generateContentStream("Hi")
  989. var citations = [Citation]()
  990. var responses = [GenerateContentResponse]()
  991. for try await content in stream {
  992. responses.append(content)
  993. XCTAssertNotNil(content.text)
  994. let candidate = try XCTUnwrap(content.candidates.first)
  995. if let sources = candidate.citationMetadata?.citations {
  996. citations.append(contentsOf: sources)
  997. }
  998. }
  999. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1000. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1001. XCTAssertEqual(citations.count, 6)
  1002. XCTAssertTrue(citations
  1003. .contains {
  1004. $0.startIndex == 0 && $0.endIndex == 128
  1005. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1006. && $0.license == nil && $0.publicationDate == nil
  1007. })
  1008. XCTAssertTrue(citations
  1009. .contains {
  1010. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1011. && $0.title == "some-citation-2" && $0.license == nil
  1012. && $0.publicationDate == expectedPublicationDate
  1013. })
  1014. XCTAssertTrue(citations
  1015. .contains {
  1016. $0.startIndex == 272 && $0.endIndex == 431
  1017. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1018. && $0.license == "mit" && $0.publicationDate == nil
  1019. })
  1020. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1021. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1022. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1023. }
  1024. func testGenerateContentStream_appCheck_validToken() async throws {
  1025. let appCheckToken = "test-valid-token"
  1026. model = GenerativeModel(
  1027. name: testModelResourceName,
  1028. projectID: "my-project-id",
  1029. apiKey: "API_KEY",
  1030. tools: nil,
  1031. requestOptions: RequestOptions(),
  1032. appCheck: AppCheckInteropFake(token: appCheckToken),
  1033. auth: nil,
  1034. urlSession: urlSession
  1035. )
  1036. MockURLProtocol
  1037. .requestHandler = try httpRequestHandler(
  1038. forResource: "streaming-success-basic-reply-short",
  1039. withExtension: "txt",
  1040. appCheckToken: appCheckToken
  1041. )
  1042. let stream = try model.generateContentStream(testPrompt)
  1043. for try await _ in stream {}
  1044. }
  1045. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1046. model = GenerativeModel(
  1047. name: testModelResourceName,
  1048. projectID: "my-project-id",
  1049. apiKey: "API_KEY",
  1050. tools: nil,
  1051. requestOptions: RequestOptions(),
  1052. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  1053. auth: nil,
  1054. urlSession: urlSession
  1055. )
  1056. MockURLProtocol
  1057. .requestHandler = try httpRequestHandler(
  1058. forResource: "streaming-success-basic-reply-short",
  1059. withExtension: "txt",
  1060. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1061. )
  1062. let stream = try model.generateContentStream(testPrompt)
  1063. for try await _ in stream {}
  1064. }
  1065. func testGenerateContentStream_usageMetadata() async throws {
  1066. MockURLProtocol
  1067. .requestHandler = try httpRequestHandler(
  1068. forResource: "streaming-success-basic-reply-short",
  1069. withExtension: "txt"
  1070. )
  1071. var responses = [GenerateContentResponse]()
  1072. let stream = try model.generateContentStream(testPrompt)
  1073. for try await response in stream {
  1074. responses.append(response)
  1075. }
  1076. for (index, response) in responses.enumerated() {
  1077. if index == responses.endIndex - 1 {
  1078. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1079. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1080. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1081. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1082. } else {
  1083. // Only the last streamed response contains usage metadata
  1084. XCTAssertNil(response.usageMetadata)
  1085. }
  1086. }
  1087. }
  1088. func testGenerateContentStream_errorMidStream() async throws {
  1089. MockURLProtocol.requestHandler = try httpRequestHandler(
  1090. forResource: "streaming-failure-error-mid-stream",
  1091. withExtension: "txt"
  1092. )
  1093. var responseCount = 0
  1094. do {
  1095. let stream = try model.generateContentStream("Hi")
  1096. for try await content in stream {
  1097. XCTAssertNotNil(content.text)
  1098. responseCount += 1
  1099. }
  1100. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1101. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1102. XCTAssertEqual(rpcError.status, .cancelled)
  1103. // Check the content count is correct.
  1104. XCTAssertEqual(responseCount, 2)
  1105. return
  1106. }
  1107. XCTFail("Expected an internalError with an RPCError.")
  1108. }
  1109. func testGenerateContentStream_nonHTTPResponse() async throws {
  1110. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  1111. let stream = try model.generateContentStream("Hi")
  1112. do {
  1113. for try await content in stream {
  1114. XCTFail("Unexpected content in stream: \(content)")
  1115. }
  1116. } catch let GenerateContentError.internalError(underlying) {
  1117. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1118. return
  1119. }
  1120. XCTFail("Expected an internal error.")
  1121. }
  1122. func testGenerateContentStream_invalidResponse() async throws {
  1123. MockURLProtocol
  1124. .requestHandler = try httpRequestHandler(
  1125. forResource: "streaming-failure-invalid-json",
  1126. withExtension: "txt"
  1127. )
  1128. let stream = try model.generateContentStream(testPrompt)
  1129. do {
  1130. for try await content in stream {
  1131. XCTFail("Unexpected content in stream: \(content)")
  1132. }
  1133. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1134. guard case let .dataCorrupted(context) = underlying else {
  1135. XCTFail("Not a data corrupted error: \(underlying)")
  1136. return
  1137. }
  1138. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1139. return
  1140. }
  1141. XCTFail("Expected an internal error.")
  1142. }
  1143. func testGenerateContentStream_malformedContent() async throws {
  1144. MockURLProtocol
  1145. .requestHandler = try httpRequestHandler(
  1146. forResource: "streaming-failure-malformed-content",
  1147. withExtension: "txt"
  1148. )
  1149. let stream = try model.generateContentStream(testPrompt)
  1150. do {
  1151. for try await content in stream {
  1152. XCTFail("Unexpected content in stream: \(content)")
  1153. }
  1154. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1155. guard case let .malformedContent(contentError) = underlyingError else {
  1156. XCTFail("Not a malformed content error: \(underlyingError)")
  1157. return
  1158. }
  1159. XCTAssert(contentError is DecodingError)
  1160. return
  1161. }
  1162. XCTFail("Expected an internal decoding error.")
  1163. }
  1164. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1165. let expectedTimeout = 150.0
  1166. MockURLProtocol
  1167. .requestHandler = try httpRequestHandler(
  1168. forResource: "streaming-success-basic-reply-short",
  1169. withExtension: "txt",
  1170. timeout: expectedTimeout
  1171. )
  1172. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1173. model = GenerativeModel(
  1174. name: testModelResourceName,
  1175. projectID: "my-project-id",
  1176. apiKey: "API_KEY",
  1177. tools: nil,
  1178. requestOptions: requestOptions,
  1179. appCheck: nil,
  1180. auth: nil,
  1181. urlSession: urlSession
  1182. )
  1183. var responses = 0
  1184. let stream = try model.generateContentStream(testPrompt)
  1185. for try await content in stream {
  1186. XCTAssertNotNil(content.text)
  1187. responses += 1
  1188. }
  1189. XCTAssertEqual(responses, 1)
  1190. }
  1191. // MARK: - Count Tokens
  1192. func testCountTokens_succeeds() async throws {
  1193. MockURLProtocol.requestHandler = try httpRequestHandler(
  1194. forResource: "unary-success-total-tokens",
  1195. withExtension: "json"
  1196. )
  1197. let response = try await model.countTokens("Why is the sky blue?")
  1198. XCTAssertEqual(response.totalTokens, 6)
  1199. XCTAssertEqual(response.totalBillableCharacters, 16)
  1200. }
  1201. func testCountTokens_succeeds_detailed() async throws {
  1202. MockURLProtocol.requestHandler = try httpRequestHandler(
  1203. forResource: "unary-success-detailed-token-response",
  1204. withExtension: "json"
  1205. )
  1206. let response = try await model.countTokens("Why is the sky blue?")
  1207. XCTAssertEqual(response.totalTokens, 1837)
  1208. XCTAssertEqual(response.totalBillableCharacters, 117)
  1209. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1210. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1211. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1212. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1213. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1214. }
  1215. func testCountTokens_succeeds_allOptions() async throws {
  1216. MockURLProtocol.requestHandler = try httpRequestHandler(
  1217. forResource: "unary-success-total-tokens",
  1218. withExtension: "json"
  1219. )
  1220. let generationConfig = GenerationConfig(
  1221. temperature: 0.5,
  1222. topP: 0.9,
  1223. topK: 3,
  1224. candidateCount: 1,
  1225. maxOutputTokens: 1024,
  1226. stopSequences: ["test-stop"],
  1227. responseMIMEType: "text/plain"
  1228. )
  1229. let sumFunction = FunctionDeclaration(
  1230. name: "sum",
  1231. description: "Add two integers.",
  1232. parameters: ["x": .integer(), "y": .integer()]
  1233. )
  1234. let systemInstruction = ModelContent(
  1235. role: "system",
  1236. parts: "You are a calculator. Use the provided tools."
  1237. )
  1238. model = GenerativeModel(
  1239. name: testModelResourceName,
  1240. projectID: "my-project-id",
  1241. apiKey: "API_KEY",
  1242. generationConfig: generationConfig,
  1243. tools: [Tool(functionDeclarations: [sumFunction])],
  1244. systemInstruction: systemInstruction,
  1245. requestOptions: RequestOptions(),
  1246. appCheck: nil,
  1247. auth: nil,
  1248. urlSession: urlSession
  1249. )
  1250. let response = try await model.countTokens("Why is the sky blue?")
  1251. XCTAssertEqual(response.totalTokens, 6)
  1252. XCTAssertEqual(response.totalBillableCharacters, 16)
  1253. }
  1254. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1255. MockURLProtocol.requestHandler = try httpRequestHandler(
  1256. forResource: "unary-success-no-billable-characters",
  1257. withExtension: "json"
  1258. )
  1259. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1260. XCTAssertEqual(response.totalTokens, 258)
  1261. XCTAssertNil(response.totalBillableCharacters)
  1262. }
  1263. func testCountTokens_modelNotFound() async throws {
  1264. MockURLProtocol.requestHandler = try httpRequestHandler(
  1265. forResource: "unary-failure-model-not-found", withExtension: "json",
  1266. statusCode: 404
  1267. )
  1268. do {
  1269. _ = try await model.countTokens("Why is the sky blue?")
  1270. XCTFail("Request should not have succeeded.")
  1271. } catch let rpcError as BackendError {
  1272. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1273. XCTAssertEqual(rpcError.status, .notFound)
  1274. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1275. return
  1276. }
  1277. XCTFail("Expected internal RPCError.")
  1278. }
  1279. func testCountTokens_requestOptions_customTimeout() async throws {
  1280. let expectedTimeout = 150.0
  1281. MockURLProtocol
  1282. .requestHandler = try httpRequestHandler(
  1283. forResource: "unary-success-total-tokens",
  1284. withExtension: "json",
  1285. timeout: expectedTimeout
  1286. )
  1287. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1288. model = GenerativeModel(
  1289. name: testModelResourceName,
  1290. projectID: "my-project-id",
  1291. apiKey: "API_KEY",
  1292. tools: nil,
  1293. requestOptions: requestOptions,
  1294. appCheck: nil,
  1295. auth: nil,
  1296. urlSession: urlSession
  1297. )
  1298. let response = try await model.countTokens(testPrompt)
  1299. XCTAssertEqual(response.totalTokens, 6)
  1300. }
  1301. // MARK: - Helpers
  1302. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1303. URLResponse,
  1304. AsyncLineSequence<URL.AsyncBytes>?
  1305. )) {
  1306. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1307. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1308. #if os(watchOS)
  1309. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1310. #endif // os(watchOS)
  1311. return { request in
  1312. // This is *not* an HTTPURLResponse
  1313. let response = URLResponse(
  1314. url: request.url!,
  1315. mimeType: nil,
  1316. expectedContentLength: 0,
  1317. textEncodingName: nil
  1318. )
  1319. return (response, nil)
  1320. }
  1321. }
  1322. private func httpRequestHandler(forResource name: String,
  1323. withExtension ext: String,
  1324. statusCode: Int = 200,
  1325. timeout: TimeInterval = RequestOptions().timeout,
  1326. appCheckToken: String? = nil,
  1327. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1328. URLResponse,
  1329. AsyncLineSequence<URL.AsyncBytes>?
  1330. )) {
  1331. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1332. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1333. #if os(watchOS)
  1334. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1335. #endif // os(watchOS)
  1336. let bundle = BundleTestUtil.bundle()
  1337. let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
  1338. return { request in
  1339. let requestURL = try XCTUnwrap(request.url)
  1340. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1341. XCTAssertEqual(request.timeoutInterval, timeout)
  1342. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1343. .components(separatedBy: " ")
  1344. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1345. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1346. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1347. if let authToken {
  1348. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1349. } else {
  1350. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1351. }
  1352. let response = try XCTUnwrap(HTTPURLResponse(
  1353. url: requestURL,
  1354. statusCode: statusCode,
  1355. httpVersion: nil,
  1356. headerFields: nil
  1357. ))
  1358. return (response, fileURL.lines)
  1359. }
  1360. }
  1361. }
  1362. private extension String {
  1363. /// Returns the number of occurrences of `substring` in the `String`.
  1364. func occurrenceCount(of substring: String) -> Int {
  1365. return components(separatedBy: substring).count - 1
  1366. }
  1367. }
  1368. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1369. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1370. /// The placeholder token value returned when an error occurs
  1371. static let placeholderTokenValue = "placeholder-token"
  1372. var token: String
  1373. var error: Error?
  1374. private init(token: String, error: Error?) {
  1375. self.token = token
  1376. self.error = error
  1377. }
  1378. convenience init(token: String) {
  1379. self.init(token: token, error: nil)
  1380. }
  1381. convenience init(error: Error) {
  1382. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1383. }
  1384. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1385. return AppCheckTokenResultInteropFake(token: token, error: error)
  1386. }
  1387. func tokenDidChangeNotificationName() -> String {
  1388. fatalError("\(#function) not implemented.")
  1389. }
  1390. func notificationTokenKey() -> String {
  1391. fatalError("\(#function) not implemented.")
  1392. }
  1393. func notificationAppNameKey() -> String {
  1394. fatalError("\(#function) not implemented.")
  1395. }
  1396. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1397. var token: String
  1398. var error: Error?
  1399. init(token: String, error: Error?) {
  1400. self.token = token
  1401. self.error = error
  1402. }
  1403. }
  1404. }
  1405. struct AppCheckErrorFake: Error {}
  1406. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1407. extension SafetyRating: Swift.Comparable {
  1408. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1409. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1410. return lhs.category.rawValue < rhs.category.rawValue
  1411. }
  1412. }