GenerativeModelTests.swift 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let testModelResourceName =
  57. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  58. let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)
  59. var urlSession: URLSession!
  60. var model: GenerativeModel!
  61. override func setUp() async throws {
  62. let configuration = URLSessionConfiguration.default
  63. configuration.protocolClasses = [MockURLProtocol.self]
  64. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  65. model = GenerativeModel(
  66. name: testModelResourceName,
  67. firebaseInfo: testFirebaseInfo(),
  68. apiConfig: apiConfig,
  69. tools: nil,
  70. requestOptions: RequestOptions(),
  71. urlSession: urlSession
  72. )
  73. }
  74. override func tearDown() {
  75. MockURLProtocol.requestHandler = nil
  76. }
  77. // MARK: - Generate Content
  78. func testGenerateContent_success_basicReplyLong() async throws {
  79. MockURLProtocol
  80. .requestHandler = try httpRequestHandler(
  81. forResource: "unary-success-basic-reply-long",
  82. withExtension: "json"
  83. )
  84. let response = try await model.generateContent(testPrompt)
  85. XCTAssertEqual(response.candidates.count, 1)
  86. let candidate = try XCTUnwrap(response.candidates.first)
  87. let finishReason = try XCTUnwrap(candidate.finishReason)
  88. XCTAssertEqual(finishReason, .stop)
  89. XCTAssertEqual(candidate.safetyRatings.count, 4)
  90. XCTAssertEqual(candidate.content.parts.count, 1)
  91. let part = try XCTUnwrap(candidate.content.parts.first)
  92. let partText = try XCTUnwrap(part as? TextPart).text
  93. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  94. XCTAssertEqual(response.text, partText)
  95. XCTAssertEqual(response.functionCalls, [])
  96. }
  97. func testGenerateContent_success_basicReplyShort() async throws {
  98. MockURLProtocol
  99. .requestHandler = try httpRequestHandler(
  100. forResource: "unary-success-basic-reply-short",
  101. withExtension: "json"
  102. )
  103. let response = try await model.generateContent(testPrompt)
  104. XCTAssertEqual(response.candidates.count, 1)
  105. let candidate = try XCTUnwrap(response.candidates.first)
  106. let finishReason = try XCTUnwrap(candidate.finishReason)
  107. XCTAssertEqual(finishReason, .stop)
  108. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  109. XCTAssertEqual(candidate.content.parts.count, 1)
  110. let part = try XCTUnwrap(candidate.content.parts.first)
  111. let textPart = try XCTUnwrap(part as? TextPart)
  112. XCTAssertEqual(textPart.text, "Mountain View, California")
  113. XCTAssertEqual(response.text, textPart.text)
  114. XCTAssertEqual(response.functionCalls, [])
  115. }
  116. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  117. MockURLProtocol
  118. .requestHandler = try httpRequestHandler(
  119. forResource: "unary-success-basic-response-long-usage-metadata",
  120. withExtension: "json"
  121. )
  122. let response = try await model.generateContent(testPrompt)
  123. XCTAssertEqual(response.candidates.count, 1)
  124. let candidate = try XCTUnwrap(response.candidates.first)
  125. let finishReason = try XCTUnwrap(candidate.finishReason)
  126. XCTAssertEqual(finishReason, .stop)
  127. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  128. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  129. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  130. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  131. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  132. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  133. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  134. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  135. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  136. }
  137. func testGenerateContent_success_citations() async throws {
  138. MockURLProtocol
  139. .requestHandler = try httpRequestHandler(
  140. forResource: "unary-success-citations",
  141. withExtension: "json"
  142. )
  143. let expectedPublicationDate = DateComponents(
  144. calendar: Calendar(identifier: .gregorian),
  145. year: 2019,
  146. month: 5,
  147. day: 10
  148. )
  149. let response = try await model.generateContent(testPrompt)
  150. XCTAssertEqual(response.candidates.count, 1)
  151. let candidate = try XCTUnwrap(response.candidates.first)
  152. XCTAssertEqual(candidate.content.parts.count, 1)
  153. XCTAssertEqual(response.text, "Some information cited from an external source")
  154. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  155. XCTAssertEqual(citationMetadata.citations.count, 3)
  156. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  157. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  158. XCTAssertEqual(citationSource1.startIndex, 0)
  159. XCTAssertEqual(citationSource1.endIndex, 128)
  160. XCTAssertNil(citationSource1.title)
  161. XCTAssertNil(citationSource1.license)
  162. XCTAssertNil(citationSource1.publicationDate)
  163. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  164. XCTAssertEqual(citationSource2.title, "some-citation-2")
  165. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  166. XCTAssertEqual(citationSource2.startIndex, 130)
  167. XCTAssertEqual(citationSource2.endIndex, 265)
  168. XCTAssertNil(citationSource2.uri)
  169. XCTAssertNil(citationSource2.license)
  170. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  171. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  172. XCTAssertEqual(citationSource3.startIndex, 272)
  173. XCTAssertEqual(citationSource3.endIndex, 431)
  174. XCTAssertEqual(citationSource3.license, "mit")
  175. XCTAssertNil(citationSource3.title)
  176. XCTAssertNil(citationSource3.publicationDate)
  177. }
  178. func testGenerateContent_success_quoteReply() async throws {
  179. MockURLProtocol
  180. .requestHandler = try httpRequestHandler(
  181. forResource: "unary-success-quote-reply",
  182. withExtension: "json"
  183. )
  184. let response = try await model.generateContent(testPrompt)
  185. XCTAssertEqual(response.candidates.count, 1)
  186. let candidate = try XCTUnwrap(response.candidates.first)
  187. let finishReason = try XCTUnwrap(candidate.finishReason)
  188. XCTAssertEqual(finishReason, .stop)
  189. XCTAssertEqual(candidate.safetyRatings.count, 4)
  190. XCTAssertEqual(candidate.content.parts.count, 1)
  191. let part = try XCTUnwrap(candidate.content.parts.first)
  192. let textPart = try XCTUnwrap(part as? TextPart)
  193. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  194. XCTAssertEqual(response.text, textPart.text)
  195. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  196. XCTAssertNil(promptFeedback.blockReason)
  197. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  198. }
  199. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  200. let expectedSafetyRatings = [
  201. SafetyRating(
  202. category: .harassment,
  203. probability: .medium,
  204. probabilityScore: 0.0,
  205. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  206. severityScore: 0.0,
  207. blocked: false
  208. ),
  209. SafetyRating(
  210. category: .dangerousContent,
  211. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  212. probabilityScore: 0.0,
  213. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  214. severityScore: 0.0,
  215. blocked: false
  216. ),
  217. SafetyRating(
  218. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  219. probability: .high,
  220. probabilityScore: 0.0,
  221. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  222. severityScore: 0.0,
  223. blocked: false
  224. ),
  225. ]
  226. MockURLProtocol
  227. .requestHandler = try httpRequestHandler(
  228. forResource: "unary-success-unknown-enum-safety-ratings",
  229. withExtension: "json"
  230. )
  231. let response = try await model.generateContent(testPrompt)
  232. XCTAssertEqual(response.text, "Some text")
  233. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  234. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  235. }
  236. func testGenerateContent_success_prefixedModelName() async throws {
  237. MockURLProtocol
  238. .requestHandler = try httpRequestHandler(
  239. forResource: "unary-success-basic-reply-short",
  240. withExtension: "json"
  241. )
  242. let model = GenerativeModel(
  243. // Model name is prefixed with "models/".
  244. name: "models/test-model",
  245. firebaseInfo: testFirebaseInfo(),
  246. apiConfig: apiConfig,
  247. tools: nil,
  248. requestOptions: RequestOptions(),
  249. urlSession: urlSession
  250. )
  251. _ = try await model.generateContent(testPrompt)
  252. }
  253. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  254. MockURLProtocol
  255. .requestHandler = try httpRequestHandler(
  256. forResource: "unary-success-function-call-empty-arguments",
  257. withExtension: "json"
  258. )
  259. let response = try await model.generateContent(testPrompt)
  260. XCTAssertEqual(response.candidates.count, 1)
  261. let candidate = try XCTUnwrap(response.candidates.first)
  262. XCTAssertEqual(candidate.content.parts.count, 1)
  263. let part = try XCTUnwrap(candidate.content.parts.first)
  264. guard let functionCall = part as? FunctionCallPart else {
  265. XCTFail("Part is not a FunctionCall.")
  266. return
  267. }
  268. XCTAssertEqual(functionCall.name, "current_time")
  269. XCTAssertTrue(functionCall.args.isEmpty)
  270. XCTAssertEqual(response.functionCalls, [functionCall])
  271. }
  272. func testGenerateContent_success_functionCall_noArguments() async throws {
  273. MockURLProtocol
  274. .requestHandler = try httpRequestHandler(
  275. forResource: "unary-success-function-call-no-arguments",
  276. withExtension: "json"
  277. )
  278. let response = try await model.generateContent(testPrompt)
  279. XCTAssertEqual(response.candidates.count, 1)
  280. let candidate = try XCTUnwrap(response.candidates.first)
  281. XCTAssertEqual(candidate.content.parts.count, 1)
  282. let part = try XCTUnwrap(candidate.content.parts.first)
  283. guard let functionCall = part as? FunctionCallPart else {
  284. XCTFail("Part is not a FunctionCall.")
  285. return
  286. }
  287. XCTAssertEqual(functionCall.name, "current_time")
  288. XCTAssertTrue(functionCall.args.isEmpty)
  289. XCTAssertEqual(response.functionCalls, [functionCall])
  290. }
  291. func testGenerateContent_success_functionCall_withArguments() async throws {
  292. MockURLProtocol
  293. .requestHandler = try httpRequestHandler(
  294. forResource: "unary-success-function-call-with-arguments",
  295. withExtension: "json"
  296. )
  297. let response = try await model.generateContent(testPrompt)
  298. XCTAssertEqual(response.candidates.count, 1)
  299. let candidate = try XCTUnwrap(response.candidates.first)
  300. XCTAssertEqual(candidate.content.parts.count, 1)
  301. let part = try XCTUnwrap(candidate.content.parts.first)
  302. guard let functionCall = part as? FunctionCallPart else {
  303. XCTFail("Part is not a FunctionCall.")
  304. return
  305. }
  306. XCTAssertEqual(functionCall.name, "sum")
  307. XCTAssertEqual(functionCall.args.count, 2)
  308. let argX = try XCTUnwrap(functionCall.args["x"])
  309. XCTAssertEqual(argX, .number(4))
  310. let argY = try XCTUnwrap(functionCall.args["y"])
  311. XCTAssertEqual(argY, .number(5))
  312. XCTAssertEqual(response.functionCalls, [functionCall])
  313. }
  314. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  315. MockURLProtocol
  316. .requestHandler = try httpRequestHandler(
  317. forResource: "unary-success-function-call-parallel-calls",
  318. withExtension: "json"
  319. )
  320. let response = try await model.generateContent(testPrompt)
  321. XCTAssertEqual(response.candidates.count, 1)
  322. let candidate = try XCTUnwrap(response.candidates.first)
  323. XCTAssertEqual(candidate.content.parts.count, 3)
  324. let functionCalls = response.functionCalls
  325. XCTAssertEqual(functionCalls.count, 3)
  326. }
  327. func testGenerateContent_success_functionCall_mixedContent() async throws {
  328. MockURLProtocol
  329. .requestHandler = try httpRequestHandler(
  330. forResource: "unary-success-function-call-mixed-content",
  331. withExtension: "json"
  332. )
  333. let response = try await model.generateContent(testPrompt)
  334. XCTAssertEqual(response.candidates.count, 1)
  335. let candidate = try XCTUnwrap(response.candidates.first)
  336. XCTAssertEqual(candidate.content.parts.count, 4)
  337. let functionCalls = response.functionCalls
  338. XCTAssertEqual(functionCalls.count, 2)
  339. let text = try XCTUnwrap(response.text)
  340. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  341. }
  342. func testGenerateContent_appCheck_validToken() async throws {
  343. let appCheckToken = "test-valid-token"
  344. model = GenerativeModel(
  345. name: testModelResourceName,
  346. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
  347. apiConfig: apiConfig,
  348. tools: nil,
  349. requestOptions: RequestOptions(),
  350. urlSession: urlSession
  351. )
  352. MockURLProtocol
  353. .requestHandler = try httpRequestHandler(
  354. forResource: "unary-success-basic-reply-short",
  355. withExtension: "json",
  356. appCheckToken: appCheckToken
  357. )
  358. _ = try await model.generateContent(testPrompt)
  359. }
  360. func testGenerateContent_dataCollectionOff() async throws {
  361. let appCheckToken = "test-valid-token"
  362. model = GenerativeModel(
  363. name: testModelResourceName,
  364. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken),
  365. privateAppID: true),
  366. apiConfig: apiConfig,
  367. tools: nil,
  368. requestOptions: RequestOptions(),
  369. urlSession: urlSession
  370. )
  371. MockURLProtocol
  372. .requestHandler = try httpRequestHandler(
  373. forResource: "unary-success-basic-reply-short",
  374. withExtension: "json",
  375. appCheckToken: appCheckToken,
  376. dataCollection: false
  377. )
  378. _ = try await model.generateContent(testPrompt)
  379. }
  380. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  381. model = GenerativeModel(
  382. name: testModelResourceName,
  383. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
  384. apiConfig: apiConfig,
  385. tools: nil,
  386. requestOptions: RequestOptions(),
  387. urlSession: urlSession
  388. )
  389. MockURLProtocol
  390. .requestHandler = try httpRequestHandler(
  391. forResource: "unary-success-basic-reply-short",
  392. withExtension: "json",
  393. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  394. )
  395. _ = try await model.generateContent(testPrompt)
  396. }
  397. func testGenerateContent_auth_validAuthToken() async throws {
  398. let authToken = "test-valid-token"
  399. model = GenerativeModel(
  400. name: testModelResourceName,
  401. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)),
  402. apiConfig: apiConfig,
  403. tools: nil,
  404. requestOptions: RequestOptions(),
  405. urlSession: urlSession
  406. )
  407. MockURLProtocol
  408. .requestHandler = try httpRequestHandler(
  409. forResource: "unary-success-basic-reply-short",
  410. withExtension: "json",
  411. authToken: authToken
  412. )
  413. _ = try await model.generateContent(testPrompt)
  414. }
  415. func testGenerateContent_auth_nilAuthToken() async throws {
  416. model = GenerativeModel(
  417. name: testModelResourceName,
  418. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  419. apiConfig: apiConfig,
  420. tools: nil,
  421. requestOptions: RequestOptions(),
  422. urlSession: urlSession
  423. )
  424. MockURLProtocol
  425. .requestHandler = try httpRequestHandler(
  426. forResource: "unary-success-basic-reply-short",
  427. withExtension: "json",
  428. authToken: nil
  429. )
  430. _ = try await model.generateContent(testPrompt)
  431. }
  432. func testGenerateContent_auth_authTokenRefreshError() async throws {
  433. model = GenerativeModel(
  434. name: "my-model",
  435. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())),
  436. apiConfig: apiConfig,
  437. tools: nil,
  438. requestOptions: RequestOptions(),
  439. urlSession: urlSession
  440. )
  441. MockURLProtocol
  442. .requestHandler = try httpRequestHandler(
  443. forResource: "unary-success-basic-reply-short",
  444. withExtension: "json",
  445. authToken: nil
  446. )
  447. do {
  448. _ = try await model.generateContent(testPrompt)
  449. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  450. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  451. //
  452. } catch {
  453. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  454. }
  455. }
  456. func testGenerateContent_usageMetadata() async throws {
  457. MockURLProtocol
  458. .requestHandler = try httpRequestHandler(
  459. forResource: "unary-success-basic-reply-short",
  460. withExtension: "json"
  461. )
  462. let response = try await model.generateContent(testPrompt)
  463. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  464. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  465. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  466. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  467. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  468. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  469. }
  470. func testGenerateContent_failure_invalidAPIKey() async throws {
  471. let expectedStatusCode = 400
  472. MockURLProtocol
  473. .requestHandler = try httpRequestHandler(
  474. forResource: "unary-failure-api-key",
  475. withExtension: "json",
  476. statusCode: expectedStatusCode
  477. )
  478. do {
  479. _ = try await model.generateContent(testPrompt)
  480. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  481. } catch let GenerateContentError.internalError(error as BackendError) {
  482. XCTAssertEqual(error.httpResponseCode, 400)
  483. XCTAssertEqual(error.status, .invalidArgument)
  484. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  485. XCTAssertTrue(error.localizedDescription.contains(error.message))
  486. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  487. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  488. let nsError = error as NSError
  489. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  490. XCTAssertEqual(nsError.code, error.httpResponseCode)
  491. return
  492. } catch {
  493. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  494. }
  495. }
  496. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  497. let expectedStatusCode = 403
  498. MockURLProtocol
  499. .requestHandler = try httpRequestHandler(
  500. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  501. withExtension: "json",
  502. statusCode: expectedStatusCode
  503. )
  504. do {
  505. _ = try await model.generateContent(testPrompt)
  506. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  507. } catch let GenerateContentError.internalError(error as BackendError) {
  508. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  509. XCTAssertEqual(error.status, .permissionDenied)
  510. XCTAssertTrue(error.message
  511. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  512. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  513. return
  514. } catch {
  515. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  516. }
  517. }
  518. func testGenerateContent_failure_emptyContent() async throws {
  519. MockURLProtocol
  520. .requestHandler = try httpRequestHandler(
  521. forResource: "unary-failure-empty-content",
  522. withExtension: "json"
  523. )
  524. do {
  525. _ = try await model.generateContent(testPrompt)
  526. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  527. } catch let GenerateContentError
  528. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  529. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  530. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  531. return
  532. }
  533. _ = try XCTUnwrap(decodingError as? DecodingError,
  534. "Not a DecodingError: \(decodingError)")
  535. } catch {
  536. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  537. }
  538. }
  539. func testGenerateContent_failure_finishReasonSafety() async throws {
  540. MockURLProtocol
  541. .requestHandler = try httpRequestHandler(
  542. forResource: "unary-failure-finish-reason-safety",
  543. withExtension: "json"
  544. )
  545. do {
  546. _ = try await model.generateContent(testPrompt)
  547. XCTFail("Should throw")
  548. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  549. XCTAssertEqual(reason, .safety)
  550. XCTAssertEqual(response.text, "<redacted>")
  551. } catch {
  552. XCTFail("Should throw a responseStoppedEarly")
  553. }
  554. }
  555. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  556. MockURLProtocol
  557. .requestHandler = try httpRequestHandler(
  558. forResource: "unary-failure-finish-reason-safety-no-content",
  559. withExtension: "json"
  560. )
  561. do {
  562. _ = try await model.generateContent(testPrompt)
  563. XCTFail("Should throw")
  564. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  565. XCTAssertEqual(reason, .safety)
  566. XCTAssertNil(response.text)
  567. } catch {
  568. XCTFail("Should throw a responseStoppedEarly")
  569. }
  570. }
  571. func testGenerateContent_failure_imageRejected() async throws {
  572. let expectedStatusCode = 400
  573. MockURLProtocol
  574. .requestHandler = try httpRequestHandler(
  575. forResource: "unary-failure-image-rejected",
  576. withExtension: "json",
  577. statusCode: 400
  578. )
  579. do {
  580. _ = try await model.generateContent(testPrompt)
  581. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  582. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  583. XCTAssertEqual(rpcError.status, .invalidArgument)
  584. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  585. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  586. } catch {
  587. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  588. }
  589. }
  590. func testGenerateContent_failure_promptBlockedSafety() async throws {
  591. MockURLProtocol
  592. .requestHandler = try httpRequestHandler(
  593. forResource: "unary-failure-prompt-blocked-safety",
  594. withExtension: "json"
  595. )
  596. do {
  597. _ = try await model.generateContent(testPrompt)
  598. XCTFail("Should throw")
  599. } catch let GenerateContentError.promptBlocked(response) {
  600. XCTAssertNil(response.text)
  601. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  602. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  603. XCTAssertNil(promptFeedback.blockReasonMessage)
  604. } catch {
  605. XCTFail("Should throw a promptBlocked")
  606. }
  607. }
  608. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  609. MockURLProtocol
  610. .requestHandler = try httpRequestHandler(
  611. forResource: "unary-failure-prompt-blocked-safety-with-message",
  612. withExtension: "json"
  613. )
  614. do {
  615. _ = try await model.generateContent(testPrompt)
  616. XCTFail("Should throw")
  617. } catch let GenerateContentError.promptBlocked(response) {
  618. XCTAssertNil(response.text)
  619. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  620. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  621. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  622. } catch {
  623. XCTFail("Should throw a promptBlocked")
  624. }
  625. }
  626. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  627. MockURLProtocol
  628. .requestHandler = try httpRequestHandler(
  629. forResource: "unary-failure-unknown-enum-finish-reason",
  630. withExtension: "json"
  631. )
  632. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  633. do {
  634. _ = try await model.generateContent(testPrompt)
  635. XCTFail("Should throw")
  636. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  637. XCTAssertEqual(reason, unknownFinishReason)
  638. XCTAssertEqual(response.text, "Some text")
  639. } catch {
  640. XCTFail("Should throw a responseStoppedEarly")
  641. }
  642. }
  643. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  644. MockURLProtocol
  645. .requestHandler = try httpRequestHandler(
  646. forResource: "unary-failure-unknown-enum-prompt-blocked",
  647. withExtension: "json"
  648. )
  649. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  650. do {
  651. _ = try await model.generateContent(testPrompt)
  652. XCTFail("Should throw")
  653. } catch let GenerateContentError.promptBlocked(response) {
  654. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  655. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  656. } catch {
  657. XCTFail("Should throw a promptBlocked")
  658. }
  659. }
  660. func testGenerateContent_failure_unknownModel() async throws {
  661. let expectedStatusCode = 404
  662. MockURLProtocol
  663. .requestHandler = try httpRequestHandler(
  664. forResource: "unary-failure-unknown-model",
  665. withExtension: "json",
  666. statusCode: 404
  667. )
  668. do {
  669. _ = try await model.generateContent(testPrompt)
  670. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  671. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  672. XCTAssertEqual(rpcError.status, .notFound)
  673. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  674. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  675. } catch {
  676. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  677. }
  678. }
  679. func testGenerateContent_failure_nonHTTPResponse() async throws {
  680. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  681. var responseError: Error?
  682. var content: GenerateContentResponse?
  683. do {
  684. content = try await model.generateContent(testPrompt)
  685. } catch {
  686. responseError = error
  687. }
  688. XCTAssertNil(content)
  689. XCTAssertNotNil(responseError)
  690. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  691. guard case let .internalError(underlyingError) = generateContentError else {
  692. XCTFail("Not an internal error: \(generateContentError)")
  693. return
  694. }
  695. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  696. let underlyingNSError = underlyingError as NSError
  697. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  698. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  699. }
  700. func testGenerateContent_failure_invalidResponse() async throws {
  701. MockURLProtocol.requestHandler = try httpRequestHandler(
  702. forResource: "unary-failure-invalid-response",
  703. withExtension: "json"
  704. )
  705. var responseError: Error?
  706. var content: GenerateContentResponse?
  707. do {
  708. content = try await model.generateContent(testPrompt)
  709. } catch {
  710. responseError = error
  711. }
  712. XCTAssertNil(content)
  713. XCTAssertNotNil(responseError)
  714. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  715. guard case let .internalError(underlyingError) = generateContentError else {
  716. XCTFail("Not an internal error: \(generateContentError)")
  717. return
  718. }
  719. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  720. guard case let .dataCorrupted(context) = decodingError else {
  721. XCTFail("Not a data corrupted error: \(decodingError)")
  722. return
  723. }
  724. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  725. }
  726. func testGenerateContent_failure_malformedContent() async throws {
  727. MockURLProtocol
  728. .requestHandler = try httpRequestHandler(
  729. forResource: "unary-failure-malformed-content",
  730. withExtension: "json"
  731. )
  732. var responseError: Error?
  733. var content: GenerateContentResponse?
  734. do {
  735. content = try await model.generateContent(testPrompt)
  736. } catch {
  737. responseError = error
  738. }
  739. XCTAssertNil(content)
  740. XCTAssertNotNil(responseError)
  741. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  742. guard case let .internalError(underlyingError) = generateContentError else {
  743. XCTFail("Not an internal error: \(generateContentError)")
  744. return
  745. }
  746. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  747. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  748. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  749. return
  750. }
  751. _ = try XCTUnwrap(
  752. malformedContentUnderlyingError as? DecodingError,
  753. "Not a decoding error: \(malformedContentUnderlyingError)"
  754. )
  755. }
  756. func testGenerateContentMissingSafetyRatings() async throws {
  757. MockURLProtocol.requestHandler = try httpRequestHandler(
  758. forResource: "unary-success-missing-safety-ratings",
  759. withExtension: "json"
  760. )
  761. let content = try await model.generateContent(testPrompt)
  762. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  763. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  764. XCTAssertEqual(content.text, "This is the generated content.")
  765. }
  766. func testGenerateContent_requestOptions_customTimeout() async throws {
  767. let expectedTimeout = 150.0
  768. MockURLProtocol
  769. .requestHandler = try httpRequestHandler(
  770. forResource: "unary-success-basic-reply-short",
  771. withExtension: "json",
  772. timeout: expectedTimeout
  773. )
  774. let requestOptions = RequestOptions(timeout: expectedTimeout)
  775. model = GenerativeModel(
  776. name: testModelResourceName,
  777. firebaseInfo: testFirebaseInfo(),
  778. apiConfig: apiConfig,
  779. tools: nil,
  780. requestOptions: requestOptions,
  781. urlSession: urlSession
  782. )
  783. let response = try await model.generateContent(testPrompt)
  784. XCTAssertEqual(response.candidates.count, 1)
  785. }
  786. // MARK: - Generate Content (Streaming)
  787. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  788. MockURLProtocol
  789. .requestHandler = try httpRequestHandler(
  790. forResource: "unary-failure-api-key",
  791. withExtension: "json"
  792. )
  793. do {
  794. let stream = try model.generateContentStream("Hi")
  795. for try await _ in stream {
  796. XCTFail("No content is there, this shouldn't happen.")
  797. }
  798. } catch let GenerateContentError.internalError(error as BackendError) {
  799. XCTAssertEqual(error.httpResponseCode, 400)
  800. XCTAssertEqual(error.status, .invalidArgument)
  801. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  802. XCTAssertTrue(error.localizedDescription.contains(error.message))
  803. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  804. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  805. let nsError = error as NSError
  806. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  807. XCTAssertEqual(nsError.code, error.httpResponseCode)
  808. return
  809. }
  810. XCTFail("Should have caught an error.")
  811. }
  812. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  813. let expectedStatusCode = 403
  814. MockURLProtocol
  815. .requestHandler = try httpRequestHandler(
  816. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  817. withExtension: "json",
  818. statusCode: expectedStatusCode
  819. )
  820. do {
  821. let stream = try model.generateContentStream(testPrompt)
  822. for try await _ in stream {
  823. XCTFail("No content is there, this shouldn't happen.")
  824. }
  825. } catch let GenerateContentError.internalError(error as BackendError) {
  826. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  827. XCTAssertEqual(error.status, .permissionDenied)
  828. XCTAssertTrue(error.message
  829. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  830. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  831. return
  832. }
  833. XCTFail("Should have caught an error.")
  834. }
  835. func testGenerateContentStream_failureEmptyContent() async throws {
  836. MockURLProtocol
  837. .requestHandler = try httpRequestHandler(
  838. forResource: "streaming-failure-empty-content",
  839. withExtension: "txt"
  840. )
  841. do {
  842. let stream = try model.generateContentStream("Hi")
  843. for try await _ in stream {
  844. XCTFail("No content is there, this shouldn't happen.")
  845. }
  846. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  847. // Underlying error is as expected, nothing else to check.
  848. return
  849. }
  850. XCTFail("Should have caught an error.")
  851. }
  852. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  853. MockURLProtocol
  854. .requestHandler = try httpRequestHandler(
  855. forResource: "streaming-failure-finish-reason-safety",
  856. withExtension: "txt"
  857. )
  858. do {
  859. let stream = try model.generateContentStream("Hi")
  860. for try await _ in stream {
  861. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  862. }
  863. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  864. XCTAssertEqual(reason, .safety)
  865. let candidate = try XCTUnwrap(response.candidates.first)
  866. XCTAssertEqual(candidate.finishReason, reason)
  867. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  868. return
  869. }
  870. XCTFail("Should have caught an error.")
  871. }
  872. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  873. MockURLProtocol
  874. .requestHandler = try httpRequestHandler(
  875. forResource: "streaming-failure-prompt-blocked-safety",
  876. withExtension: "txt"
  877. )
  878. do {
  879. let stream = try model.generateContentStream("Hi")
  880. for try await _ in stream {
  881. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  882. }
  883. } catch let GenerateContentError.promptBlocked(response) {
  884. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  885. XCTAssertEqual(promptFeedback.blockReason, .safety)
  886. XCTAssertNil(promptFeedback.blockReasonMessage)
  887. return
  888. }
  889. XCTFail("Should have caught an error.")
  890. }
  891. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  892. MockURLProtocol
  893. .requestHandler = try httpRequestHandler(
  894. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  895. withExtension: "txt"
  896. )
  897. do {
  898. let stream = try model.generateContentStream("Hi")
  899. for try await _ in stream {
  900. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  901. }
  902. } catch let GenerateContentError.promptBlocked(response) {
  903. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  904. XCTAssertEqual(promptFeedback.blockReason, .safety)
  905. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  906. return
  907. }
  908. XCTFail("Should have caught an error.")
  909. }
  910. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  911. MockURLProtocol
  912. .requestHandler = try httpRequestHandler(
  913. forResource: "streaming-failure-unknown-finish-enum",
  914. withExtension: "txt"
  915. )
  916. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  917. let stream = try model.generateContentStream("Hi")
  918. do {
  919. for try await content in stream {
  920. XCTAssertNotNil(content.text)
  921. }
  922. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  923. XCTAssertEqual(reason, unknownFinishReason)
  924. return
  925. }
  926. XCTFail("Should have caught an error.")
  927. }
  928. func testGenerateContentStream_successBasicReplyLong() async throws {
  929. MockURLProtocol
  930. .requestHandler = try httpRequestHandler(
  931. forResource: "streaming-success-basic-reply-long",
  932. withExtension: "txt"
  933. )
  934. var responses = 0
  935. let stream = try model.generateContentStream("Hi")
  936. for try await content in stream {
  937. XCTAssertNotNil(content.text)
  938. responses += 1
  939. }
  940. XCTAssertEqual(responses, 6)
  941. }
  942. func testGenerateContentStream_successBasicReplyShort() async throws {
  943. MockURLProtocol
  944. .requestHandler = try httpRequestHandler(
  945. forResource: "streaming-success-basic-reply-short",
  946. withExtension: "txt"
  947. )
  948. var responses = 0
  949. let stream = try model.generateContentStream("Hi")
  950. for try await content in stream {
  951. XCTAssertNotNil(content.text)
  952. responses += 1
  953. }
  954. XCTAssertEqual(responses, 1)
  955. }
  956. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  957. MockURLProtocol
  958. .requestHandler = try httpRequestHandler(
  959. forResource: "streaming-success-unknown-safety-enum",
  960. withExtension: "txt"
  961. )
  962. let unknownSafetyRating = SafetyRating(
  963. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  964. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  965. probabilityScore: 0.0,
  966. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  967. severityScore: 0.0,
  968. blocked: false
  969. )
  970. var foundUnknownSafetyRating = false
  971. let stream = try model.generateContentStream("Hi")
  972. for try await content in stream {
  973. XCTAssertNotNil(content.text)
  974. if let ratings = content.candidates.first?.safetyRatings,
  975. ratings.contains(where: { $0 == unknownSafetyRating }) {
  976. foundUnknownSafetyRating = true
  977. }
  978. }
  979. XCTAssertTrue(foundUnknownSafetyRating)
  980. }
  981. func testGenerateContentStream_successWithCitations() async throws {
  982. MockURLProtocol
  983. .requestHandler = try httpRequestHandler(
  984. forResource: "streaming-success-citations",
  985. withExtension: "txt"
  986. )
  987. let expectedPublicationDate = DateComponents(
  988. calendar: Calendar(identifier: .gregorian),
  989. year: 2014,
  990. month: 3,
  991. day: 30
  992. )
  993. let stream = try model.generateContentStream("Hi")
  994. var citations = [Citation]()
  995. var responses = [GenerateContentResponse]()
  996. for try await content in stream {
  997. responses.append(content)
  998. XCTAssertNotNil(content.text)
  999. let candidate = try XCTUnwrap(content.candidates.first)
  1000. if let sources = candidate.citationMetadata?.citations {
  1001. citations.append(contentsOf: sources)
  1002. }
  1003. }
  1004. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1005. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1006. XCTAssertEqual(citations.count, 6)
  1007. XCTAssertTrue(citations
  1008. .contains {
  1009. $0.startIndex == 0 && $0.endIndex == 128
  1010. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1011. && $0.license == nil && $0.publicationDate == nil
  1012. })
  1013. XCTAssertTrue(citations
  1014. .contains {
  1015. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1016. && $0.title == "some-citation-2" && $0.license == nil
  1017. && $0.publicationDate == expectedPublicationDate
  1018. })
  1019. XCTAssertTrue(citations
  1020. .contains {
  1021. $0.startIndex == 272 && $0.endIndex == 431
  1022. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1023. && $0.license == "mit" && $0.publicationDate == nil
  1024. })
  1025. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1026. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1027. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1028. }
  1029. func testGenerateContentStream_appCheck_validToken() async throws {
  1030. let appCheckToken = "test-valid-token"
  1031. model = GenerativeModel(
  1032. name: testModelResourceName,
  1033. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
  1034. apiConfig: apiConfig,
  1035. tools: nil,
  1036. requestOptions: RequestOptions(),
  1037. urlSession: urlSession
  1038. )
  1039. MockURLProtocol
  1040. .requestHandler = try httpRequestHandler(
  1041. forResource: "streaming-success-basic-reply-short",
  1042. withExtension: "txt",
  1043. appCheckToken: appCheckToken
  1044. )
  1045. let stream = try model.generateContentStream(testPrompt)
  1046. for try await _ in stream {}
  1047. }
  1048. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1049. model = GenerativeModel(
  1050. name: testModelResourceName,
  1051. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
  1052. apiConfig: apiConfig,
  1053. tools: nil,
  1054. requestOptions: RequestOptions(),
  1055. urlSession: urlSession
  1056. )
  1057. MockURLProtocol
  1058. .requestHandler = try httpRequestHandler(
  1059. forResource: "streaming-success-basic-reply-short",
  1060. withExtension: "txt",
  1061. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1062. )
  1063. let stream = try model.generateContentStream(testPrompt)
  1064. for try await _ in stream {}
  1065. }
  1066. func testGenerateContentStream_usageMetadata() async throws {
  1067. MockURLProtocol
  1068. .requestHandler = try httpRequestHandler(
  1069. forResource: "streaming-success-basic-reply-short",
  1070. withExtension: "txt"
  1071. )
  1072. var responses = [GenerateContentResponse]()
  1073. let stream = try model.generateContentStream(testPrompt)
  1074. for try await response in stream {
  1075. responses.append(response)
  1076. }
  1077. for (index, response) in responses.enumerated() {
  1078. if index == responses.endIndex - 1 {
  1079. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1080. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1081. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1082. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1083. } else {
  1084. // Only the last streamed response contains usage metadata
  1085. XCTAssertNil(response.usageMetadata)
  1086. }
  1087. }
  1088. }
  1089. func testGenerateContentStream_errorMidStream() async throws {
  1090. MockURLProtocol.requestHandler = try httpRequestHandler(
  1091. forResource: "streaming-failure-error-mid-stream",
  1092. withExtension: "txt"
  1093. )
  1094. var responseCount = 0
  1095. do {
  1096. let stream = try model.generateContentStream("Hi")
  1097. for try await content in stream {
  1098. XCTAssertNotNil(content.text)
  1099. responseCount += 1
  1100. }
  1101. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1102. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1103. XCTAssertEqual(rpcError.status, .cancelled)
  1104. // Check the content count is correct.
  1105. XCTAssertEqual(responseCount, 2)
  1106. return
  1107. }
  1108. XCTFail("Expected an internalError with an RPCError.")
  1109. }
  1110. func testGenerateContentStream_nonHTTPResponse() async throws {
  1111. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  1112. let stream = try model.generateContentStream("Hi")
  1113. do {
  1114. for try await content in stream {
  1115. XCTFail("Unexpected content in stream: \(content)")
  1116. }
  1117. } catch let GenerateContentError.internalError(underlying) {
  1118. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1119. return
  1120. }
  1121. XCTFail("Expected an internal error.")
  1122. }
  1123. func testGenerateContentStream_invalidResponse() async throws {
  1124. MockURLProtocol
  1125. .requestHandler = try httpRequestHandler(
  1126. forResource: "streaming-failure-invalid-json",
  1127. withExtension: "txt"
  1128. )
  1129. let stream = try model.generateContentStream(testPrompt)
  1130. do {
  1131. for try await content in stream {
  1132. XCTFail("Unexpected content in stream: \(content)")
  1133. }
  1134. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1135. guard case let .dataCorrupted(context) = underlying else {
  1136. XCTFail("Not a data corrupted error: \(underlying)")
  1137. return
  1138. }
  1139. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1140. return
  1141. }
  1142. XCTFail("Expected an internal error.")
  1143. }
  1144. func testGenerateContentStream_malformedContent() async throws {
  1145. MockURLProtocol
  1146. .requestHandler = try httpRequestHandler(
  1147. forResource: "streaming-failure-malformed-content",
  1148. withExtension: "txt"
  1149. )
  1150. let stream = try model.generateContentStream(testPrompt)
  1151. do {
  1152. for try await content in stream {
  1153. XCTFail("Unexpected content in stream: \(content)")
  1154. }
  1155. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1156. guard case let .malformedContent(contentError) = underlyingError else {
  1157. XCTFail("Not a malformed content error: \(underlyingError)")
  1158. return
  1159. }
  1160. XCTAssert(contentError is DecodingError)
  1161. return
  1162. }
  1163. XCTFail("Expected an internal decoding error.")
  1164. }
  1165. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1166. let expectedTimeout = 150.0
  1167. MockURLProtocol
  1168. .requestHandler = try httpRequestHandler(
  1169. forResource: "streaming-success-basic-reply-short",
  1170. withExtension: "txt",
  1171. timeout: expectedTimeout
  1172. )
  1173. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1174. model = GenerativeModel(
  1175. name: testModelResourceName,
  1176. firebaseInfo: testFirebaseInfo(),
  1177. apiConfig: apiConfig,
  1178. tools: nil,
  1179. requestOptions: requestOptions,
  1180. urlSession: urlSession
  1181. )
  1182. var responses = 0
  1183. let stream = try model.generateContentStream(testPrompt)
  1184. for try await content in stream {
  1185. XCTAssertNotNil(content.text)
  1186. responses += 1
  1187. }
  1188. XCTAssertEqual(responses, 1)
  1189. }
  1190. // MARK: - Count Tokens
  1191. func testCountTokens_succeeds() async throws {
  1192. MockURLProtocol.requestHandler = try httpRequestHandler(
  1193. forResource: "unary-success-total-tokens",
  1194. withExtension: "json"
  1195. )
  1196. let response = try await model.countTokens("Why is the sky blue?")
  1197. XCTAssertEqual(response.totalTokens, 6)
  1198. XCTAssertEqual(response.totalBillableCharacters, 16)
  1199. }
  1200. func testCountTokens_succeeds_detailed() async throws {
  1201. MockURLProtocol.requestHandler = try httpRequestHandler(
  1202. forResource: "unary-success-detailed-token-response",
  1203. withExtension: "json"
  1204. )
  1205. let response = try await model.countTokens("Why is the sky blue?")
  1206. XCTAssertEqual(response.totalTokens, 1837)
  1207. XCTAssertEqual(response.totalBillableCharacters, 117)
  1208. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1209. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1210. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1211. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1212. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1213. }
  1214. func testCountTokens_succeeds_allOptions() async throws {
  1215. MockURLProtocol.requestHandler = try httpRequestHandler(
  1216. forResource: "unary-success-total-tokens",
  1217. withExtension: "json"
  1218. )
  1219. let generationConfig = GenerationConfig(
  1220. temperature: 0.5,
  1221. topP: 0.9,
  1222. topK: 3,
  1223. candidateCount: 1,
  1224. maxOutputTokens: 1024,
  1225. stopSequences: ["test-stop"],
  1226. responseMIMEType: "text/plain"
  1227. )
  1228. let sumFunction = FunctionDeclaration(
  1229. name: "sum",
  1230. description: "Add two integers.",
  1231. parameters: ["x": .integer(), "y": .integer()]
  1232. )
  1233. let systemInstruction = ModelContent(
  1234. role: "system",
  1235. parts: "You are a calculator. Use the provided tools."
  1236. )
  1237. model = GenerativeModel(
  1238. name: testModelResourceName,
  1239. firebaseInfo: testFirebaseInfo(),
  1240. apiConfig: apiConfig,
  1241. generationConfig: generationConfig,
  1242. tools: [Tool(functionDeclarations: [sumFunction])],
  1243. systemInstruction: systemInstruction,
  1244. requestOptions: RequestOptions(),
  1245. urlSession: urlSession
  1246. )
  1247. let response = try await model.countTokens("Why is the sky blue?")
  1248. XCTAssertEqual(response.totalTokens, 6)
  1249. XCTAssertEqual(response.totalBillableCharacters, 16)
  1250. }
  1251. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1252. MockURLProtocol.requestHandler = try httpRequestHandler(
  1253. forResource: "unary-success-no-billable-characters",
  1254. withExtension: "json"
  1255. )
  1256. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1257. XCTAssertEqual(response.totalTokens, 258)
  1258. XCTAssertNil(response.totalBillableCharacters)
  1259. }
  1260. func testCountTokens_modelNotFound() async throws {
  1261. MockURLProtocol.requestHandler = try httpRequestHandler(
  1262. forResource: "unary-failure-model-not-found", withExtension: "json",
  1263. statusCode: 404
  1264. )
  1265. do {
  1266. _ = try await model.countTokens("Why is the sky blue?")
  1267. XCTFail("Request should not have succeeded.")
  1268. } catch let rpcError as BackendError {
  1269. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1270. XCTAssertEqual(rpcError.status, .notFound)
  1271. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1272. return
  1273. }
  1274. XCTFail("Expected internal RPCError.")
  1275. }
  1276. func testCountTokens_requestOptions_customTimeout() async throws {
  1277. let expectedTimeout = 150.0
  1278. MockURLProtocol
  1279. .requestHandler = try httpRequestHandler(
  1280. forResource: "unary-success-total-tokens",
  1281. withExtension: "json",
  1282. timeout: expectedTimeout
  1283. )
  1284. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1285. model = GenerativeModel(
  1286. name: testModelResourceName,
  1287. firebaseInfo: testFirebaseInfo(),
  1288. apiConfig: apiConfig,
  1289. tools: nil,
  1290. requestOptions: requestOptions,
  1291. urlSession: urlSession
  1292. )
  1293. let response = try await model.countTokens(testPrompt)
  1294. XCTAssertEqual(response.totalTokens, 6)
  1295. }
  1296. // MARK: - Helpers
  1297. private func testFirebaseInfo(appCheck: AppCheckInterop? = nil,
  1298. auth: AuthInterop? = nil,
  1299. privateAppID: Bool = false) -> FirebaseInfo {
  1300. let app = FirebaseApp(instanceWithName: "testApp",
  1301. options: FirebaseOptions(googleAppID: "ignore",
  1302. gcmSenderID: "ignore"))
  1303. app.isDataCollectionDefaultEnabled = !privateAppID
  1304. return FirebaseInfo(
  1305. appCheck: appCheck,
  1306. auth: auth,
  1307. projectID: "my-project-id",
  1308. apiKey: "API_KEY",
  1309. googleAppID: "My app ID",
  1310. firebaseApp: app
  1311. )
  1312. }
  1313. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1314. URLResponse,
  1315. AsyncLineSequence<URL.AsyncBytes>?
  1316. )) {
  1317. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1318. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1319. #if os(watchOS)
  1320. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1321. #endif // os(watchOS)
  1322. return { request in
  1323. // This is *not* an HTTPURLResponse
  1324. let response = URLResponse(
  1325. url: request.url!,
  1326. mimeType: nil,
  1327. expectedContentLength: 0,
  1328. textEncodingName: nil
  1329. )
  1330. return (response, nil)
  1331. }
  1332. }
  1333. private func httpRequestHandler(forResource name: String,
  1334. withExtension ext: String,
  1335. statusCode: Int = 200,
  1336. timeout: TimeInterval = RequestOptions().timeout,
  1337. appCheckToken: String? = nil,
  1338. authToken: String? = nil,
  1339. dataCollection: Bool = true) throws -> ((URLRequest) throws -> (
  1340. URLResponse,
  1341. AsyncLineSequence<URL.AsyncBytes>?
  1342. )) {
  1343. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1344. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1345. #if os(watchOS)
  1346. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1347. #endif // os(watchOS)
  1348. let bundle = BundleTestUtil.bundle()
  1349. let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
  1350. return { request in
  1351. let requestURL = try XCTUnwrap(request.url)
  1352. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1353. XCTAssertEqual(request.timeoutInterval, timeout)
  1354. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1355. .components(separatedBy: " ")
  1356. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1357. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1358. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1359. // TODO: Wait for release approval
  1360. // let googleAppID = request.value(forHTTPHeaderField: "X-Firebase-AppId")
  1361. // XCTAssertEqual(googleAppID, dataCollection ? "My app ID" : nil)
  1362. if let authToken {
  1363. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1364. } else {
  1365. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1366. }
  1367. let response = try XCTUnwrap(HTTPURLResponse(
  1368. url: requestURL,
  1369. statusCode: statusCode,
  1370. httpVersion: nil,
  1371. headerFields: nil
  1372. ))
  1373. return (response, fileURL.lines)
  1374. }
  1375. }
  1376. }
  1377. private extension String {
  1378. /// Returns the number of occurrences of `substring` in the `String`.
  1379. func occurrenceCount(of substring: String) -> Int {
  1380. return components(separatedBy: substring).count - 1
  1381. }
  1382. }
  1383. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1384. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1385. /// The placeholder token value returned when an error occurs
  1386. static let placeholderTokenValue = "placeholder-token"
  1387. var token: String
  1388. var error: Error?
  1389. private init(token: String, error: Error?) {
  1390. self.token = token
  1391. self.error = error
  1392. }
  1393. convenience init(token: String) {
  1394. self.init(token: token, error: nil)
  1395. }
  1396. convenience init(error: Error) {
  1397. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1398. }
  1399. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1400. return AppCheckTokenResultInteropFake(token: token, error: error)
  1401. }
  1402. func tokenDidChangeNotificationName() -> String {
  1403. fatalError("\(#function) not implemented.")
  1404. }
  1405. func notificationTokenKey() -> String {
  1406. fatalError("\(#function) not implemented.")
  1407. }
  1408. func notificationAppNameKey() -> String {
  1409. fatalError("\(#function) not implemented.")
  1410. }
  1411. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1412. var token: String
  1413. var error: Error?
  1414. init(token: String, error: Error?) {
  1415. self.token = token
  1416. self.error = error
  1417. }
  1418. }
  1419. }
  1420. struct AppCheckErrorFake: Error {}
  1421. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1422. extension SafetyRating: Swift.Comparable {
  1423. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1424. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1425. return lhs.category.rawValue < rhs.category.rawValue
  1426. }
  1427. }