GenerativeModelTests.swift 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. let testModelResourceName =
  29. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  30. var urlSession: URLSession!
  31. var model: GenerativeModel!
  32. override func setUp() async throws {
  33. let configuration = URLSessionConfiguration.default
  34. configuration.protocolClasses = [MockURLProtocol.self]
  35. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  36. model = GenerativeModel(
  37. name: testModelResourceName,
  38. projectID: "my-project-id",
  39. apiKey: "API_KEY",
  40. tools: nil,
  41. requestOptions: RequestOptions(),
  42. appCheck: nil,
  43. auth: nil,
  44. urlSession: urlSession
  45. )
  46. }
  47. override func tearDown() {
  48. MockURLProtocol.requestHandler = nil
  49. }
  50. // MARK: - Generate Content
  51. func testGenerateContent_success_basicReplyLong() async throws {
  52. MockURLProtocol
  53. .requestHandler = try httpRequestHandler(
  54. forResource: "unary-success-basic-reply-long",
  55. withExtension: "json"
  56. )
  57. let response = try await model.generateContent(testPrompt)
  58. XCTAssertEqual(response.candidates.count, 1)
  59. let candidate = try XCTUnwrap(response.candidates.first)
  60. let finishReason = try XCTUnwrap(candidate.finishReason)
  61. XCTAssertEqual(finishReason, .stop)
  62. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  63. XCTAssertEqual(candidate.content.parts.count, 1)
  64. let part = try XCTUnwrap(candidate.content.parts.first)
  65. let partText = try XCTUnwrap(part as? TextPart).text
  66. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  67. XCTAssertEqual(response.text, partText)
  68. XCTAssertEqual(response.functionCalls, [])
  69. }
  70. func testGenerateContent_success_basicReplyShort() async throws {
  71. MockURLProtocol
  72. .requestHandler = try httpRequestHandler(
  73. forResource: "unary-success-basic-reply-short",
  74. withExtension: "json"
  75. )
  76. let response = try await model.generateContent(testPrompt)
  77. XCTAssertEqual(response.candidates.count, 1)
  78. let candidate = try XCTUnwrap(response.candidates.first)
  79. let finishReason = try XCTUnwrap(candidate.finishReason)
  80. XCTAssertEqual(finishReason, .stop)
  81. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  82. XCTAssertEqual(candidate.content.parts.count, 1)
  83. let part = try XCTUnwrap(candidate.content.parts.first)
  84. let textPart = try XCTUnwrap(part as? TextPart)
  85. XCTAssertEqual(textPart.text, "Mountain View, California")
  86. XCTAssertEqual(response.text, textPart.text)
  87. XCTAssertEqual(response.functionCalls, [])
  88. }
  89. func testGenerateContent_success_citations() async throws {
  90. MockURLProtocol
  91. .requestHandler = try httpRequestHandler(
  92. forResource: "unary-success-citations",
  93. withExtension: "json"
  94. )
  95. let response = try await model.generateContent(testPrompt)
  96. XCTAssertEqual(response.candidates.count, 1)
  97. let candidate = try XCTUnwrap(response.candidates.first)
  98. XCTAssertEqual(candidate.content.parts.count, 1)
  99. XCTAssertEqual(response.text, "Some information cited from an external source")
  100. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  101. XCTAssertEqual(citationMetadata.citations.count, 3)
  102. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  103. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  104. XCTAssertEqual(citationSource1.startIndex, 0)
  105. XCTAssertEqual(citationSource1.endIndex, 128)
  106. XCTAssertNil(citationSource1.title)
  107. XCTAssertNil(citationSource1.license)
  108. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  109. XCTAssertEqual(citationSource2.title, "some-citation-2")
  110. XCTAssertEqual(citationSource2.startIndex, 130)
  111. XCTAssertEqual(citationSource2.endIndex, 265)
  112. XCTAssertNil(citationSource2.uri)
  113. XCTAssertNil(citationSource2.license)
  114. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  115. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  116. XCTAssertEqual(citationSource3.startIndex, 272)
  117. XCTAssertEqual(citationSource3.endIndex, 431)
  118. XCTAssertEqual(citationSource3.license, "mit")
  119. XCTAssertNil(citationSource3.title)
  120. }
  121. func testGenerateContent_success_quoteReply() async throws {
  122. MockURLProtocol
  123. .requestHandler = try httpRequestHandler(
  124. forResource: "unary-success-quote-reply",
  125. withExtension: "json"
  126. )
  127. let response = try await model.generateContent(testPrompt)
  128. XCTAssertEqual(response.candidates.count, 1)
  129. let candidate = try XCTUnwrap(response.candidates.first)
  130. let finishReason = try XCTUnwrap(candidate.finishReason)
  131. XCTAssertEqual(finishReason, .stop)
  132. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  133. XCTAssertEqual(candidate.content.parts.count, 1)
  134. let part = try XCTUnwrap(candidate.content.parts.first)
  135. let textPart = try XCTUnwrap(part as? TextPart)
  136. XCTAssertTrue(textPart.text.hasPrefix("Google"))
  137. XCTAssertEqual(response.text, textPart.text)
  138. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  139. XCTAssertNil(promptFeedback.blockReason)
  140. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  141. }
  142. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  143. let expectedSafetyRatings = [
  144. SafetyRating(category: .harassment, probability: .medium),
  145. SafetyRating(
  146. category: .dangerousContent,
  147. probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY")
  148. ),
  149. SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
  150. ]
  151. MockURLProtocol
  152. .requestHandler = try httpRequestHandler(
  153. forResource: "unary-success-unknown-enum-safety-ratings",
  154. withExtension: "json"
  155. )
  156. let response = try await model.generateContent(testPrompt)
  157. XCTAssertEqual(response.text, "Some text")
  158. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  159. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  160. }
  161. func testGenerateContent_success_prefixedModelName() async throws {
  162. MockURLProtocol
  163. .requestHandler = try httpRequestHandler(
  164. forResource: "unary-success-basic-reply-short",
  165. withExtension: "json"
  166. )
  167. let model = GenerativeModel(
  168. // Model name is prefixed with "models/".
  169. name: "models/test-model",
  170. projectID: "my-project-id",
  171. apiKey: "API_KEY",
  172. tools: nil,
  173. requestOptions: RequestOptions(),
  174. appCheck: nil,
  175. auth: nil,
  176. urlSession: urlSession
  177. )
  178. _ = try await model.generateContent(testPrompt)
  179. }
  180. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  181. MockURLProtocol
  182. .requestHandler = try httpRequestHandler(
  183. forResource: "unary-success-function-call-empty-arguments",
  184. withExtension: "json"
  185. )
  186. let response = try await model.generateContent(testPrompt)
  187. XCTAssertEqual(response.candidates.count, 1)
  188. let candidate = try XCTUnwrap(response.candidates.first)
  189. XCTAssertEqual(candidate.content.parts.count, 1)
  190. let part = try XCTUnwrap(candidate.content.parts.first)
  191. guard let functionCall = part as? FunctionCallPart else {
  192. XCTFail("Part is not a FunctionCall.")
  193. return
  194. }
  195. XCTAssertEqual(functionCall.name, "current_time")
  196. XCTAssertTrue(functionCall.args.isEmpty)
  197. XCTAssertEqual(response.functionCalls, [functionCall])
  198. }
  199. func testGenerateContent_success_functionCall_noArguments() async throws {
  200. MockURLProtocol
  201. .requestHandler = try httpRequestHandler(
  202. forResource: "unary-success-function-call-no-arguments",
  203. withExtension: "json"
  204. )
  205. let response = try await model.generateContent(testPrompt)
  206. XCTAssertEqual(response.candidates.count, 1)
  207. let candidate = try XCTUnwrap(response.candidates.first)
  208. XCTAssertEqual(candidate.content.parts.count, 1)
  209. let part = try XCTUnwrap(candidate.content.parts.first)
  210. guard let functionCall = part as? FunctionCallPart else {
  211. XCTFail("Part is not a FunctionCall.")
  212. return
  213. }
  214. XCTAssertEqual(functionCall.name, "current_time")
  215. XCTAssertTrue(functionCall.args.isEmpty)
  216. XCTAssertEqual(response.functionCalls, [functionCall])
  217. }
  218. func testGenerateContent_success_functionCall_withArguments() async throws {
  219. MockURLProtocol
  220. .requestHandler = try httpRequestHandler(
  221. forResource: "unary-success-function-call-with-arguments",
  222. withExtension: "json"
  223. )
  224. let response = try await model.generateContent(testPrompt)
  225. XCTAssertEqual(response.candidates.count, 1)
  226. let candidate = try XCTUnwrap(response.candidates.first)
  227. XCTAssertEqual(candidate.content.parts.count, 1)
  228. let part = try XCTUnwrap(candidate.content.parts.first)
  229. guard let functionCall = part as? FunctionCallPart else {
  230. XCTFail("Part is not a FunctionCall.")
  231. return
  232. }
  233. XCTAssertEqual(functionCall.name, "sum")
  234. XCTAssertEqual(functionCall.args.count, 2)
  235. let argX = try XCTUnwrap(functionCall.args["x"])
  236. XCTAssertEqual(argX, .number(4))
  237. let argY = try XCTUnwrap(functionCall.args["y"])
  238. XCTAssertEqual(argY, .number(5))
  239. XCTAssertEqual(response.functionCalls, [functionCall])
  240. }
  241. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  242. MockURLProtocol
  243. .requestHandler = try httpRequestHandler(
  244. forResource: "unary-success-function-call-parallel-calls",
  245. withExtension: "json"
  246. )
  247. let response = try await model.generateContent(testPrompt)
  248. XCTAssertEqual(response.candidates.count, 1)
  249. let candidate = try XCTUnwrap(response.candidates.first)
  250. XCTAssertEqual(candidate.content.parts.count, 3)
  251. let functionCalls = response.functionCalls
  252. XCTAssertEqual(functionCalls.count, 3)
  253. }
  254. func testGenerateContent_success_functionCall_mixedContent() async throws {
  255. MockURLProtocol
  256. .requestHandler = try httpRequestHandler(
  257. forResource: "unary-success-function-call-mixed-content",
  258. withExtension: "json"
  259. )
  260. let response = try await model.generateContent(testPrompt)
  261. XCTAssertEqual(response.candidates.count, 1)
  262. let candidate = try XCTUnwrap(response.candidates.first)
  263. XCTAssertEqual(candidate.content.parts.count, 4)
  264. let functionCalls = response.functionCalls
  265. XCTAssertEqual(functionCalls.count, 2)
  266. let text = try XCTUnwrap(response.text)
  267. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  268. }
  269. func testGenerateContent_appCheck_validToken() async throws {
  270. let appCheckToken = "test-valid-token"
  271. model = GenerativeModel(
  272. name: testModelResourceName,
  273. projectID: "my-project-id",
  274. apiKey: "API_KEY",
  275. tools: nil,
  276. requestOptions: RequestOptions(),
  277. appCheck: AppCheckInteropFake(token: appCheckToken),
  278. auth: nil,
  279. urlSession: urlSession
  280. )
  281. MockURLProtocol
  282. .requestHandler = try httpRequestHandler(
  283. forResource: "unary-success-basic-reply-short",
  284. withExtension: "json",
  285. appCheckToken: appCheckToken
  286. )
  287. _ = try await model.generateContent(testPrompt)
  288. }
  289. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  290. model = GenerativeModel(
  291. name: testModelResourceName,
  292. projectID: "my-project-id",
  293. apiKey: "API_KEY",
  294. tools: nil,
  295. requestOptions: RequestOptions(),
  296. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  297. auth: nil,
  298. urlSession: urlSession
  299. )
  300. MockURLProtocol
  301. .requestHandler = try httpRequestHandler(
  302. forResource: "unary-success-basic-reply-short",
  303. withExtension: "json",
  304. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  305. )
  306. _ = try await model.generateContent(testPrompt)
  307. }
  308. func testGenerateContent_auth_validAuthToken() async throws {
  309. let authToken = "test-valid-token"
  310. model = GenerativeModel(
  311. name: testModelResourceName,
  312. projectID: "my-project-id",
  313. apiKey: "API_KEY",
  314. tools: nil,
  315. requestOptions: RequestOptions(),
  316. appCheck: nil,
  317. auth: AuthInteropFake(token: authToken),
  318. urlSession: urlSession
  319. )
  320. MockURLProtocol
  321. .requestHandler = try httpRequestHandler(
  322. forResource: "unary-success-basic-reply-short",
  323. withExtension: "json",
  324. authToken: authToken
  325. )
  326. _ = try await model.generateContent(testPrompt)
  327. }
  328. func testGenerateContent_auth_nilAuthToken() async throws {
  329. model = GenerativeModel(
  330. name: testModelResourceName,
  331. projectID: "my-project-id",
  332. apiKey: "API_KEY",
  333. tools: nil,
  334. requestOptions: RequestOptions(),
  335. appCheck: nil,
  336. auth: AuthInteropFake(token: nil),
  337. urlSession: urlSession
  338. )
  339. MockURLProtocol
  340. .requestHandler = try httpRequestHandler(
  341. forResource: "unary-success-basic-reply-short",
  342. withExtension: "json",
  343. authToken: nil
  344. )
  345. _ = try await model.generateContent(testPrompt)
  346. }
  347. func testGenerateContent_auth_authTokenRefreshError() async throws {
  348. model = GenerativeModel(
  349. name: "my-model",
  350. projectID: "my-project-id",
  351. apiKey: "API_KEY",
  352. tools: nil,
  353. requestOptions: RequestOptions(),
  354. appCheck: nil,
  355. auth: AuthInteropFake(error: AuthErrorFake()),
  356. urlSession: urlSession
  357. )
  358. MockURLProtocol
  359. .requestHandler = try httpRequestHandler(
  360. forResource: "unary-success-basic-reply-short",
  361. withExtension: "json",
  362. authToken: nil
  363. )
  364. do {
  365. _ = try await model.generateContent(testPrompt)
  366. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  367. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  368. //
  369. } catch {
  370. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  371. }
  372. }
  373. func testGenerateContent_usageMetadata() async throws {
  374. MockURLProtocol
  375. .requestHandler = try httpRequestHandler(
  376. forResource: "unary-success-basic-reply-short",
  377. withExtension: "json"
  378. )
  379. let response = try await model.generateContent(testPrompt)
  380. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  381. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  382. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  383. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  384. }
  385. func testGenerateContent_failure_invalidAPIKey() async throws {
  386. let expectedStatusCode = 400
  387. MockURLProtocol
  388. .requestHandler = try httpRequestHandler(
  389. forResource: "unary-failure-api-key",
  390. withExtension: "json",
  391. statusCode: expectedStatusCode
  392. )
  393. do {
  394. _ = try await model.generateContent(testPrompt)
  395. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  396. } catch let GenerateContentError.internalError(error as RPCError) {
  397. XCTAssertEqual(error.httpResponseCode, 400)
  398. XCTAssertEqual(error.status, .invalidArgument)
  399. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  400. return
  401. } catch {
  402. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  403. }
  404. }
  405. // TODO(andrewheard): Remove this test case after the Vertex AI in Firebase API launch.
  406. func testGenerateContent_failure_firebaseMLAPINotEnabled() async throws {
  407. let expectedStatusCode = 403
  408. MockURLProtocol
  409. .requestHandler = try httpRequestHandler(
  410. forResource: "unary-failure-firebaseml-api-not-enabled",
  411. withExtension: "json",
  412. statusCode: expectedStatusCode
  413. )
  414. do {
  415. _ = try await model.generateContent(testPrompt)
  416. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  417. } catch let GenerateContentError.internalError(error as RPCError) {
  418. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  419. XCTAssertEqual(error.status, .permissionDenied)
  420. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  421. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  422. return
  423. } catch {
  424. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  425. }
  426. }
  427. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  428. let expectedStatusCode = 403
  429. MockURLProtocol
  430. .requestHandler = try httpRequestHandler(
  431. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  432. withExtension: "json",
  433. statusCode: expectedStatusCode
  434. )
  435. do {
  436. _ = try await model.generateContent(testPrompt)
  437. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  438. } catch let GenerateContentError.internalError(error as RPCError) {
  439. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  440. XCTAssertEqual(error.status, .permissionDenied)
  441. XCTAssertTrue(error.message
  442. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  443. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  444. return
  445. } catch {
  446. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  447. }
  448. }
  449. func testGenerateContent_failure_emptyContent() async throws {
  450. MockURLProtocol
  451. .requestHandler = try httpRequestHandler(
  452. forResource: "unary-failure-empty-content",
  453. withExtension: "json"
  454. )
  455. do {
  456. _ = try await model.generateContent(testPrompt)
  457. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  458. } catch let GenerateContentError
  459. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  460. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  461. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  462. return
  463. }
  464. _ = try XCTUnwrap(decodingError as? DecodingError,
  465. "Not a DecodingError: \(decodingError)")
  466. } catch {
  467. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  468. }
  469. }
  470. func testGenerateContent_failure_finishReasonSafety() async throws {
  471. MockURLProtocol
  472. .requestHandler = try httpRequestHandler(
  473. forResource: "unary-failure-finish-reason-safety",
  474. withExtension: "json"
  475. )
  476. do {
  477. _ = try await model.generateContent(testPrompt)
  478. XCTFail("Should throw")
  479. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  480. XCTAssertEqual(reason, .safety)
  481. XCTAssertEqual(response.text, "<redacted>")
  482. } catch {
  483. XCTFail("Should throw a responseStoppedEarly")
  484. }
  485. }
  486. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  487. MockURLProtocol
  488. .requestHandler = try httpRequestHandler(
  489. forResource: "unary-failure-finish-reason-safety-no-content",
  490. withExtension: "json"
  491. )
  492. do {
  493. _ = try await model.generateContent(testPrompt)
  494. XCTFail("Should throw")
  495. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  496. XCTAssertEqual(reason, .safety)
  497. XCTAssertNil(response.text)
  498. } catch {
  499. XCTFail("Should throw a responseStoppedEarly")
  500. }
  501. }
  502. func testGenerateContent_failure_imageRejected() async throws {
  503. let expectedStatusCode = 400
  504. MockURLProtocol
  505. .requestHandler = try httpRequestHandler(
  506. forResource: "unary-failure-image-rejected",
  507. withExtension: "json",
  508. statusCode: 400
  509. )
  510. do {
  511. _ = try await model.generateContent(testPrompt)
  512. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  513. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  514. XCTAssertEqual(rpcError.status, .invalidArgument)
  515. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  516. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  517. } catch {
  518. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  519. }
  520. }
  521. func testGenerateContent_failure_promptBlockedSafety() async throws {
  522. MockURLProtocol
  523. .requestHandler = try httpRequestHandler(
  524. forResource: "unary-failure-prompt-blocked-safety",
  525. withExtension: "json"
  526. )
  527. do {
  528. _ = try await model.generateContent(testPrompt)
  529. XCTFail("Should throw")
  530. } catch let GenerateContentError.promptBlocked(response) {
  531. XCTAssertNil(response.text)
  532. } catch {
  533. XCTFail("Should throw a promptBlocked")
  534. }
  535. }
  536. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  537. MockURLProtocol
  538. .requestHandler = try httpRequestHandler(
  539. forResource: "unary-failure-unknown-enum-finish-reason",
  540. withExtension: "json"
  541. )
  542. let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
  543. do {
  544. _ = try await model.generateContent(testPrompt)
  545. XCTFail("Should throw")
  546. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  547. XCTAssertEqual(reason, unknownFinishReason)
  548. XCTAssertEqual(response.text, "Some text")
  549. } catch {
  550. XCTFail("Should throw a responseStoppedEarly")
  551. }
  552. }
  553. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  554. MockURLProtocol
  555. .requestHandler = try httpRequestHandler(
  556. forResource: "unary-failure-unknown-enum-prompt-blocked",
  557. withExtension: "json"
  558. )
  559. let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
  560. do {
  561. _ = try await model.generateContent(testPrompt)
  562. XCTFail("Should throw")
  563. } catch let GenerateContentError.promptBlocked(response) {
  564. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  565. XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
  566. } catch {
  567. XCTFail("Should throw a promptBlocked")
  568. }
  569. }
  570. func testGenerateContent_failure_unknownModel() async throws {
  571. let expectedStatusCode = 404
  572. MockURLProtocol
  573. .requestHandler = try httpRequestHandler(
  574. forResource: "unary-failure-unknown-model",
  575. withExtension: "json",
  576. statusCode: 404
  577. )
  578. do {
  579. _ = try await model.generateContent(testPrompt)
  580. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  581. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  582. XCTAssertEqual(rpcError.status, .notFound)
  583. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  584. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  585. } catch {
  586. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  587. }
  588. }
  589. func testGenerateContent_failure_nonHTTPResponse() async throws {
  590. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  591. var responseError: Error?
  592. var content: GenerateContentResponse?
  593. do {
  594. content = try await model.generateContent(testPrompt)
  595. } catch {
  596. responseError = error
  597. }
  598. XCTAssertNil(content)
  599. XCTAssertNotNil(responseError)
  600. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  601. guard case let .internalError(underlyingError) = generateContentError else {
  602. XCTFail("Not an internal error: \(generateContentError)")
  603. return
  604. }
  605. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  606. }
  607. func testGenerateContent_failure_invalidResponse() async throws {
  608. MockURLProtocol.requestHandler = try httpRequestHandler(
  609. forResource: "unary-failure-invalid-response",
  610. withExtension: "json"
  611. )
  612. var responseError: Error?
  613. var content: GenerateContentResponse?
  614. do {
  615. content = try await model.generateContent(testPrompt)
  616. } catch {
  617. responseError = error
  618. }
  619. XCTAssertNil(content)
  620. XCTAssertNotNil(responseError)
  621. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  622. guard case let .internalError(underlyingError) = generateContentError else {
  623. XCTFail("Not an internal error: \(generateContentError)")
  624. return
  625. }
  626. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  627. guard case let .dataCorrupted(context) = decodingError else {
  628. XCTFail("Not a data corrupted error: \(decodingError)")
  629. return
  630. }
  631. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  632. }
  633. func testGenerateContent_failure_malformedContent() async throws {
  634. MockURLProtocol
  635. .requestHandler = try httpRequestHandler(
  636. forResource: "unary-failure-malformed-content",
  637. withExtension: "json"
  638. )
  639. var responseError: Error?
  640. var content: GenerateContentResponse?
  641. do {
  642. content = try await model.generateContent(testPrompt)
  643. } catch {
  644. responseError = error
  645. }
  646. XCTAssertNil(content)
  647. XCTAssertNotNil(responseError)
  648. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  649. guard case let .internalError(underlyingError) = generateContentError else {
  650. XCTFail("Not an internal error: \(generateContentError)")
  651. return
  652. }
  653. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  654. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  655. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  656. return
  657. }
  658. _ = try XCTUnwrap(
  659. malformedContentUnderlyingError as? DecodingError,
  660. "Not a decoding error: \(malformedContentUnderlyingError)"
  661. )
  662. }
  663. func testGenerateContentMissingSafetyRatings() async throws {
  664. MockURLProtocol.requestHandler = try httpRequestHandler(
  665. forResource: "unary-success-missing-safety-ratings",
  666. withExtension: "json"
  667. )
  668. let content = try await model.generateContent(testPrompt)
  669. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  670. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  671. XCTAssertEqual(content.text, "This is the generated content.")
  672. }
  673. func testGenerateContent_requestOptions_customTimeout() async throws {
  674. let expectedTimeout = 150.0
  675. MockURLProtocol
  676. .requestHandler = try httpRequestHandler(
  677. forResource: "unary-success-basic-reply-short",
  678. withExtension: "json",
  679. timeout: expectedTimeout
  680. )
  681. let requestOptions = RequestOptions(timeout: expectedTimeout)
  682. model = GenerativeModel(
  683. name: testModelResourceName,
  684. projectID: "my-project-id",
  685. apiKey: "API_KEY",
  686. tools: nil,
  687. requestOptions: requestOptions,
  688. appCheck: nil,
  689. auth: nil,
  690. urlSession: urlSession
  691. )
  692. let response = try await model.generateContent(testPrompt)
  693. XCTAssertEqual(response.candidates.count, 1)
  694. }
  695. // MARK: - Generate Content (Streaming)
  696. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  697. MockURLProtocol
  698. .requestHandler = try httpRequestHandler(
  699. forResource: "unary-failure-api-key",
  700. withExtension: "json"
  701. )
  702. do {
  703. let stream = try model.generateContentStream("Hi")
  704. for try await _ in stream {
  705. XCTFail("No content is there, this shouldn't happen.")
  706. }
  707. } catch let GenerateContentError.internalError(error as RPCError) {
  708. XCTAssertEqual(error.httpResponseCode, 400)
  709. XCTAssertEqual(error.status, .invalidArgument)
  710. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  711. return
  712. }
  713. XCTFail("Should have caught an error.")
  714. }
  715. // TODO(andrewheard): Remove this test case after the Vertex AI in Firebase API launch.
  716. func testGenerateContentStream_failure_firebaseMLAPINotEnabled() async throws {
  717. let expectedStatusCode = 403
  718. MockURLProtocol
  719. .requestHandler = try httpRequestHandler(
  720. forResource: "unary-failure-firebaseml-api-not-enabled",
  721. withExtension: "json",
  722. statusCode: expectedStatusCode
  723. )
  724. do {
  725. let stream = try model.generateContentStream(testPrompt)
  726. for try await _ in stream {
  727. XCTFail("No content is there, this shouldn't happen.")
  728. }
  729. } catch let GenerateContentError.internalError(error as RPCError) {
  730. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  731. XCTAssertEqual(error.status, .permissionDenied)
  732. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  733. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  734. return
  735. }
  736. XCTFail("Should have caught an error.")
  737. }
  738. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  739. let expectedStatusCode = 403
  740. MockURLProtocol
  741. .requestHandler = try httpRequestHandler(
  742. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  743. withExtension: "json",
  744. statusCode: expectedStatusCode
  745. )
  746. do {
  747. let stream = try model.generateContentStream(testPrompt)
  748. for try await _ in stream {
  749. XCTFail("No content is there, this shouldn't happen.")
  750. }
  751. } catch let GenerateContentError.internalError(error as RPCError) {
  752. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  753. XCTAssertEqual(error.status, .permissionDenied)
  754. XCTAssertTrue(error.message
  755. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  756. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  757. return
  758. }
  759. XCTFail("Should have caught an error.")
  760. }
  761. func testGenerateContentStream_failureEmptyContent() async throws {
  762. MockURLProtocol
  763. .requestHandler = try httpRequestHandler(
  764. forResource: "streaming-failure-empty-content",
  765. withExtension: "txt"
  766. )
  767. do {
  768. let stream = try model.generateContentStream("Hi")
  769. for try await _ in stream {
  770. XCTFail("No content is there, this shouldn't happen.")
  771. }
  772. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  773. // Underlying error is as expected, nothing else to check.
  774. return
  775. }
  776. XCTFail("Should have caught an error.")
  777. }
  778. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  779. MockURLProtocol
  780. .requestHandler = try httpRequestHandler(
  781. forResource: "streaming-failure-finish-reason-safety",
  782. withExtension: "txt"
  783. )
  784. do {
  785. let stream = try model.generateContentStream("Hi")
  786. for try await _ in stream {
  787. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  788. }
  789. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  790. XCTAssertEqual(reason, .safety)
  791. return
  792. }
  793. XCTFail("Should have caught an error.")
  794. }
  795. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  796. MockURLProtocol
  797. .requestHandler = try httpRequestHandler(
  798. forResource: "streaming-failure-prompt-blocked-safety",
  799. withExtension: "txt"
  800. )
  801. do {
  802. let stream = try model.generateContentStream("Hi")
  803. for try await _ in stream {
  804. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  805. }
  806. } catch let GenerateContentError.promptBlocked(response) {
  807. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  808. return
  809. }
  810. XCTFail("Should have caught an error.")
  811. }
  812. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  813. MockURLProtocol
  814. .requestHandler = try httpRequestHandler(
  815. forResource: "streaming-failure-unknown-finish-enum",
  816. withExtension: "txt"
  817. )
  818. let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
  819. let stream = try model.generateContentStream("Hi")
  820. do {
  821. for try await content in stream {
  822. XCTAssertNotNil(content.text)
  823. }
  824. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  825. XCTAssertEqual(reason, unknownFinishReason)
  826. return
  827. }
  828. XCTFail("Should have caught an error.")
  829. }
  830. func testGenerateContentStream_successBasicReplyLong() async throws {
  831. MockURLProtocol
  832. .requestHandler = try httpRequestHandler(
  833. forResource: "streaming-success-basic-reply-long",
  834. withExtension: "txt"
  835. )
  836. var responses = 0
  837. let stream = try model.generateContentStream("Hi")
  838. for try await content in stream {
  839. XCTAssertNotNil(content.text)
  840. responses += 1
  841. }
  842. XCTAssertEqual(responses, 6)
  843. }
  844. func testGenerateContentStream_successBasicReplyShort() async throws {
  845. MockURLProtocol
  846. .requestHandler = try httpRequestHandler(
  847. forResource: "streaming-success-basic-reply-short",
  848. withExtension: "txt"
  849. )
  850. var responses = 0
  851. let stream = try model.generateContentStream("Hi")
  852. for try await content in stream {
  853. XCTAssertNotNil(content.text)
  854. responses += 1
  855. }
  856. XCTAssertEqual(responses, 1)
  857. }
  858. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  859. MockURLProtocol
  860. .requestHandler = try httpRequestHandler(
  861. forResource: "streaming-success-unknown-safety-enum",
  862. withExtension: "txt"
  863. )
  864. let unknownSafetyRating = SafetyRating(
  865. category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
  866. probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM")
  867. )
  868. var foundUnknownSafetyRating = false
  869. let stream = try model.generateContentStream("Hi")
  870. for try await content in stream {
  871. XCTAssertNotNil(content.text)
  872. if let ratings = content.candidates.first?.safetyRatings,
  873. ratings.contains(where: { $0 == unknownSafetyRating }) {
  874. foundUnknownSafetyRating = true
  875. }
  876. }
  877. XCTAssertTrue(foundUnknownSafetyRating)
  878. }
  879. func testGenerateContentStream_successWithCitations() async throws {
  880. MockURLProtocol
  881. .requestHandler = try httpRequestHandler(
  882. forResource: "streaming-success-citations",
  883. withExtension: "txt"
  884. )
  885. let stream = try model.generateContentStream("Hi")
  886. var citations = [Citation]()
  887. var responses = [GenerateContentResponse]()
  888. for try await content in stream {
  889. responses.append(content)
  890. XCTAssertNotNil(content.text)
  891. let candidate = try XCTUnwrap(content.candidates.first)
  892. if let sources = candidate.citationMetadata?.citations {
  893. citations.append(contentsOf: sources)
  894. }
  895. }
  896. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  897. XCTAssertEqual(lastCandidate.finishReason, .stop)
  898. XCTAssertEqual(citations.count, 6)
  899. XCTAssertTrue(citations
  900. .contains {
  901. $0.startIndex == 0 && $0.endIndex == 128
  902. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  903. && $0.license == nil
  904. })
  905. XCTAssertTrue(citations
  906. .contains {
  907. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  908. && $0.title == "some-citation-2" && $0.license == nil
  909. })
  910. XCTAssertTrue(citations
  911. .contains {
  912. $0.startIndex == 272 && $0.endIndex == 431
  913. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  914. && $0.license == "mit"
  915. })
  916. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  917. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  918. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  919. }
  920. func testGenerateContentStream_appCheck_validToken() async throws {
  921. let appCheckToken = "test-valid-token"
  922. model = GenerativeModel(
  923. name: testModelResourceName,
  924. projectID: "my-project-id",
  925. apiKey: "API_KEY",
  926. tools: nil,
  927. requestOptions: RequestOptions(),
  928. appCheck: AppCheckInteropFake(token: appCheckToken),
  929. auth: nil,
  930. urlSession: urlSession
  931. )
  932. MockURLProtocol
  933. .requestHandler = try httpRequestHandler(
  934. forResource: "streaming-success-basic-reply-short",
  935. withExtension: "txt",
  936. appCheckToken: appCheckToken
  937. )
  938. let stream = try model.generateContentStream(testPrompt)
  939. for try await _ in stream {}
  940. }
  941. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  942. model = GenerativeModel(
  943. name: testModelResourceName,
  944. projectID: "my-project-id",
  945. apiKey: "API_KEY",
  946. tools: nil,
  947. requestOptions: RequestOptions(),
  948. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  949. auth: nil,
  950. urlSession: urlSession
  951. )
  952. MockURLProtocol
  953. .requestHandler = try httpRequestHandler(
  954. forResource: "streaming-success-basic-reply-short",
  955. withExtension: "txt",
  956. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  957. )
  958. let stream = try model.generateContentStream(testPrompt)
  959. for try await _ in stream {}
  960. }
  961. func testGenerateContentStream_usageMetadata() async throws {
  962. MockURLProtocol
  963. .requestHandler = try httpRequestHandler(
  964. forResource: "streaming-success-basic-reply-short",
  965. withExtension: "txt"
  966. )
  967. var responses = [GenerateContentResponse]()
  968. let stream = try model.generateContentStream(testPrompt)
  969. for try await response in stream {
  970. responses.append(response)
  971. }
  972. for (index, response) in responses.enumerated() {
  973. if index == responses.endIndex - 1 {
  974. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  975. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  976. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  977. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  978. } else {
  979. // Only the last streamed response contains usage metadata
  980. XCTAssertNil(response.usageMetadata)
  981. }
  982. }
  983. }
  984. func testGenerateContentStream_errorMidStream() async throws {
  985. MockURLProtocol.requestHandler = try httpRequestHandler(
  986. forResource: "streaming-failure-error-mid-stream",
  987. withExtension: "txt"
  988. )
  989. var responseCount = 0
  990. do {
  991. let stream = try model.generateContentStream("Hi")
  992. for try await content in stream {
  993. XCTAssertNotNil(content.text)
  994. responseCount += 1
  995. }
  996. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  997. XCTAssertEqual(rpcError.httpResponseCode, 499)
  998. XCTAssertEqual(rpcError.status, .cancelled)
  999. // Check the content count is correct.
  1000. XCTAssertEqual(responseCount, 2)
  1001. return
  1002. }
  1003. XCTFail("Expected an internalError with an RPCError.")
  1004. }
  1005. func testGenerateContentStream_nonHTTPResponse() async throws {
  1006. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  1007. let stream = try model.generateContentStream("Hi")
  1008. do {
  1009. for try await content in stream {
  1010. XCTFail("Unexpected content in stream: \(content)")
  1011. }
  1012. } catch let GenerateContentError.internalError(underlying) {
  1013. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1014. return
  1015. }
  1016. XCTFail("Expected an internal error.")
  1017. }
  1018. func testGenerateContentStream_invalidResponse() async throws {
  1019. MockURLProtocol
  1020. .requestHandler = try httpRequestHandler(
  1021. forResource: "streaming-failure-invalid-json",
  1022. withExtension: "txt"
  1023. )
  1024. let stream = try model.generateContentStream(testPrompt)
  1025. do {
  1026. for try await content in stream {
  1027. XCTFail("Unexpected content in stream: \(content)")
  1028. }
  1029. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1030. guard case let .dataCorrupted(context) = underlying else {
  1031. XCTFail("Not a data corrupted error: \(underlying)")
  1032. return
  1033. }
  1034. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1035. return
  1036. }
  1037. XCTFail("Expected an internal error.")
  1038. }
  1039. func testGenerateContentStream_malformedContent() async throws {
  1040. MockURLProtocol
  1041. .requestHandler = try httpRequestHandler(
  1042. forResource: "streaming-failure-malformed-content",
  1043. withExtension: "txt"
  1044. )
  1045. let stream = try model.generateContentStream(testPrompt)
  1046. do {
  1047. for try await content in stream {
  1048. XCTFail("Unexpected content in stream: \(content)")
  1049. }
  1050. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1051. guard case let .malformedContent(contentError) = underlyingError else {
  1052. XCTFail("Not a malformed content error: \(underlyingError)")
  1053. return
  1054. }
  1055. XCTAssert(contentError is DecodingError)
  1056. return
  1057. }
  1058. XCTFail("Expected an internal decoding error.")
  1059. }
  1060. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1061. let expectedTimeout = 150.0
  1062. MockURLProtocol
  1063. .requestHandler = try httpRequestHandler(
  1064. forResource: "streaming-success-basic-reply-short",
  1065. withExtension: "txt",
  1066. timeout: expectedTimeout
  1067. )
  1068. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1069. model = GenerativeModel(
  1070. name: testModelResourceName,
  1071. projectID: "my-project-id",
  1072. apiKey: "API_KEY",
  1073. tools: nil,
  1074. requestOptions: requestOptions,
  1075. appCheck: nil,
  1076. auth: nil,
  1077. urlSession: urlSession
  1078. )
  1079. var responses = 0
  1080. let stream = try model.generateContentStream(testPrompt)
  1081. for try await content in stream {
  1082. XCTAssertNotNil(content.text)
  1083. responses += 1
  1084. }
  1085. XCTAssertEqual(responses, 1)
  1086. }
  1087. // MARK: - Count Tokens
  1088. func testCountTokens_succeeds() async throws {
  1089. MockURLProtocol.requestHandler = try httpRequestHandler(
  1090. forResource: "unary-success-total-tokens",
  1091. withExtension: "json"
  1092. )
  1093. let response = try await model.countTokens("Why is the sky blue?")
  1094. XCTAssertEqual(response.totalTokens, 6)
  1095. XCTAssertEqual(response.totalBillableCharacters, 16)
  1096. }
  1097. func testCountTokens_succeeds_allOptions() async throws {
  1098. MockURLProtocol.requestHandler = try httpRequestHandler(
  1099. forResource: "unary-success-total-tokens",
  1100. withExtension: "json"
  1101. )
  1102. let generationConfig = GenerationConfig(
  1103. temperature: 0.5,
  1104. topP: 0.9,
  1105. topK: 3,
  1106. candidateCount: 1,
  1107. maxOutputTokens: 1024,
  1108. stopSequences: ["test-stop"],
  1109. responseMIMEType: "text/plain"
  1110. )
  1111. let sumFunction = FunctionDeclaration(
  1112. name: "sum",
  1113. description: "Add two integers.",
  1114. parameters: ["x": .integer(), "y": .integer()]
  1115. )
  1116. let systemInstruction = ModelContent(
  1117. role: "system",
  1118. parts: "You are a calculator. Use the provided tools."
  1119. )
  1120. model = GenerativeModel(
  1121. name: testModelResourceName,
  1122. projectID: "my-project-id",
  1123. apiKey: "API_KEY",
  1124. generationConfig: generationConfig,
  1125. tools: [Tool(functionDeclarations: [sumFunction])],
  1126. systemInstruction: systemInstruction,
  1127. requestOptions: RequestOptions(),
  1128. appCheck: nil,
  1129. auth: nil,
  1130. urlSession: urlSession
  1131. )
  1132. let response = try await model.countTokens("Why is the sky blue?")
  1133. XCTAssertEqual(response.totalTokens, 6)
  1134. XCTAssertEqual(response.totalBillableCharacters, 16)
  1135. }
  1136. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1137. MockURLProtocol.requestHandler = try httpRequestHandler(
  1138. forResource: "unary-success-no-billable-characters",
  1139. withExtension: "json"
  1140. )
  1141. let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg"))
  1142. XCTAssertEqual(response.totalTokens, 258)
  1143. XCTAssertNil(response.totalBillableCharacters)
  1144. }
  1145. func testCountTokens_modelNotFound() async throws {
  1146. MockURLProtocol.requestHandler = try httpRequestHandler(
  1147. forResource: "unary-failure-model-not-found", withExtension: "json",
  1148. statusCode: 404
  1149. )
  1150. do {
  1151. _ = try await model.countTokens("Why is the sky blue?")
  1152. XCTFail("Request should not have succeeded.")
  1153. } catch let rpcError as RPCError {
  1154. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1155. XCTAssertEqual(rpcError.status, .notFound)
  1156. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1157. return
  1158. }
  1159. XCTFail("Expected internal RPCError.")
  1160. }
  1161. func testCountTokens_requestOptions_customTimeout() async throws {
  1162. let expectedTimeout = 150.0
  1163. MockURLProtocol
  1164. .requestHandler = try httpRequestHandler(
  1165. forResource: "unary-success-total-tokens",
  1166. withExtension: "json",
  1167. timeout: expectedTimeout
  1168. )
  1169. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1170. model = GenerativeModel(
  1171. name: testModelResourceName,
  1172. projectID: "my-project-id",
  1173. apiKey: "API_KEY",
  1174. tools: nil,
  1175. requestOptions: requestOptions,
  1176. appCheck: nil,
  1177. auth: nil,
  1178. urlSession: urlSession
  1179. )
  1180. let response = try await model.countTokens(testPrompt)
  1181. XCTAssertEqual(response.totalTokens, 6)
  1182. }
  1183. // MARK: - Helpers
  1184. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1185. URLResponse,
  1186. AsyncLineSequence<URL.AsyncBytes>?
  1187. )) {
  1188. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1189. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1190. #if os(watchOS)
  1191. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1192. #endif // os(watchOS)
  1193. return { request in
  1194. // This is *not* an HTTPURLResponse
  1195. let response = URLResponse(
  1196. url: request.url!,
  1197. mimeType: nil,
  1198. expectedContentLength: 0,
  1199. textEncodingName: nil
  1200. )
  1201. return (response, nil)
  1202. }
  1203. }
  1204. private func httpRequestHandler(forResource name: String,
  1205. withExtension ext: String,
  1206. statusCode: Int = 200,
  1207. timeout: TimeInterval = RequestOptions().timeout,
  1208. appCheckToken: String? = nil,
  1209. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1210. URLResponse,
  1211. AsyncLineSequence<URL.AsyncBytes>?
  1212. )) {
  1213. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1214. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1215. #if os(watchOS)
  1216. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1217. #endif // os(watchOS)
  1218. #if SWIFT_PACKAGE
  1219. let bundle = Bundle.module
  1220. #else // SWIFT_PACKAGE
  1221. let bundle = Bundle(for: Self.self)
  1222. #endif // SWIFT_PACKAGE
  1223. let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
  1224. return { request in
  1225. let requestURL = try XCTUnwrap(request.url)
  1226. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1227. XCTAssertEqual(request.timeoutInterval, timeout)
  1228. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1229. .components(separatedBy: " ")
  1230. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1231. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1232. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1233. if let authToken {
  1234. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1235. } else {
  1236. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1237. }
  1238. let response = try XCTUnwrap(HTTPURLResponse(
  1239. url: requestURL,
  1240. statusCode: statusCode,
  1241. httpVersion: nil,
  1242. headerFields: nil
  1243. ))
  1244. return (response, fileURL.lines)
  1245. }
  1246. }
  1247. }
  1248. private extension String {
  1249. /// Returns the number of occurrences of `substring` in the `String`.
  1250. func occurrenceCount(of substring: String) -> Int {
  1251. return components(separatedBy: substring).count - 1
  1252. }
  1253. }
  1254. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1255. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1256. /// The placeholder token value returned when an error occurs
  1257. static let placeholderTokenValue = "placeholder-token"
  1258. var token: String
  1259. var error: Error?
  1260. private init(token: String, error: Error?) {
  1261. self.token = token
  1262. self.error = error
  1263. }
  1264. convenience init(token: String) {
  1265. self.init(token: token, error: nil)
  1266. }
  1267. convenience init(error: Error) {
  1268. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1269. }
  1270. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1271. return AppCheckTokenResultInteropFake(token: token, error: error)
  1272. }
  1273. func tokenDidChangeNotificationName() -> String {
  1274. fatalError("\(#function) not implemented.")
  1275. }
  1276. func notificationTokenKey() -> String {
  1277. fatalError("\(#function) not implemented.")
  1278. }
  1279. func notificationAppNameKey() -> String {
  1280. fatalError("\(#function) not implemented.")
  1281. }
  1282. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1283. var token: String
  1284. var error: Error?
  1285. init(token: String, error: Error?) {
  1286. self.token = token
  1287. self.error = error
  1288. }
  1289. }
  1290. }
  1291. struct AppCheckErrorFake: Error {}
  1292. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1293. extension SafetyRating: Swift.Comparable {
  1294. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1295. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1296. return lhs.category.rawValue < rhs.category.rawValue
  1297. }
  1298. }