GenerativeModelTests.swift 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.1431877,
  27. severity: .negligible,
  28. severityScore: 0.11027937,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.029035643,
  35. severity: .negligible,
  36. severityScore: 0.05613278,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.087252244,
  43. severity: .negligible,
  44. severityScore: 0.04509957,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.2641685,
  51. severity: .negligible,
  52. severityScore: 0.082253955,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let testModelResourceName =
  57. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  58. let apiConfig = APIConfig(service: .vertexAI, version: .v1beta)
  59. let vertexSubdirectory = "vertexai"
  60. var urlSession: URLSession!
  61. var model: GenerativeModel!
  62. override func setUp() async throws {
  63. let configuration = URLSessionConfiguration.default
  64. configuration.protocolClasses = [MockURLProtocol.self]
  65. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  66. model = GenerativeModel(
  67. name: testModelResourceName,
  68. firebaseInfo: testFirebaseInfo(),
  69. apiConfig: apiConfig,
  70. tools: nil,
  71. requestOptions: RequestOptions(),
  72. urlSession: urlSession
  73. )
  74. }
  75. override func tearDown() {
  76. MockURLProtocol.requestHandler = nil
  77. }
  78. // MARK: - Generate Content
  79. func testGenerateContent_success_basicReplyLong() async throws {
  80. MockURLProtocol
  81. .requestHandler = try httpRequestHandler(
  82. forResource: "unary-success-basic-reply-long",
  83. withExtension: "json",
  84. subdirectory: vertexSubdirectory
  85. )
  86. let response = try await model.generateContent(testPrompt)
  87. XCTAssertEqual(response.candidates.count, 1)
  88. let candidate = try XCTUnwrap(response.candidates.first)
  89. let finishReason = try XCTUnwrap(candidate.finishReason)
  90. XCTAssertEqual(finishReason, .stop)
  91. XCTAssertEqual(candidate.safetyRatings.count, 4)
  92. XCTAssertEqual(candidate.content.parts.count, 1)
  93. let part = try XCTUnwrap(candidate.content.parts.first)
  94. let partText = try XCTUnwrap(part as? TextPart).text
  95. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  96. XCTAssertEqual(response.text, partText)
  97. XCTAssertEqual(response.functionCalls, [])
  98. }
  99. func testGenerateContent_success_basicReplyShort() async throws {
  100. MockURLProtocol
  101. .requestHandler = try httpRequestHandler(
  102. forResource: "unary-success-basic-reply-short",
  103. withExtension: "json",
  104. subdirectory: vertexSubdirectory
  105. )
  106. let response = try await model.generateContent(testPrompt)
  107. XCTAssertEqual(response.candidates.count, 1)
  108. let candidate = try XCTUnwrap(response.candidates.first)
  109. let finishReason = try XCTUnwrap(candidate.finishReason)
  110. XCTAssertEqual(finishReason, .stop)
  111. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  112. XCTAssertEqual(candidate.content.parts.count, 1)
  113. let part = try XCTUnwrap(candidate.content.parts.first)
  114. let textPart = try XCTUnwrap(part as? TextPart)
  115. XCTAssertEqual(textPart.text, "Mountain View, California")
  116. XCTAssertEqual(response.text, textPart.text)
  117. XCTAssertEqual(response.functionCalls, [])
  118. }
  119. func testGenerateContent_success_basicReplyFullUsageMetadata() async throws {
  120. MockURLProtocol
  121. .requestHandler = try httpRequestHandler(
  122. forResource: "unary-success-basic-response-long-usage-metadata",
  123. withExtension: "json",
  124. subdirectory: vertexSubdirectory
  125. )
  126. let response = try await model.generateContent(testPrompt)
  127. XCTAssertEqual(response.candidates.count, 1)
  128. let candidate = try XCTUnwrap(response.candidates.first)
  129. let finishReason = try XCTUnwrap(candidate.finishReason)
  130. XCTAssertEqual(finishReason, .stop)
  131. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  132. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2)
  133. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image)
  134. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806)
  135. XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text)
  136. XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76)
  137. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  138. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  139. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76)
  140. }
  141. func testGenerateContent_success_citations() async throws {
  142. MockURLProtocol
  143. .requestHandler = try httpRequestHandler(
  144. forResource: "unary-success-citations",
  145. withExtension: "json",
  146. subdirectory: vertexSubdirectory
  147. )
  148. let expectedPublicationDate = DateComponents(
  149. calendar: Calendar(identifier: .gregorian),
  150. year: 2019,
  151. month: 5,
  152. day: 10
  153. )
  154. let response = try await model.generateContent(testPrompt)
  155. XCTAssertEqual(response.candidates.count, 1)
  156. let candidate = try XCTUnwrap(response.candidates.first)
  157. XCTAssertEqual(candidate.content.parts.count, 1)
  158. XCTAssertEqual(response.text, "Some information cited from an external source")
  159. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  160. XCTAssertEqual(citationMetadata.citations.count, 3)
  161. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  162. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  163. XCTAssertEqual(citationSource1.startIndex, 0)
  164. XCTAssertEqual(citationSource1.endIndex, 128)
  165. XCTAssertNil(citationSource1.title)
  166. XCTAssertNil(citationSource1.license)
  167. XCTAssertNil(citationSource1.publicationDate)
  168. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  169. XCTAssertEqual(citationSource2.title, "some-citation-2")
  170. XCTAssertEqual(citationSource2.publicationDate, expectedPublicationDate)
  171. XCTAssertEqual(citationSource2.startIndex, 130)
  172. XCTAssertEqual(citationSource2.endIndex, 265)
  173. XCTAssertNil(citationSource2.uri)
  174. XCTAssertNil(citationSource2.license)
  175. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  176. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  177. XCTAssertEqual(citationSource3.startIndex, 272)
  178. XCTAssertEqual(citationSource3.endIndex, 431)
  179. XCTAssertEqual(citationSource3.license, "mit")
  180. XCTAssertNil(citationSource3.title)
  181. XCTAssertNil(citationSource3.publicationDate)
  182. }
  183. func testGenerateContent_success_quoteReply() async throws {
  184. MockURLProtocol
  185. .requestHandler = try httpRequestHandler(
  186. forResource: "unary-success-quote-reply",
  187. withExtension: "json",
  188. subdirectory: vertexSubdirectory
  189. )
  190. let response = try await model.generateContent(testPrompt)
  191. XCTAssertEqual(response.candidates.count, 1)
  192. let candidate = try XCTUnwrap(response.candidates.first)
  193. let finishReason = try XCTUnwrap(candidate.finishReason)
  194. XCTAssertEqual(finishReason, .stop)
  195. XCTAssertEqual(candidate.safetyRatings.count, 4)
  196. XCTAssertEqual(candidate.content.parts.count, 1)
  197. let part = try XCTUnwrap(candidate.content.parts.first)
  198. let textPart = try XCTUnwrap(part as? TextPart)
  199. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  200. XCTAssertEqual(response.text, textPart.text)
  201. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  202. XCTAssertNil(promptFeedback.blockReason)
  203. XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
  204. }
  205. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  206. let expectedSafetyRatings = [
  207. SafetyRating(
  208. category: .harassment,
  209. probability: .medium,
  210. probabilityScore: 0.0,
  211. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  212. severityScore: 0.0,
  213. blocked: false
  214. ),
  215. SafetyRating(
  216. category: .dangerousContent,
  217. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
  218. probabilityScore: 0.0,
  219. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  220. severityScore: 0.0,
  221. blocked: false
  222. ),
  223. SafetyRating(
  224. category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
  225. probability: .high,
  226. probabilityScore: 0.0,
  227. severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  228. severityScore: 0.0,
  229. blocked: false
  230. ),
  231. ]
  232. MockURLProtocol
  233. .requestHandler = try httpRequestHandler(
  234. forResource: "unary-success-unknown-enum-safety-ratings",
  235. withExtension: "json",
  236. subdirectory: vertexSubdirectory
  237. )
  238. let response = try await model.generateContent(testPrompt)
  239. XCTAssertEqual(response.text, "Some text")
  240. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  241. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  242. }
  243. func testGenerateContent_success_prefixedModelName() async throws {
  244. MockURLProtocol
  245. .requestHandler = try httpRequestHandler(
  246. forResource: "unary-success-basic-reply-short",
  247. withExtension: "json",
  248. subdirectory: vertexSubdirectory
  249. )
  250. let model = GenerativeModel(
  251. // Model name is prefixed with "models/".
  252. name: "models/test-model",
  253. firebaseInfo: testFirebaseInfo(),
  254. apiConfig: apiConfig,
  255. tools: nil,
  256. requestOptions: RequestOptions(),
  257. urlSession: urlSession
  258. )
  259. _ = try await model.generateContent(testPrompt)
  260. }
  261. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  262. MockURLProtocol
  263. .requestHandler = try httpRequestHandler(
  264. forResource: "unary-success-function-call-empty-arguments",
  265. withExtension: "json",
  266. subdirectory: vertexSubdirectory
  267. )
  268. let response = try await model.generateContent(testPrompt)
  269. XCTAssertEqual(response.candidates.count, 1)
  270. let candidate = try XCTUnwrap(response.candidates.first)
  271. XCTAssertEqual(candidate.content.parts.count, 1)
  272. let part = try XCTUnwrap(candidate.content.parts.first)
  273. guard let functionCall = part as? FunctionCallPart else {
  274. XCTFail("Part is not a FunctionCall.")
  275. return
  276. }
  277. XCTAssertEqual(functionCall.name, "current_time")
  278. XCTAssertTrue(functionCall.args.isEmpty)
  279. XCTAssertEqual(response.functionCalls, [functionCall])
  280. }
  281. func testGenerateContent_success_functionCall_noArguments() async throws {
  282. MockURLProtocol
  283. .requestHandler = try httpRequestHandler(
  284. forResource: "unary-success-function-call-no-arguments",
  285. withExtension: "json",
  286. subdirectory: vertexSubdirectory
  287. )
  288. let response = try await model.generateContent(testPrompt)
  289. XCTAssertEqual(response.candidates.count, 1)
  290. let candidate = try XCTUnwrap(response.candidates.first)
  291. XCTAssertEqual(candidate.content.parts.count, 1)
  292. let part = try XCTUnwrap(candidate.content.parts.first)
  293. guard let functionCall = part as? FunctionCallPart else {
  294. XCTFail("Part is not a FunctionCall.")
  295. return
  296. }
  297. XCTAssertEqual(functionCall.name, "current_time")
  298. XCTAssertTrue(functionCall.args.isEmpty)
  299. XCTAssertEqual(response.functionCalls, [functionCall])
  300. }
  301. func testGenerateContent_success_functionCall_withArguments() async throws {
  302. MockURLProtocol
  303. .requestHandler = try httpRequestHandler(
  304. forResource: "unary-success-function-call-with-arguments",
  305. withExtension: "json",
  306. subdirectory: vertexSubdirectory
  307. )
  308. let response = try await model.generateContent(testPrompt)
  309. XCTAssertEqual(response.candidates.count, 1)
  310. let candidate = try XCTUnwrap(response.candidates.first)
  311. XCTAssertEqual(candidate.content.parts.count, 1)
  312. let part = try XCTUnwrap(candidate.content.parts.first)
  313. guard let functionCall = part as? FunctionCallPart else {
  314. XCTFail("Part is not a FunctionCall.")
  315. return
  316. }
  317. XCTAssertEqual(functionCall.name, "sum")
  318. XCTAssertEqual(functionCall.args.count, 2)
  319. let argX = try XCTUnwrap(functionCall.args["x"])
  320. XCTAssertEqual(argX, .number(4))
  321. let argY = try XCTUnwrap(functionCall.args["y"])
  322. XCTAssertEqual(argY, .number(5))
  323. XCTAssertEqual(response.functionCalls, [functionCall])
  324. }
  325. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  326. MockURLProtocol
  327. .requestHandler = try httpRequestHandler(
  328. forResource: "unary-success-function-call-parallel-calls",
  329. withExtension: "json",
  330. subdirectory: vertexSubdirectory
  331. )
  332. let response = try await model.generateContent(testPrompt)
  333. XCTAssertEqual(response.candidates.count, 1)
  334. let candidate = try XCTUnwrap(response.candidates.first)
  335. XCTAssertEqual(candidate.content.parts.count, 3)
  336. let functionCalls = response.functionCalls
  337. XCTAssertEqual(functionCalls.count, 3)
  338. }
  339. func testGenerateContent_success_functionCall_mixedContent() async throws {
  340. MockURLProtocol
  341. .requestHandler = try httpRequestHandler(
  342. forResource: "unary-success-function-call-mixed-content",
  343. withExtension: "json",
  344. subdirectory: vertexSubdirectory
  345. )
  346. let response = try await model.generateContent(testPrompt)
  347. XCTAssertEqual(response.candidates.count, 1)
  348. let candidate = try XCTUnwrap(response.candidates.first)
  349. XCTAssertEqual(candidate.content.parts.count, 4)
  350. let functionCalls = response.functionCalls
  351. XCTAssertEqual(functionCalls.count, 2)
  352. let text = try XCTUnwrap(response.text)
  353. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  354. }
  355. func testGenerateContent_appCheck_validToken() async throws {
  356. let appCheckToken = "test-valid-token"
  357. model = GenerativeModel(
  358. name: testModelResourceName,
  359. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
  360. apiConfig: apiConfig,
  361. tools: nil,
  362. requestOptions: RequestOptions(),
  363. urlSession: urlSession
  364. )
  365. MockURLProtocol
  366. .requestHandler = try httpRequestHandler(
  367. forResource: "unary-success-basic-reply-short",
  368. withExtension: "json",
  369. subdirectory: vertexSubdirectory,
  370. appCheckToken: appCheckToken
  371. )
  372. _ = try await model.generateContent(testPrompt)
  373. }
  374. func testGenerateContent_dataCollectionOff() async throws {
  375. let appCheckToken = "test-valid-token"
  376. model = GenerativeModel(
  377. name: testModelResourceName,
  378. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken),
  379. privateAppID: true),
  380. apiConfig: apiConfig,
  381. tools: nil,
  382. requestOptions: RequestOptions(),
  383. urlSession: urlSession
  384. )
  385. MockURLProtocol
  386. .requestHandler = try httpRequestHandler(
  387. forResource: "unary-success-basic-reply-short",
  388. withExtension: "json",
  389. subdirectory: vertexSubdirectory,
  390. appCheckToken: appCheckToken,
  391. dataCollection: false
  392. )
  393. _ = try await model.generateContent(testPrompt)
  394. }
  395. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  396. model = GenerativeModel(
  397. name: testModelResourceName,
  398. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
  399. apiConfig: apiConfig,
  400. tools: nil,
  401. requestOptions: RequestOptions(),
  402. urlSession: urlSession
  403. )
  404. MockURLProtocol
  405. .requestHandler = try httpRequestHandler(
  406. forResource: "unary-success-basic-reply-short",
  407. withExtension: "json",
  408. subdirectory: vertexSubdirectory,
  409. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  410. )
  411. _ = try await model.generateContent(testPrompt)
  412. }
  413. func testGenerateContent_auth_validAuthToken() async throws {
  414. let authToken = "test-valid-token"
  415. model = GenerativeModel(
  416. name: testModelResourceName,
  417. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)),
  418. apiConfig: apiConfig,
  419. tools: nil,
  420. requestOptions: RequestOptions(),
  421. urlSession: urlSession
  422. )
  423. MockURLProtocol
  424. .requestHandler = try httpRequestHandler(
  425. forResource: "unary-success-basic-reply-short",
  426. withExtension: "json",
  427. subdirectory: vertexSubdirectory,
  428. authToken: authToken
  429. )
  430. _ = try await model.generateContent(testPrompt)
  431. }
  432. func testGenerateContent_auth_nilAuthToken() async throws {
  433. model = GenerativeModel(
  434. name: testModelResourceName,
  435. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)),
  436. apiConfig: apiConfig,
  437. tools: nil,
  438. requestOptions: RequestOptions(),
  439. urlSession: urlSession
  440. )
  441. MockURLProtocol
  442. .requestHandler = try httpRequestHandler(
  443. forResource: "unary-success-basic-reply-short",
  444. withExtension: "json",
  445. subdirectory: vertexSubdirectory,
  446. authToken: nil
  447. )
  448. _ = try await model.generateContent(testPrompt)
  449. }
  450. func testGenerateContent_auth_authTokenRefreshError() async throws {
  451. model = GenerativeModel(
  452. name: "my-model",
  453. firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())),
  454. apiConfig: apiConfig,
  455. tools: nil,
  456. requestOptions: RequestOptions(),
  457. urlSession: urlSession
  458. )
  459. MockURLProtocol
  460. .requestHandler = try httpRequestHandler(
  461. forResource: "unary-success-basic-reply-short",
  462. withExtension: "json",
  463. subdirectory: vertexSubdirectory,
  464. authToken: nil
  465. )
  466. do {
  467. _ = try await model.generateContent(testPrompt)
  468. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  469. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  470. //
  471. } catch {
  472. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  473. }
  474. }
  475. func testGenerateContent_usageMetadata() async throws {
  476. MockURLProtocol
  477. .requestHandler = try httpRequestHandler(
  478. forResource: "unary-success-basic-reply-short",
  479. withExtension: "json",
  480. subdirectory: vertexSubdirectory
  481. )
  482. let response = try await model.generateContent(testPrompt)
  483. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  484. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  485. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  486. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  487. XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true)
  488. XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true)
  489. }
  490. func testGenerateContent_failure_invalidAPIKey() async throws {
  491. let expectedStatusCode = 400
  492. MockURLProtocol
  493. .requestHandler = try httpRequestHandler(
  494. forResource: "unary-failure-api-key",
  495. withExtension: "json",
  496. subdirectory: vertexSubdirectory,
  497. statusCode: expectedStatusCode
  498. )
  499. do {
  500. _ = try await model.generateContent(testPrompt)
  501. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  502. } catch let GenerateContentError.internalError(error as BackendError) {
  503. XCTAssertEqual(error.httpResponseCode, 400)
  504. XCTAssertEqual(error.status, .invalidArgument)
  505. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  506. XCTAssertTrue(error.localizedDescription.contains(error.message))
  507. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  508. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  509. let nsError = error as NSError
  510. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  511. XCTAssertEqual(nsError.code, error.httpResponseCode)
  512. return
  513. } catch {
  514. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  515. }
  516. }
  517. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  518. let expectedStatusCode = 403
  519. MockURLProtocol
  520. .requestHandler = try httpRequestHandler(
  521. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  522. withExtension: "json",
  523. subdirectory: vertexSubdirectory,
  524. statusCode: expectedStatusCode
  525. )
  526. do {
  527. _ = try await model.generateContent(testPrompt)
  528. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  529. } catch let GenerateContentError.internalError(error as BackendError) {
  530. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  531. XCTAssertEqual(error.status, .permissionDenied)
  532. XCTAssertTrue(error.message
  533. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  534. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  535. return
  536. } catch {
  537. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  538. }
  539. }
  540. func testGenerateContent_failure_emptyContent() async throws {
  541. MockURLProtocol
  542. .requestHandler = try httpRequestHandler(
  543. forResource: "unary-failure-empty-content",
  544. withExtension: "json",
  545. subdirectory: vertexSubdirectory
  546. )
  547. do {
  548. _ = try await model.generateContent(testPrompt)
  549. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  550. } catch let GenerateContentError
  551. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  552. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  553. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  554. return
  555. }
  556. _ = try XCTUnwrap(decodingError as? DecodingError,
  557. "Not a DecodingError: \(decodingError)")
  558. } catch {
  559. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  560. }
  561. }
  562. func testGenerateContent_failure_finishReasonSafety() async throws {
  563. MockURLProtocol
  564. .requestHandler = try httpRequestHandler(
  565. forResource: "unary-failure-finish-reason-safety",
  566. withExtension: "json",
  567. subdirectory: vertexSubdirectory
  568. )
  569. do {
  570. _ = try await model.generateContent(testPrompt)
  571. XCTFail("Should throw")
  572. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  573. XCTAssertEqual(reason, .safety)
  574. XCTAssertEqual(response.text, "<redacted>")
  575. } catch {
  576. XCTFail("Should throw a responseStoppedEarly")
  577. }
  578. }
  579. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  580. MockURLProtocol
  581. .requestHandler = try httpRequestHandler(
  582. forResource: "unary-failure-finish-reason-safety-no-content",
  583. withExtension: "json",
  584. subdirectory: vertexSubdirectory
  585. )
  586. do {
  587. _ = try await model.generateContent(testPrompt)
  588. XCTFail("Should throw")
  589. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  590. XCTAssertEqual(reason, .safety)
  591. XCTAssertNil(response.text)
  592. } catch {
  593. XCTFail("Should throw a responseStoppedEarly")
  594. }
  595. }
  596. func testGenerateContent_failure_imageRejected() async throws {
  597. let expectedStatusCode = 400
  598. MockURLProtocol
  599. .requestHandler = try httpRequestHandler(
  600. forResource: "unary-failure-image-rejected",
  601. withExtension: "json",
  602. subdirectory: vertexSubdirectory,
  603. statusCode: 400
  604. )
  605. do {
  606. _ = try await model.generateContent(testPrompt)
  607. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  608. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  609. XCTAssertEqual(rpcError.status, .invalidArgument)
  610. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  611. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  612. } catch {
  613. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  614. }
  615. }
  616. func testGenerateContent_failure_promptBlockedSafety() async throws {
  617. MockURLProtocol
  618. .requestHandler = try httpRequestHandler(
  619. forResource: "unary-failure-prompt-blocked-safety",
  620. withExtension: "json",
  621. subdirectory: vertexSubdirectory
  622. )
  623. do {
  624. _ = try await model.generateContent(testPrompt)
  625. XCTFail("Should throw")
  626. } catch let GenerateContentError.promptBlocked(response) {
  627. XCTAssertNil(response.text)
  628. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  629. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  630. XCTAssertNil(promptFeedback.blockReasonMessage)
  631. } catch {
  632. XCTFail("Should throw a promptBlocked")
  633. }
  634. }
  635. func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
  636. MockURLProtocol
  637. .requestHandler = try httpRequestHandler(
  638. forResource: "unary-failure-prompt-blocked-safety-with-message",
  639. withExtension: "json",
  640. subdirectory: vertexSubdirectory
  641. )
  642. do {
  643. _ = try await model.generateContent(testPrompt)
  644. XCTFail("Should throw")
  645. } catch let GenerateContentError.promptBlocked(response) {
  646. XCTAssertNil(response.text)
  647. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  648. XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
  649. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  650. } catch {
  651. XCTFail("Should throw a promptBlocked")
  652. }
  653. }
  654. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  655. MockURLProtocol
  656. .requestHandler = try httpRequestHandler(
  657. forResource: "unary-failure-unknown-enum-finish-reason",
  658. withExtension: "json",
  659. subdirectory: vertexSubdirectory
  660. )
  661. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  662. do {
  663. _ = try await model.generateContent(testPrompt)
  664. XCTFail("Should throw")
  665. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  666. XCTAssertEqual(reason, unknownFinishReason)
  667. XCTAssertEqual(response.text, "Some text")
  668. } catch {
  669. XCTFail("Should throw a responseStoppedEarly")
  670. }
  671. }
  672. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  673. MockURLProtocol
  674. .requestHandler = try httpRequestHandler(
  675. forResource: "unary-failure-unknown-enum-prompt-blocked",
  676. withExtension: "json",
  677. subdirectory: vertexSubdirectory
  678. )
  679. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  680. do {
  681. _ = try await model.generateContent(testPrompt)
  682. XCTFail("Should throw")
  683. } catch let GenerateContentError.promptBlocked(response) {
  684. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  685. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  686. } catch {
  687. XCTFail("Should throw a promptBlocked")
  688. }
  689. }
  690. func testGenerateContent_failure_unknownModel() async throws {
  691. let expectedStatusCode = 404
  692. MockURLProtocol
  693. .requestHandler = try httpRequestHandler(
  694. forResource: "unary-failure-unknown-model",
  695. withExtension: "json",
  696. subdirectory: vertexSubdirectory,
  697. statusCode: 404
  698. )
  699. do {
  700. _ = try await model.generateContent(testPrompt)
  701. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  702. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  703. XCTAssertEqual(rpcError.status, .notFound)
  704. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  705. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  706. } catch {
  707. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  708. }
  709. }
  710. func testGenerateContent_failure_nonHTTPResponse() async throws {
  711. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  712. var responseError: Error?
  713. var content: GenerateContentResponse?
  714. do {
  715. content = try await model.generateContent(testPrompt)
  716. } catch {
  717. responseError = error
  718. }
  719. XCTAssertNil(content)
  720. XCTAssertNotNil(responseError)
  721. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  722. guard case let .internalError(underlyingError) = generateContentError else {
  723. XCTFail("Not an internal error: \(generateContentError)")
  724. return
  725. }
  726. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  727. let underlyingNSError = underlyingError as NSError
  728. XCTAssertEqual(underlyingNSError.domain, NSURLErrorDomain)
  729. XCTAssertEqual(underlyingNSError.code, URLError.Code.badServerResponse.rawValue)
  730. }
  731. func testGenerateContent_failure_invalidResponse() async throws {
  732. MockURLProtocol.requestHandler = try httpRequestHandler(
  733. forResource: "unary-failure-invalid-response",
  734. withExtension: "json",
  735. subdirectory: vertexSubdirectory
  736. )
  737. var responseError: Error?
  738. var content: GenerateContentResponse?
  739. do {
  740. content = try await model.generateContent(testPrompt)
  741. } catch {
  742. responseError = error
  743. }
  744. XCTAssertNil(content)
  745. XCTAssertNotNil(responseError)
  746. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  747. guard case let .internalError(underlyingError) = generateContentError else {
  748. XCTFail("Not an internal error: \(generateContentError)")
  749. return
  750. }
  751. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  752. guard case let .dataCorrupted(context) = decodingError else {
  753. XCTFail("Not a data corrupted error: \(decodingError)")
  754. return
  755. }
  756. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  757. }
  758. func testGenerateContent_failure_malformedContent() async throws {
  759. MockURLProtocol
  760. .requestHandler = try httpRequestHandler(
  761. forResource: "unary-failure-malformed-content",
  762. withExtension: "json",
  763. subdirectory: vertexSubdirectory
  764. )
  765. var responseError: Error?
  766. var content: GenerateContentResponse?
  767. do {
  768. content = try await model.generateContent(testPrompt)
  769. } catch {
  770. responseError = error
  771. }
  772. XCTAssertNil(content)
  773. XCTAssertNotNil(responseError)
  774. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  775. guard case let .internalError(underlyingError) = generateContentError else {
  776. XCTFail("Not an internal error: \(generateContentError)")
  777. return
  778. }
  779. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  780. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  781. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  782. return
  783. }
  784. _ = try XCTUnwrap(
  785. malformedContentUnderlyingError as? DecodingError,
  786. "Not a decoding error: \(malformedContentUnderlyingError)"
  787. )
  788. }
  789. func testGenerateContentMissingSafetyRatings() async throws {
  790. MockURLProtocol.requestHandler = try httpRequestHandler(
  791. forResource: "unary-success-missing-safety-ratings",
  792. withExtension: "json",
  793. subdirectory: vertexSubdirectory
  794. )
  795. let content = try await model.generateContent(testPrompt)
  796. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  797. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  798. XCTAssertEqual(content.text, "This is the generated content.")
  799. }
  800. func testGenerateContent_requestOptions_customTimeout() async throws {
  801. let expectedTimeout = 150.0
  802. MockURLProtocol
  803. .requestHandler = try httpRequestHandler(
  804. forResource: "unary-success-basic-reply-short",
  805. withExtension: "json",
  806. subdirectory: vertexSubdirectory,
  807. timeout: expectedTimeout
  808. )
  809. let requestOptions = RequestOptions(timeout: expectedTimeout)
  810. model = GenerativeModel(
  811. name: testModelResourceName,
  812. firebaseInfo: testFirebaseInfo(),
  813. apiConfig: apiConfig,
  814. tools: nil,
  815. requestOptions: requestOptions,
  816. urlSession: urlSession
  817. )
  818. let response = try await model.generateContent(testPrompt)
  819. XCTAssertEqual(response.candidates.count, 1)
  820. }
  821. // MARK: - Generate Content (Streaming)
  822. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  823. MockURLProtocol
  824. .requestHandler = try httpRequestHandler(
  825. forResource: "unary-failure-api-key",
  826. withExtension: "json",
  827. subdirectory: vertexSubdirectory
  828. )
  829. do {
  830. let stream = try model.generateContentStream("Hi")
  831. for try await _ in stream {
  832. XCTFail("No content is there, this shouldn't happen.")
  833. }
  834. } catch let GenerateContentError.internalError(error as BackendError) {
  835. XCTAssertEqual(error.httpResponseCode, 400)
  836. XCTAssertEqual(error.status, .invalidArgument)
  837. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  838. XCTAssertTrue(error.localizedDescription.contains(error.message))
  839. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  840. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  841. let nsError = error as NSError
  842. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  843. XCTAssertEqual(nsError.code, error.httpResponseCode)
  844. return
  845. }
  846. XCTFail("Should have caught an error.")
  847. }
  848. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  849. let expectedStatusCode = 403
  850. MockURLProtocol
  851. .requestHandler = try httpRequestHandler(
  852. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  853. withExtension: "json",
  854. subdirectory: vertexSubdirectory,
  855. statusCode: expectedStatusCode
  856. )
  857. do {
  858. let stream = try model.generateContentStream(testPrompt)
  859. for try await _ in stream {
  860. XCTFail("No content is there, this shouldn't happen.")
  861. }
  862. } catch let GenerateContentError.internalError(error as BackendError) {
  863. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  864. XCTAssertEqual(error.status, .permissionDenied)
  865. XCTAssertTrue(error.message
  866. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  867. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  868. return
  869. }
  870. XCTFail("Should have caught an error.")
  871. }
  872. func testGenerateContentStream_failureEmptyContent() async throws {
  873. MockURLProtocol
  874. .requestHandler = try httpRequestHandler(
  875. forResource: "streaming-failure-empty-content",
  876. withExtension: "txt",
  877. subdirectory: vertexSubdirectory
  878. )
  879. do {
  880. let stream = try model.generateContentStream("Hi")
  881. for try await _ in stream {
  882. XCTFail("No content is there, this shouldn't happen.")
  883. }
  884. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  885. // Underlying error is as expected, nothing else to check.
  886. return
  887. }
  888. XCTFail("Should have caught an error.")
  889. }
  890. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  891. MockURLProtocol
  892. .requestHandler = try httpRequestHandler(
  893. forResource: "streaming-failure-finish-reason-safety",
  894. withExtension: "txt",
  895. subdirectory: vertexSubdirectory
  896. )
  897. do {
  898. let stream = try model.generateContentStream("Hi")
  899. for try await _ in stream {
  900. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  901. }
  902. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  903. XCTAssertEqual(reason, .safety)
  904. let candidate = try XCTUnwrap(response.candidates.first)
  905. XCTAssertEqual(candidate.finishReason, reason)
  906. XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked })
  907. return
  908. }
  909. XCTFail("Should have caught an error.")
  910. }
  911. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  912. MockURLProtocol
  913. .requestHandler = try httpRequestHandler(
  914. forResource: "streaming-failure-prompt-blocked-safety",
  915. withExtension: "txt",
  916. subdirectory: vertexSubdirectory
  917. )
  918. do {
  919. let stream = try model.generateContentStream("Hi")
  920. for try await _ in stream {
  921. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  922. }
  923. } catch let GenerateContentError.promptBlocked(response) {
  924. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  925. XCTAssertEqual(promptFeedback.blockReason, .safety)
  926. XCTAssertNil(promptFeedback.blockReasonMessage)
  927. return
  928. }
  929. XCTFail("Should have caught an error.")
  930. }
  931. func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
  932. MockURLProtocol
  933. .requestHandler = try httpRequestHandler(
  934. forResource: "streaming-failure-prompt-blocked-safety-with-message",
  935. withExtension: "txt",
  936. subdirectory: vertexSubdirectory
  937. )
  938. do {
  939. let stream = try model.generateContentStream("Hi")
  940. for try await _ in stream {
  941. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  942. }
  943. } catch let GenerateContentError.promptBlocked(response) {
  944. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  945. XCTAssertEqual(promptFeedback.blockReason, .safety)
  946. XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
  947. return
  948. }
  949. XCTFail("Should have caught an error.")
  950. }
  951. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  952. MockURLProtocol
  953. .requestHandler = try httpRequestHandler(
  954. forResource: "streaming-failure-unknown-finish-enum",
  955. withExtension: "txt",
  956. subdirectory: vertexSubdirectory
  957. )
  958. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  959. let stream = try model.generateContentStream("Hi")
  960. do {
  961. for try await content in stream {
  962. XCTAssertNotNil(content.text)
  963. }
  964. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  965. XCTAssertEqual(reason, unknownFinishReason)
  966. return
  967. }
  968. XCTFail("Should have caught an error.")
  969. }
  970. func testGenerateContentStream_successBasicReplyLong() async throws {
  971. MockURLProtocol
  972. .requestHandler = try httpRequestHandler(
  973. forResource: "streaming-success-basic-reply-long",
  974. withExtension: "txt",
  975. subdirectory: vertexSubdirectory
  976. )
  977. var responses = 0
  978. let stream = try model.generateContentStream("Hi")
  979. for try await content in stream {
  980. XCTAssertNotNil(content.text)
  981. responses += 1
  982. }
  983. XCTAssertEqual(responses, 6)
  984. }
  985. func testGenerateContentStream_successBasicReplyShort() async throws {
  986. MockURLProtocol
  987. .requestHandler = try httpRequestHandler(
  988. forResource: "streaming-success-basic-reply-short",
  989. withExtension: "txt",
  990. subdirectory: vertexSubdirectory
  991. )
  992. var responses = 0
  993. let stream = try model.generateContentStream("Hi")
  994. for try await content in stream {
  995. XCTAssertNotNil(content.text)
  996. responses += 1
  997. }
  998. XCTAssertEqual(responses, 1)
  999. }
  1000. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  1001. MockURLProtocol
  1002. .requestHandler = try httpRequestHandler(
  1003. forResource: "streaming-success-unknown-safety-enum",
  1004. withExtension: "txt",
  1005. subdirectory: vertexSubdirectory
  1006. )
  1007. let unknownSafetyRating = SafetyRating(
  1008. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  1009. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
  1010. probabilityScore: 0.0,
  1011. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  1012. severityScore: 0.0,
  1013. blocked: false
  1014. )
  1015. var foundUnknownSafetyRating = false
  1016. let stream = try model.generateContentStream("Hi")
  1017. for try await content in stream {
  1018. XCTAssertNotNil(content.text)
  1019. if let ratings = content.candidates.first?.safetyRatings,
  1020. ratings.contains(where: { $0 == unknownSafetyRating }) {
  1021. foundUnknownSafetyRating = true
  1022. }
  1023. }
  1024. XCTAssertTrue(foundUnknownSafetyRating)
  1025. }
  1026. func testGenerateContentStream_successWithCitations() async throws {
  1027. MockURLProtocol
  1028. .requestHandler = try httpRequestHandler(
  1029. forResource: "streaming-success-citations",
  1030. withExtension: "txt",
  1031. subdirectory: vertexSubdirectory
  1032. )
  1033. let expectedPublicationDate = DateComponents(
  1034. calendar: Calendar(identifier: .gregorian),
  1035. year: 2014,
  1036. month: 3,
  1037. day: 30
  1038. )
  1039. let stream = try model.generateContentStream("Hi")
  1040. var citations = [Citation]()
  1041. var responses = [GenerateContentResponse]()
  1042. for try await content in stream {
  1043. responses.append(content)
  1044. XCTAssertNotNil(content.text)
  1045. let candidate = try XCTUnwrap(content.candidates.first)
  1046. if let sources = candidate.citationMetadata?.citations {
  1047. citations.append(contentsOf: sources)
  1048. }
  1049. }
  1050. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  1051. XCTAssertEqual(lastCandidate.finishReason, .stop)
  1052. XCTAssertEqual(citations.count, 6)
  1053. XCTAssertTrue(citations
  1054. .contains {
  1055. $0.startIndex == 0 && $0.endIndex == 128
  1056. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  1057. && $0.license == nil && $0.publicationDate == nil
  1058. })
  1059. XCTAssertTrue(citations
  1060. .contains {
  1061. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  1062. && $0.title == "some-citation-2" && $0.license == nil
  1063. && $0.publicationDate == expectedPublicationDate
  1064. })
  1065. XCTAssertTrue(citations
  1066. .contains {
  1067. $0.startIndex == 272 && $0.endIndex == 431
  1068. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  1069. && $0.license == "mit" && $0.publicationDate == nil
  1070. })
  1071. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  1072. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  1073. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  1074. }
  1075. func testGenerateContentStream_appCheck_validToken() async throws {
  1076. let appCheckToken = "test-valid-token"
  1077. model = GenerativeModel(
  1078. name: testModelResourceName,
  1079. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
  1080. apiConfig: apiConfig,
  1081. tools: nil,
  1082. requestOptions: RequestOptions(),
  1083. urlSession: urlSession
  1084. )
  1085. MockURLProtocol
  1086. .requestHandler = try httpRequestHandler(
  1087. forResource: "streaming-success-basic-reply-short",
  1088. withExtension: "txt",
  1089. subdirectory: vertexSubdirectory,
  1090. appCheckToken: appCheckToken
  1091. )
  1092. let stream = try model.generateContentStream(testPrompt)
  1093. for try await _ in stream {}
  1094. }
  1095. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  1096. model = GenerativeModel(
  1097. name: testModelResourceName,
  1098. firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
  1099. apiConfig: apiConfig,
  1100. tools: nil,
  1101. requestOptions: RequestOptions(),
  1102. urlSession: urlSession
  1103. )
  1104. MockURLProtocol
  1105. .requestHandler = try httpRequestHandler(
  1106. forResource: "streaming-success-basic-reply-short",
  1107. withExtension: "txt",
  1108. subdirectory: vertexSubdirectory,
  1109. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  1110. )
  1111. let stream = try model.generateContentStream(testPrompt)
  1112. for try await _ in stream {}
  1113. }
  1114. func testGenerateContentStream_usageMetadata() async throws {
  1115. MockURLProtocol
  1116. .requestHandler = try httpRequestHandler(
  1117. forResource: "streaming-success-basic-reply-short",
  1118. withExtension: "txt",
  1119. subdirectory: vertexSubdirectory
  1120. )
  1121. var responses = [GenerateContentResponse]()
  1122. let stream = try model.generateContentStream(testPrompt)
  1123. for try await response in stream {
  1124. responses.append(response)
  1125. }
  1126. for (index, response) in responses.enumerated() {
  1127. if index == responses.endIndex - 1 {
  1128. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  1129. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  1130. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  1131. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  1132. } else {
  1133. // Only the last streamed response contains usage metadata
  1134. XCTAssertNil(response.usageMetadata)
  1135. }
  1136. }
  1137. }
  1138. func testGenerateContentStream_errorMidStream() async throws {
  1139. MockURLProtocol.requestHandler = try httpRequestHandler(
  1140. forResource: "streaming-failure-error-mid-stream",
  1141. withExtension: "txt",
  1142. subdirectory: vertexSubdirectory
  1143. )
  1144. var responseCount = 0
  1145. do {
  1146. let stream = try model.generateContentStream("Hi")
  1147. for try await content in stream {
  1148. XCTAssertNotNil(content.text)
  1149. responseCount += 1
  1150. }
  1151. } catch let GenerateContentError.internalError(rpcError as BackendError) {
  1152. XCTAssertEqual(rpcError.httpResponseCode, 499)
  1153. XCTAssertEqual(rpcError.status, .cancelled)
  1154. // Check the content count is correct.
  1155. XCTAssertEqual(responseCount, 2)
  1156. return
  1157. }
  1158. XCTFail("Expected an internalError with an RPCError.")
  1159. }
  1160. func testGenerateContentStream_nonHTTPResponse() async throws {
  1161. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  1162. let stream = try model.generateContentStream("Hi")
  1163. do {
  1164. for try await content in stream {
  1165. XCTFail("Unexpected content in stream: \(content)")
  1166. }
  1167. } catch let GenerateContentError.internalError(underlying) {
  1168. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1169. return
  1170. }
  1171. XCTFail("Expected an internal error.")
  1172. }
  1173. func testGenerateContentStream_invalidResponse() async throws {
  1174. MockURLProtocol
  1175. .requestHandler = try httpRequestHandler(
  1176. forResource: "streaming-failure-invalid-json",
  1177. withExtension: "txt",
  1178. subdirectory: vertexSubdirectory
  1179. )
  1180. let stream = try model.generateContentStream(testPrompt)
  1181. do {
  1182. for try await content in stream {
  1183. XCTFail("Unexpected content in stream: \(content)")
  1184. }
  1185. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1186. guard case let .dataCorrupted(context) = underlying else {
  1187. XCTFail("Not a data corrupted error: \(underlying)")
  1188. return
  1189. }
  1190. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1191. return
  1192. }
  1193. XCTFail("Expected an internal error.")
  1194. }
  1195. func testGenerateContentStream_malformedContent() async throws {
  1196. MockURLProtocol
  1197. .requestHandler = try httpRequestHandler(
  1198. forResource: "streaming-failure-malformed-content",
  1199. withExtension: "txt",
  1200. subdirectory: vertexSubdirectory
  1201. )
  1202. let stream = try model.generateContentStream(testPrompt)
  1203. do {
  1204. for try await content in stream {
  1205. XCTFail("Unexpected content in stream: \(content)")
  1206. }
  1207. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1208. guard case let .malformedContent(contentError) = underlyingError else {
  1209. XCTFail("Not a malformed content error: \(underlyingError)")
  1210. return
  1211. }
  1212. XCTAssert(contentError is DecodingError)
  1213. return
  1214. }
  1215. XCTFail("Expected an internal decoding error.")
  1216. }
  1217. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1218. let expectedTimeout = 150.0
  1219. MockURLProtocol
  1220. .requestHandler = try httpRequestHandler(
  1221. forResource: "streaming-success-basic-reply-short",
  1222. withExtension: "txt",
  1223. subdirectory: vertexSubdirectory,
  1224. timeout: expectedTimeout
  1225. )
  1226. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1227. model = GenerativeModel(
  1228. name: testModelResourceName,
  1229. firebaseInfo: testFirebaseInfo(),
  1230. apiConfig: apiConfig,
  1231. tools: nil,
  1232. requestOptions: requestOptions,
  1233. urlSession: urlSession
  1234. )
  1235. var responses = 0
  1236. let stream = try model.generateContentStream(testPrompt)
  1237. for try await content in stream {
  1238. XCTAssertNotNil(content.text)
  1239. responses += 1
  1240. }
  1241. XCTAssertEqual(responses, 1)
  1242. }
  1243. // MARK: - Count Tokens
  1244. func testCountTokens_succeeds() async throws {
  1245. MockURLProtocol.requestHandler = try httpRequestHandler(
  1246. forResource: "unary-success-total-tokens",
  1247. withExtension: "json",
  1248. subdirectory: vertexSubdirectory
  1249. )
  1250. let response = try await model.countTokens("Why is the sky blue?")
  1251. XCTAssertEqual(response.totalTokens, 6)
  1252. XCTAssertEqual(response.totalBillableCharacters, 16)
  1253. }
  1254. func testCountTokens_succeeds_detailed() async throws {
  1255. MockURLProtocol.requestHandler = try httpRequestHandler(
  1256. forResource: "unary-success-detailed-token-response",
  1257. withExtension: "json",
  1258. subdirectory: vertexSubdirectory
  1259. )
  1260. let response = try await model.countTokens("Why is the sky blue?")
  1261. XCTAssertEqual(response.totalTokens, 1837)
  1262. XCTAssertEqual(response.totalBillableCharacters, 117)
  1263. XCTAssertEqual(response.promptTokensDetails.count, 2)
  1264. XCTAssertEqual(response.promptTokensDetails[0].modality, .image)
  1265. XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806)
  1266. XCTAssertEqual(response.promptTokensDetails[1].modality, .text)
  1267. XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31)
  1268. }
  1269. func testCountTokens_succeeds_allOptions() async throws {
  1270. MockURLProtocol.requestHandler = try httpRequestHandler(
  1271. forResource: "unary-success-total-tokens",
  1272. withExtension: "json",
  1273. subdirectory: vertexSubdirectory
  1274. )
  1275. let generationConfig = GenerationConfig(
  1276. temperature: 0.5,
  1277. topP: 0.9,
  1278. topK: 3,
  1279. candidateCount: 1,
  1280. maxOutputTokens: 1024,
  1281. stopSequences: ["test-stop"],
  1282. responseMIMEType: "text/plain"
  1283. )
  1284. let sumFunction = FunctionDeclaration(
  1285. name: "sum",
  1286. description: "Add two integers.",
  1287. parameters: ["x": .integer(), "y": .integer()]
  1288. )
  1289. let systemInstruction = ModelContent(
  1290. role: "system",
  1291. parts: "You are a calculator. Use the provided tools."
  1292. )
  1293. model = GenerativeModel(
  1294. name: testModelResourceName,
  1295. firebaseInfo: testFirebaseInfo(),
  1296. apiConfig: apiConfig,
  1297. generationConfig: generationConfig,
  1298. tools: [Tool(functionDeclarations: [sumFunction])],
  1299. systemInstruction: systemInstruction,
  1300. requestOptions: RequestOptions(),
  1301. urlSession: urlSession
  1302. )
  1303. let response = try await model.countTokens("Why is the sky blue?")
  1304. XCTAssertEqual(response.totalTokens, 6)
  1305. XCTAssertEqual(response.totalBillableCharacters, 16)
  1306. }
  1307. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1308. MockURLProtocol.requestHandler = try httpRequestHandler(
  1309. forResource: "unary-success-no-billable-characters",
  1310. withExtension: "json",
  1311. subdirectory: vertexSubdirectory
  1312. )
  1313. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1314. XCTAssertEqual(response.totalTokens, 258)
  1315. XCTAssertNil(response.totalBillableCharacters)
  1316. }
  1317. func testCountTokens_modelNotFound() async throws {
  1318. MockURLProtocol.requestHandler = try httpRequestHandler(
  1319. forResource: "unary-failure-model-not-found", withExtension: "json",
  1320. subdirectory: vertexSubdirectory,
  1321. statusCode: 404
  1322. )
  1323. do {
  1324. _ = try await model.countTokens("Why is the sky blue?")
  1325. XCTFail("Request should not have succeeded.")
  1326. } catch let rpcError as BackendError {
  1327. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1328. XCTAssertEqual(rpcError.status, .notFound)
  1329. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1330. return
  1331. }
  1332. XCTFail("Expected internal RPCError.")
  1333. }
  1334. func testCountTokens_requestOptions_customTimeout() async throws {
  1335. let expectedTimeout = 150.0
  1336. MockURLProtocol
  1337. .requestHandler = try httpRequestHandler(
  1338. forResource: "unary-success-total-tokens",
  1339. withExtension: "json",
  1340. subdirectory: vertexSubdirectory,
  1341. timeout: expectedTimeout
  1342. )
  1343. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1344. model = GenerativeModel(
  1345. name: testModelResourceName,
  1346. firebaseInfo: testFirebaseInfo(),
  1347. apiConfig: apiConfig,
  1348. tools: nil,
  1349. requestOptions: requestOptions,
  1350. urlSession: urlSession
  1351. )
  1352. let response = try await model.countTokens(testPrompt)
  1353. XCTAssertEqual(response.totalTokens, 6)
  1354. }
  1355. // MARK: - Helpers
  1356. private func testFirebaseInfo(appCheck: AppCheckInterop? = nil,
  1357. auth: AuthInterop? = nil,
  1358. privateAppID: Bool = false) -> FirebaseInfo {
  1359. let app = FirebaseApp(instanceWithName: "testApp",
  1360. options: FirebaseOptions(googleAppID: "ignore",
  1361. gcmSenderID: "ignore"))
  1362. app.isDataCollectionDefaultEnabled = !privateAppID
  1363. return FirebaseInfo(
  1364. appCheck: appCheck,
  1365. auth: auth,
  1366. projectID: "my-project-id",
  1367. apiKey: "API_KEY",
  1368. firebaseAppID: "My app ID",
  1369. firebaseApp: app
  1370. )
  1371. }
  1372. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1373. URLResponse,
  1374. AsyncLineSequence<URL.AsyncBytes>?
  1375. )) {
  1376. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1377. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1378. #if os(watchOS)
  1379. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1380. #endif // os(watchOS)
  1381. return { request in
  1382. // This is *not* an HTTPURLResponse
  1383. let response = URLResponse(
  1384. url: request.url!,
  1385. mimeType: nil,
  1386. expectedContentLength: 0,
  1387. textEncodingName: nil
  1388. )
  1389. return (response, nil)
  1390. }
  1391. }
  1392. private func httpRequestHandler(forResource name: String,
  1393. withExtension ext: String,
  1394. subdirectory subpath: String,
  1395. statusCode: Int = 200,
  1396. timeout: TimeInterval = RequestOptions().timeout,
  1397. appCheckToken: String? = nil,
  1398. authToken: String? = nil,
  1399. dataCollection: Bool = true) throws -> ((URLRequest) throws -> (
  1400. URLResponse,
  1401. AsyncLineSequence<URL.AsyncBytes>?
  1402. )) {
  1403. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1404. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1405. #if os(watchOS)
  1406. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1407. #endif // os(watchOS)
  1408. let bundle = BundleTestUtil.bundle()
  1409. let fileURL = try XCTUnwrap(
  1410. bundle.url(forResource: name, withExtension: ext, subdirectory: subpath)
  1411. )
  1412. return { request in
  1413. let requestURL = try XCTUnwrap(request.url)
  1414. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1415. XCTAssertEqual(request.timeoutInterval, timeout)
  1416. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1417. .components(separatedBy: " ")
  1418. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1419. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1420. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1421. let firebaseAppID = request.value(forHTTPHeaderField: "X-Firebase-AppId")
  1422. let appVersion = request.value(forHTTPHeaderField: "X-Firebase-AppVersion")
  1423. let expectedAppVersion =
  1424. try? XCTUnwrap(Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String)
  1425. XCTAssertEqual(firebaseAppID, dataCollection ? "My app ID" : nil)
  1426. XCTAssertEqual(appVersion, dataCollection ? expectedAppVersion : nil)
  1427. if let authToken {
  1428. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1429. } else {
  1430. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1431. }
  1432. let response = try XCTUnwrap(HTTPURLResponse(
  1433. url: requestURL,
  1434. statusCode: statusCode,
  1435. httpVersion: nil,
  1436. headerFields: nil
  1437. ))
  1438. return (response, fileURL.lines)
  1439. }
  1440. }
  1441. }
  1442. private extension String {
  1443. /// Returns the number of occurrences of `substring` in the `String`.
  1444. func occurrenceCount(of substring: String) -> Int {
  1445. return components(separatedBy: substring).count - 1
  1446. }
  1447. }
  1448. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1449. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1450. /// The placeholder token value returned when an error occurs
  1451. static let placeholderTokenValue = "placeholder-token"
  1452. var token: String
  1453. var error: Error?
  1454. private init(token: String, error: Error?) {
  1455. self.token = token
  1456. self.error = error
  1457. }
  1458. convenience init(token: String) {
  1459. self.init(token: token, error: nil)
  1460. }
  1461. convenience init(error: Error) {
  1462. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1463. }
  1464. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1465. return AppCheckTokenResultInteropFake(token: token, error: error)
  1466. }
  1467. func tokenDidChangeNotificationName() -> String {
  1468. fatalError("\(#function) not implemented.")
  1469. }
  1470. func notificationTokenKey() -> String {
  1471. fatalError("\(#function) not implemented.")
  1472. }
  1473. func notificationAppNameKey() -> String {
  1474. fatalError("\(#function) not implemented.")
  1475. }
  1476. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1477. var token: String
  1478. var error: Error?
  1479. init(token: String, error: Error?) {
  1480. self.token = token
  1481. self.error = error
  1482. }
  1483. }
  1484. }
  1485. struct AppCheckErrorFake: Error {}
  1486. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1487. extension SafetyRating: Swift.Comparable {
  1488. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1489. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1490. return lhs.category.rawValue < rhs.category.rawValue
  1491. }
  1492. }