GenerativeModelTests.swift 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. let testModelResourceName =
  29. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  30. var urlSession: URLSession!
  31. var model: GenerativeModel!
  32. override func setUp() async throws {
  33. let configuration = URLSessionConfiguration.default
  34. configuration.protocolClasses = [MockURLProtocol.self]
  35. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  36. model = GenerativeModel(
  37. name: testModelResourceName,
  38. projectID: "my-project-id",
  39. apiKey: "API_KEY",
  40. tools: nil,
  41. requestOptions: RequestOptions(),
  42. appCheck: nil,
  43. auth: nil,
  44. urlSession: urlSession
  45. )
  46. }
  47. override func tearDown() {
  48. MockURLProtocol.requestHandler = nil
  49. }
  50. // MARK: - Generate Content
  51. func testGenerateContent_success_basicReplyLong() async throws {
  52. MockURLProtocol
  53. .requestHandler = try httpRequestHandler(
  54. forResource: "unary-success-basic-reply-long",
  55. withExtension: "json"
  56. )
  57. let response = try await model.generateContent(testPrompt)
  58. XCTAssertEqual(response.candidates.count, 1)
  59. let candidate = try XCTUnwrap(response.candidates.first)
  60. let finishReason = try XCTUnwrap(candidate.finishReason)
  61. XCTAssertEqual(finishReason, .stop)
  62. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  63. XCTAssertEqual(candidate.content.parts.count, 1)
  64. let part = try XCTUnwrap(candidate.content.parts.first)
  65. let partText = try XCTUnwrap(part.text)
  66. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  67. XCTAssertEqual(response.text, partText)
  68. XCTAssertEqual(response.functionCalls, [])
  69. }
  70. func testGenerateContent_success_basicReplyShort() async throws {
  71. MockURLProtocol
  72. .requestHandler = try httpRequestHandler(
  73. forResource: "unary-success-basic-reply-short",
  74. withExtension: "json"
  75. )
  76. let response = try await model.generateContent(testPrompt)
  77. XCTAssertEqual(response.candidates.count, 1)
  78. let candidate = try XCTUnwrap(response.candidates.first)
  79. let finishReason = try XCTUnwrap(candidate.finishReason)
  80. XCTAssertEqual(finishReason, .stop)
  81. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  82. XCTAssertEqual(candidate.content.parts.count, 1)
  83. let part = try XCTUnwrap(candidate.content.parts.first)
  84. XCTAssertEqual(part.text, "Mountain View, California")
  85. XCTAssertEqual(response.text, part.text)
  86. XCTAssertEqual(response.functionCalls, [])
  87. }
  88. func testGenerateContent_success_citations() async throws {
  89. MockURLProtocol
  90. .requestHandler = try httpRequestHandler(
  91. forResource: "unary-success-citations",
  92. withExtension: "json"
  93. )
  94. let response = try await model.generateContent(testPrompt)
  95. XCTAssertEqual(response.candidates.count, 1)
  96. let candidate = try XCTUnwrap(response.candidates.first)
  97. XCTAssertEqual(candidate.content.parts.count, 1)
  98. XCTAssertEqual(response.text, "Some information cited from an external source")
  99. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  100. XCTAssertEqual(citationMetadata.citations.count, 3)
  101. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  102. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  103. XCTAssertEqual(citationSource1.startIndex, 0)
  104. XCTAssertEqual(citationSource1.endIndex, 128)
  105. XCTAssertNil(citationSource1.title)
  106. XCTAssertNil(citationSource1.license)
  107. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  108. XCTAssertEqual(citationSource2.title, "some-citation-2")
  109. XCTAssertEqual(citationSource2.startIndex, 130)
  110. XCTAssertEqual(citationSource2.endIndex, 265)
  111. XCTAssertNil(citationSource2.uri)
  112. XCTAssertNil(citationSource2.license)
  113. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  114. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  115. XCTAssertEqual(citationSource3.startIndex, 272)
  116. XCTAssertEqual(citationSource3.endIndex, 431)
  117. XCTAssertEqual(citationSource3.license, "mit")
  118. XCTAssertNil(citationSource3.title)
  119. }
  120. func testGenerateContent_success_quoteReply() async throws {
  121. MockURLProtocol
  122. .requestHandler = try httpRequestHandler(
  123. forResource: "unary-success-quote-reply",
  124. withExtension: "json"
  125. )
  126. let response = try await model.generateContent(testPrompt)
  127. XCTAssertEqual(response.candidates.count, 1)
  128. let candidate = try XCTUnwrap(response.candidates.first)
  129. let finishReason = try XCTUnwrap(candidate.finishReason)
  130. XCTAssertEqual(finishReason, .stop)
  131. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  132. XCTAssertEqual(candidate.content.parts.count, 1)
  133. let part = try XCTUnwrap(candidate.content.parts.first)
  134. let partText = try XCTUnwrap(part.text)
  135. XCTAssertTrue(partText.hasPrefix("Google"))
  136. XCTAssertEqual(response.text, part.text)
  137. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  138. XCTAssertNil(promptFeedback.blockReason)
  139. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  140. }
  141. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  142. let expectedSafetyRatings = [
  143. SafetyRating(category: .harassment, probability: .medium),
  144. SafetyRating(category: .dangerousContent, probability: .unknown),
  145. SafetyRating(category: .unknown, probability: .high),
  146. ]
  147. MockURLProtocol
  148. .requestHandler = try httpRequestHandler(
  149. forResource: "unary-success-unknown-enum-safety-ratings",
  150. withExtension: "json"
  151. )
  152. let response = try await model.generateContent(testPrompt)
  153. XCTAssertEqual(response.text, "Some text")
  154. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  155. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  156. }
  157. func testGenerateContent_success_prefixedModelName() async throws {
  158. MockURLProtocol
  159. .requestHandler = try httpRequestHandler(
  160. forResource: "unary-success-basic-reply-short",
  161. withExtension: "json"
  162. )
  163. let model = GenerativeModel(
  164. // Model name is prefixed with "models/".
  165. name: "models/test-model",
  166. projectID: "my-project-id",
  167. apiKey: "API_KEY",
  168. tools: nil,
  169. requestOptions: RequestOptions(),
  170. appCheck: nil,
  171. auth: nil,
  172. urlSession: urlSession
  173. )
  174. _ = try await model.generateContent(testPrompt)
  175. }
  176. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  177. MockURLProtocol
  178. .requestHandler = try httpRequestHandler(
  179. forResource: "unary-success-function-call-empty-arguments",
  180. withExtension: "json"
  181. )
  182. let response = try await model.generateContent(testPrompt)
  183. XCTAssertEqual(response.candidates.count, 1)
  184. let candidate = try XCTUnwrap(response.candidates.first)
  185. XCTAssertEqual(candidate.content.parts.count, 1)
  186. let part = try XCTUnwrap(candidate.content.parts.first)
  187. guard case let .functionCall(functionCall) = part else {
  188. XCTFail("Part is not a FunctionCall.")
  189. return
  190. }
  191. XCTAssertEqual(functionCall.name, "current_time")
  192. XCTAssertTrue(functionCall.args.isEmpty)
  193. XCTAssertEqual(response.functionCalls, [functionCall])
  194. }
  195. func testGenerateContent_success_functionCall_noArguments() async throws {
  196. MockURLProtocol
  197. .requestHandler = try httpRequestHandler(
  198. forResource: "unary-success-function-call-no-arguments",
  199. withExtension: "json"
  200. )
  201. let response = try await model.generateContent(testPrompt)
  202. XCTAssertEqual(response.candidates.count, 1)
  203. let candidate = try XCTUnwrap(response.candidates.first)
  204. XCTAssertEqual(candidate.content.parts.count, 1)
  205. let part = try XCTUnwrap(candidate.content.parts.first)
  206. guard case let .functionCall(functionCall) = part else {
  207. XCTFail("Part is not a FunctionCall.")
  208. return
  209. }
  210. XCTAssertEqual(functionCall.name, "current_time")
  211. XCTAssertTrue(functionCall.args.isEmpty)
  212. XCTAssertEqual(response.functionCalls, [functionCall])
  213. }
  214. func testGenerateContent_success_functionCall_withArguments() async throws {
  215. MockURLProtocol
  216. .requestHandler = try httpRequestHandler(
  217. forResource: "unary-success-function-call-with-arguments",
  218. withExtension: "json"
  219. )
  220. let response = try await model.generateContent(testPrompt)
  221. XCTAssertEqual(response.candidates.count, 1)
  222. let candidate = try XCTUnwrap(response.candidates.first)
  223. XCTAssertEqual(candidate.content.parts.count, 1)
  224. let part = try XCTUnwrap(candidate.content.parts.first)
  225. guard case let .functionCall(functionCall) = part else {
  226. XCTFail("Part is not a FunctionCall.")
  227. return
  228. }
  229. XCTAssertEqual(functionCall.name, "sum")
  230. XCTAssertEqual(functionCall.args.count, 2)
  231. let argX = try XCTUnwrap(functionCall.args["x"])
  232. XCTAssertEqual(argX, .number(4))
  233. let argY = try XCTUnwrap(functionCall.args["y"])
  234. XCTAssertEqual(argY, .number(5))
  235. XCTAssertEqual(response.functionCalls, [functionCall])
  236. }
  237. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  238. MockURLProtocol
  239. .requestHandler = try httpRequestHandler(
  240. forResource: "unary-success-function-call-parallel-calls",
  241. withExtension: "json"
  242. )
  243. let response = try await model.generateContent(testPrompt)
  244. XCTAssertEqual(response.candidates.count, 1)
  245. let candidate = try XCTUnwrap(response.candidates.first)
  246. XCTAssertEqual(candidate.content.parts.count, 3)
  247. let functionCalls = response.functionCalls
  248. XCTAssertEqual(functionCalls.count, 3)
  249. }
  250. func testGenerateContent_success_functionCall_mixedContent() async throws {
  251. MockURLProtocol
  252. .requestHandler = try httpRequestHandler(
  253. forResource: "unary-success-function-call-mixed-content",
  254. withExtension: "json"
  255. )
  256. let response = try await model.generateContent(testPrompt)
  257. XCTAssertEqual(response.candidates.count, 1)
  258. let candidate = try XCTUnwrap(response.candidates.first)
  259. XCTAssertEqual(candidate.content.parts.count, 4)
  260. let functionCalls = response.functionCalls
  261. XCTAssertEqual(functionCalls.count, 2)
  262. let text = try XCTUnwrap(response.text)
  263. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  264. }
  265. func testGenerateContent_appCheck_validToken() async throws {
  266. let appCheckToken = "test-valid-token"
  267. model = GenerativeModel(
  268. name: testModelResourceName,
  269. projectID: "my-project-id",
  270. apiKey: "API_KEY",
  271. tools: nil,
  272. requestOptions: RequestOptions(),
  273. appCheck: AppCheckInteropFake(token: appCheckToken),
  274. auth: nil,
  275. urlSession: urlSession
  276. )
  277. MockURLProtocol
  278. .requestHandler = try httpRequestHandler(
  279. forResource: "unary-success-basic-reply-short",
  280. withExtension: "json",
  281. appCheckToken: appCheckToken
  282. )
  283. _ = try await model.generateContent(testPrompt)
  284. }
  285. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  286. model = GenerativeModel(
  287. name: testModelResourceName,
  288. projectID: "my-project-id",
  289. apiKey: "API_KEY",
  290. tools: nil,
  291. requestOptions: RequestOptions(),
  292. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  293. auth: nil,
  294. urlSession: urlSession
  295. )
  296. MockURLProtocol
  297. .requestHandler = try httpRequestHandler(
  298. forResource: "unary-success-basic-reply-short",
  299. withExtension: "json",
  300. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  301. )
  302. _ = try await model.generateContent(testPrompt)
  303. }
  304. func testGenerateContent_auth_validAuthToken() async throws {
  305. let authToken = "test-valid-token"
  306. model = GenerativeModel(
  307. name: testModelResourceName,
  308. projectID: "my-project-id",
  309. apiKey: "API_KEY",
  310. tools: nil,
  311. requestOptions: RequestOptions(),
  312. appCheck: nil,
  313. auth: AuthInteropFake(token: authToken),
  314. urlSession: urlSession
  315. )
  316. MockURLProtocol
  317. .requestHandler = try httpRequestHandler(
  318. forResource: "unary-success-basic-reply-short",
  319. withExtension: "json",
  320. authToken: authToken
  321. )
  322. _ = try await model.generateContent(testPrompt)
  323. }
  324. func testGenerateContent_auth_nilAuthToken() async throws {
  325. model = GenerativeModel(
  326. name: testModelResourceName,
  327. projectID: "my-project-id",
  328. apiKey: "API_KEY",
  329. tools: nil,
  330. requestOptions: RequestOptions(),
  331. appCheck: nil,
  332. auth: AuthInteropFake(token: nil),
  333. urlSession: urlSession
  334. )
  335. MockURLProtocol
  336. .requestHandler = try httpRequestHandler(
  337. forResource: "unary-success-basic-reply-short",
  338. withExtension: "json",
  339. authToken: nil
  340. )
  341. _ = try await model.generateContent(testPrompt)
  342. }
  343. func testGenerateContent_auth_authTokenRefreshError() async throws {
  344. model = GenerativeModel(
  345. name: "my-model",
  346. projectID: "my-project-id",
  347. apiKey: "API_KEY",
  348. tools: nil,
  349. requestOptions: RequestOptions(),
  350. appCheck: nil,
  351. auth: AuthInteropFake(error: AuthErrorFake()),
  352. urlSession: urlSession
  353. )
  354. MockURLProtocol
  355. .requestHandler = try httpRequestHandler(
  356. forResource: "unary-success-basic-reply-short",
  357. withExtension: "json",
  358. authToken: nil
  359. )
  360. do {
  361. _ = try await model.generateContent(testPrompt)
  362. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  363. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  364. //
  365. } catch {
  366. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  367. }
  368. }
  369. func testGenerateContent_usageMetadata() async throws {
  370. MockURLProtocol
  371. .requestHandler = try httpRequestHandler(
  372. forResource: "unary-success-basic-reply-short",
  373. withExtension: "json"
  374. )
  375. let response = try await model.generateContent(testPrompt)
  376. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  377. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  378. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  379. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  380. }
  381. func testGenerateContent_failure_invalidAPIKey() async throws {
  382. let expectedStatusCode = 400
  383. MockURLProtocol
  384. .requestHandler = try httpRequestHandler(
  385. forResource: "unary-failure-api-key",
  386. withExtension: "json",
  387. statusCode: expectedStatusCode
  388. )
  389. do {
  390. _ = try await model.generateContent(testPrompt)
  391. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  392. } catch let GenerateContentError.internalError(error as RPCError) {
  393. XCTAssertEqual(error.httpResponseCode, 400)
  394. XCTAssertEqual(error.status, .invalidArgument)
  395. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  396. return
  397. } catch {
  398. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  399. }
  400. }
  401. // TODO(andrewheard): Remove this test case after the Vertex AI in Firebase API launch.
  402. func testGenerateContent_failure_firebaseMLAPINotEnabled() async throws {
  403. let expectedStatusCode = 403
  404. MockURLProtocol
  405. .requestHandler = try httpRequestHandler(
  406. forResource: "unary-failure-firebaseml-api-not-enabled",
  407. withExtension: "json",
  408. statusCode: expectedStatusCode
  409. )
  410. do {
  411. _ = try await model.generateContent(testPrompt)
  412. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  413. } catch let GenerateContentError.internalError(error as RPCError) {
  414. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  415. XCTAssertEqual(error.status, .permissionDenied)
  416. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  417. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  418. return
  419. } catch {
  420. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  421. }
  422. }
  423. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  424. let expectedStatusCode = 403
  425. MockURLProtocol
  426. .requestHandler = try httpRequestHandler(
  427. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  428. withExtension: "json",
  429. statusCode: expectedStatusCode
  430. )
  431. do {
  432. _ = try await model.generateContent(testPrompt)
  433. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  434. } catch let GenerateContentError.internalError(error as RPCError) {
  435. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  436. XCTAssertEqual(error.status, .permissionDenied)
  437. XCTAssertTrue(error.message
  438. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  439. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  440. return
  441. } catch {
  442. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  443. }
  444. }
  445. func testGenerateContent_failure_emptyContent() async throws {
  446. MockURLProtocol
  447. .requestHandler = try httpRequestHandler(
  448. forResource: "unary-failure-empty-content",
  449. withExtension: "json"
  450. )
  451. do {
  452. _ = try await model.generateContent(testPrompt)
  453. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  454. } catch let GenerateContentError
  455. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  456. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  457. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  458. return
  459. }
  460. _ = try XCTUnwrap(decodingError as? DecodingError,
  461. "Not a DecodingError: \(decodingError)")
  462. } catch {
  463. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  464. }
  465. }
  466. func testGenerateContent_failure_finishReasonSafety() async throws {
  467. MockURLProtocol
  468. .requestHandler = try httpRequestHandler(
  469. forResource: "unary-failure-finish-reason-safety",
  470. withExtension: "json"
  471. )
  472. do {
  473. _ = try await model.generateContent(testPrompt)
  474. XCTFail("Should throw")
  475. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  476. XCTAssertEqual(reason, .safety)
  477. XCTAssertEqual(response.text, "<redacted>")
  478. } catch {
  479. XCTFail("Should throw a responseStoppedEarly")
  480. }
  481. }
  482. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  483. MockURLProtocol
  484. .requestHandler = try httpRequestHandler(
  485. forResource: "unary-failure-finish-reason-safety-no-content",
  486. withExtension: "json"
  487. )
  488. do {
  489. _ = try await model.generateContent(testPrompt)
  490. XCTFail("Should throw")
  491. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  492. XCTAssertEqual(reason, .safety)
  493. XCTAssertNil(response.text)
  494. } catch {
  495. XCTFail("Should throw a responseStoppedEarly")
  496. }
  497. }
  498. func testGenerateContent_failure_imageRejected() async throws {
  499. let expectedStatusCode = 400
  500. MockURLProtocol
  501. .requestHandler = try httpRequestHandler(
  502. forResource: "unary-failure-image-rejected",
  503. withExtension: "json",
  504. statusCode: 400
  505. )
  506. do {
  507. _ = try await model.generateContent(testPrompt)
  508. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  509. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  510. XCTAssertEqual(rpcError.status, .invalidArgument)
  511. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  512. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  513. } catch {
  514. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  515. }
  516. }
  517. func testGenerateContent_failure_promptBlockedSafety() async throws {
  518. MockURLProtocol
  519. .requestHandler = try httpRequestHandler(
  520. forResource: "unary-failure-prompt-blocked-safety",
  521. withExtension: "json"
  522. )
  523. do {
  524. _ = try await model.generateContent(testPrompt)
  525. XCTFail("Should throw")
  526. } catch let GenerateContentError.promptBlocked(response) {
  527. XCTAssertNil(response.text)
  528. } catch {
  529. XCTFail("Should throw a promptBlocked")
  530. }
  531. }
  532. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  533. MockURLProtocol
  534. .requestHandler = try httpRequestHandler(
  535. forResource: "unary-failure-unknown-enum-finish-reason",
  536. withExtension: "json"
  537. )
  538. do {
  539. _ = try await model.generateContent(testPrompt)
  540. XCTFail("Should throw")
  541. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  542. XCTAssertEqual(reason, .unknown)
  543. XCTAssertEqual(response.text, "Some text")
  544. } catch {
  545. XCTFail("Should throw a responseStoppedEarly")
  546. }
  547. }
  548. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  549. MockURLProtocol
  550. .requestHandler = try httpRequestHandler(
  551. forResource: "unary-failure-unknown-enum-prompt-blocked",
  552. withExtension: "json"
  553. )
  554. do {
  555. _ = try await model.generateContent(testPrompt)
  556. XCTFail("Should throw")
  557. } catch let GenerateContentError.promptBlocked(response) {
  558. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  559. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  560. } catch {
  561. XCTFail("Should throw a promptBlocked")
  562. }
  563. }
  564. func testGenerateContent_failure_unknownModel() async throws {
  565. let expectedStatusCode = 404
  566. MockURLProtocol
  567. .requestHandler = try httpRequestHandler(
  568. forResource: "unary-failure-unknown-model",
  569. withExtension: "json",
  570. statusCode: 404
  571. )
  572. do {
  573. _ = try await model.generateContent(testPrompt)
  574. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  575. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  576. XCTAssertEqual(rpcError.status, .notFound)
  577. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  578. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  579. } catch {
  580. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  581. }
  582. }
  583. func testGenerateContent_failure_nonHTTPResponse() async throws {
  584. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  585. var responseError: Error?
  586. var content: GenerateContentResponse?
  587. do {
  588. content = try await model.generateContent(testPrompt)
  589. } catch {
  590. responseError = error
  591. }
  592. XCTAssertNil(content)
  593. XCTAssertNotNil(responseError)
  594. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  595. guard case let .internalError(underlyingError) = generateContentError else {
  596. XCTFail("Not an internal error: \(generateContentError)")
  597. return
  598. }
  599. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  600. }
  601. func testGenerateContent_failure_invalidResponse() async throws {
  602. MockURLProtocol.requestHandler = try httpRequestHandler(
  603. forResource: "unary-failure-invalid-response",
  604. withExtension: "json"
  605. )
  606. var responseError: Error?
  607. var content: GenerateContentResponse?
  608. do {
  609. content = try await model.generateContent(testPrompt)
  610. } catch {
  611. responseError = error
  612. }
  613. XCTAssertNil(content)
  614. XCTAssertNotNil(responseError)
  615. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  616. guard case let .internalError(underlyingError) = generateContentError else {
  617. XCTFail("Not an internal error: \(generateContentError)")
  618. return
  619. }
  620. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  621. guard case let .dataCorrupted(context) = decodingError else {
  622. XCTFail("Not a data corrupted error: \(decodingError)")
  623. return
  624. }
  625. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  626. }
  627. func testGenerateContent_failure_malformedContent() async throws {
  628. MockURLProtocol
  629. .requestHandler = try httpRequestHandler(
  630. forResource: "unary-failure-malformed-content",
  631. withExtension: "json"
  632. )
  633. var responseError: Error?
  634. var content: GenerateContentResponse?
  635. do {
  636. content = try await model.generateContent(testPrompt)
  637. } catch {
  638. responseError = error
  639. }
  640. XCTAssertNil(content)
  641. XCTAssertNotNil(responseError)
  642. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  643. guard case let .internalError(underlyingError) = generateContentError else {
  644. XCTFail("Not an internal error: \(generateContentError)")
  645. return
  646. }
  647. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  648. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  649. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  650. return
  651. }
  652. _ = try XCTUnwrap(
  653. malformedContentUnderlyingError as? DecodingError,
  654. "Not a decoding error: \(malformedContentUnderlyingError)"
  655. )
  656. }
  657. func testGenerateContentMissingSafetyRatings() async throws {
  658. MockURLProtocol.requestHandler = try httpRequestHandler(
  659. forResource: "unary-success-missing-safety-ratings",
  660. withExtension: "json"
  661. )
  662. let content = try await model.generateContent(testPrompt)
  663. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  664. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  665. XCTAssertEqual(content.text, "This is the generated content.")
  666. }
  667. func testGenerateContent_requestOptions_customTimeout() async throws {
  668. let expectedTimeout = 150.0
  669. MockURLProtocol
  670. .requestHandler = try httpRequestHandler(
  671. forResource: "unary-success-basic-reply-short",
  672. withExtension: "json",
  673. timeout: expectedTimeout
  674. )
  675. let requestOptions = RequestOptions(timeout: expectedTimeout)
  676. model = GenerativeModel(
  677. name: testModelResourceName,
  678. projectID: "my-project-id",
  679. apiKey: "API_KEY",
  680. tools: nil,
  681. requestOptions: requestOptions,
  682. appCheck: nil,
  683. auth: nil,
  684. urlSession: urlSession
  685. )
  686. let response = try await model.generateContent(testPrompt)
  687. XCTAssertEqual(response.candidates.count, 1)
  688. }
  689. // MARK: - Generate Content (Streaming)
  690. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  691. MockURLProtocol
  692. .requestHandler = try httpRequestHandler(
  693. forResource: "unary-failure-api-key",
  694. withExtension: "json"
  695. )
  696. do {
  697. let stream = try model.generateContentStream("Hi")
  698. for try await _ in stream {
  699. XCTFail("No content is there, this shouldn't happen.")
  700. }
  701. } catch let GenerateContentError.internalError(error as RPCError) {
  702. XCTAssertEqual(error.httpResponseCode, 400)
  703. XCTAssertEqual(error.status, .invalidArgument)
  704. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  705. return
  706. }
  707. XCTFail("Should have caught an error.")
  708. }
  709. // TODO(andrewheard): Remove this test case after the Vertex AI in Firebase API launch.
  710. func testGenerateContentStream_failure_firebaseMLAPINotEnabled() async throws {
  711. let expectedStatusCode = 403
  712. MockURLProtocol
  713. .requestHandler = try httpRequestHandler(
  714. forResource: "unary-failure-firebaseml-api-not-enabled",
  715. withExtension: "json",
  716. statusCode: expectedStatusCode
  717. )
  718. do {
  719. let stream = try model.generateContentStream(testPrompt)
  720. for try await _ in stream {
  721. XCTFail("No content is there, this shouldn't happen.")
  722. }
  723. } catch let GenerateContentError.internalError(error as RPCError) {
  724. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  725. XCTAssertEqual(error.status, .permissionDenied)
  726. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  727. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  728. return
  729. }
  730. XCTFail("Should have caught an error.")
  731. }
  732. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  733. let expectedStatusCode = 403
  734. MockURLProtocol
  735. .requestHandler = try httpRequestHandler(
  736. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  737. withExtension: "json",
  738. statusCode: expectedStatusCode
  739. )
  740. do {
  741. let stream = try model.generateContentStream(testPrompt)
  742. for try await _ in stream {
  743. XCTFail("No content is there, this shouldn't happen.")
  744. }
  745. } catch let GenerateContentError.internalError(error as RPCError) {
  746. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  747. XCTAssertEqual(error.status, .permissionDenied)
  748. XCTAssertTrue(error.message
  749. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  750. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  751. return
  752. }
  753. XCTFail("Should have caught an error.")
  754. }
  755. func testGenerateContentStream_failureEmptyContent() async throws {
  756. MockURLProtocol
  757. .requestHandler = try httpRequestHandler(
  758. forResource: "streaming-failure-empty-content",
  759. withExtension: "txt"
  760. )
  761. do {
  762. let stream = try model.generateContentStream("Hi")
  763. for try await _ in stream {
  764. XCTFail("No content is there, this shouldn't happen.")
  765. }
  766. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  767. // Underlying error is as expected, nothing else to check.
  768. return
  769. }
  770. XCTFail("Should have caught an error.")
  771. }
  772. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  773. MockURLProtocol
  774. .requestHandler = try httpRequestHandler(
  775. forResource: "streaming-failure-finish-reason-safety",
  776. withExtension: "txt"
  777. )
  778. do {
  779. let stream = try model.generateContentStream("Hi")
  780. for try await _ in stream {
  781. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  782. }
  783. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  784. XCTAssertEqual(reason, .safety)
  785. return
  786. }
  787. XCTFail("Should have caught an error.")
  788. }
  789. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  790. MockURLProtocol
  791. .requestHandler = try httpRequestHandler(
  792. forResource: "streaming-failure-prompt-blocked-safety",
  793. withExtension: "txt"
  794. )
  795. do {
  796. let stream = try model.generateContentStream("Hi")
  797. for try await _ in stream {
  798. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  799. }
  800. } catch let GenerateContentError.promptBlocked(response) {
  801. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  802. return
  803. }
  804. XCTFail("Should have caught an error.")
  805. }
  806. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  807. MockURLProtocol
  808. .requestHandler = try httpRequestHandler(
  809. forResource: "streaming-failure-unknown-finish-enum",
  810. withExtension: "txt"
  811. )
  812. let stream = try model.generateContentStream("Hi")
  813. do {
  814. for try await content in stream {
  815. XCTAssertNotNil(content.text)
  816. }
  817. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  818. XCTAssertEqual(reason, .unknown)
  819. return
  820. }
  821. XCTFail("Should have caught an error.")
  822. }
  823. func testGenerateContentStream_successBasicReplyLong() async throws {
  824. MockURLProtocol
  825. .requestHandler = try httpRequestHandler(
  826. forResource: "streaming-success-basic-reply-long",
  827. withExtension: "txt"
  828. )
  829. var responses = 0
  830. let stream = try model.generateContentStream("Hi")
  831. for try await content in stream {
  832. XCTAssertNotNil(content.text)
  833. responses += 1
  834. }
  835. XCTAssertEqual(responses, 6)
  836. }
  837. func testGenerateContentStream_successBasicReplyShort() async throws {
  838. MockURLProtocol
  839. .requestHandler = try httpRequestHandler(
  840. forResource: "streaming-success-basic-reply-short",
  841. withExtension: "txt"
  842. )
  843. var responses = 0
  844. let stream = try model.generateContentStream("Hi")
  845. for try await content in stream {
  846. XCTAssertNotNil(content.text)
  847. responses += 1
  848. }
  849. XCTAssertEqual(responses, 1)
  850. }
  851. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  852. MockURLProtocol
  853. .requestHandler = try httpRequestHandler(
  854. forResource: "streaming-success-unknown-safety-enum",
  855. withExtension: "txt"
  856. )
  857. var hadUnknown = false
  858. let stream = try model.generateContentStream("Hi")
  859. for try await content in stream {
  860. XCTAssertNotNil(content.text)
  861. if let ratings = content.candidates.first?.safetyRatings,
  862. ratings.contains(where: { $0.category == .unknown }) {
  863. hadUnknown = true
  864. }
  865. }
  866. XCTAssertTrue(hadUnknown)
  867. }
  868. func testGenerateContentStream_successWithCitations() async throws {
  869. MockURLProtocol
  870. .requestHandler = try httpRequestHandler(
  871. forResource: "streaming-success-citations",
  872. withExtension: "txt"
  873. )
  874. let stream = try model.generateContentStream("Hi")
  875. var citations = [Citation]()
  876. var responses = [GenerateContentResponse]()
  877. for try await content in stream {
  878. responses.append(content)
  879. XCTAssertNotNil(content.text)
  880. let candidate = try XCTUnwrap(content.candidates.first)
  881. if let sources = candidate.citationMetadata?.citations {
  882. citations.append(contentsOf: sources)
  883. }
  884. }
  885. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  886. XCTAssertEqual(lastCandidate.finishReason, .stop)
  887. XCTAssertEqual(citations.count, 6)
  888. XCTAssertTrue(citations
  889. .contains {
  890. $0.startIndex == 0 && $0.endIndex == 128
  891. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  892. && $0.license == nil
  893. })
  894. XCTAssertTrue(citations
  895. .contains {
  896. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  897. && $0.title == "some-citation-2" && $0.license == nil
  898. })
  899. XCTAssertTrue(citations
  900. .contains {
  901. $0.startIndex == 272 && $0.endIndex == 431
  902. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  903. && $0.license == "mit"
  904. })
  905. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  906. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  907. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  908. }
  909. func testGenerateContentStream_appCheck_validToken() async throws {
  910. let appCheckToken = "test-valid-token"
  911. model = GenerativeModel(
  912. name: testModelResourceName,
  913. projectID: "my-project-id",
  914. apiKey: "API_KEY",
  915. tools: nil,
  916. requestOptions: RequestOptions(),
  917. appCheck: AppCheckInteropFake(token: appCheckToken),
  918. auth: nil,
  919. urlSession: urlSession
  920. )
  921. MockURLProtocol
  922. .requestHandler = try httpRequestHandler(
  923. forResource: "streaming-success-basic-reply-short",
  924. withExtension: "txt",
  925. appCheckToken: appCheckToken
  926. )
  927. let stream = try model.generateContentStream(testPrompt)
  928. for try await _ in stream {}
  929. }
  930. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  931. model = GenerativeModel(
  932. name: testModelResourceName,
  933. projectID: "my-project-id",
  934. apiKey: "API_KEY",
  935. tools: nil,
  936. requestOptions: RequestOptions(),
  937. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  938. auth: nil,
  939. urlSession: urlSession
  940. )
  941. MockURLProtocol
  942. .requestHandler = try httpRequestHandler(
  943. forResource: "streaming-success-basic-reply-short",
  944. withExtension: "txt",
  945. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  946. )
  947. let stream = try model.generateContentStream(testPrompt)
  948. for try await _ in stream {}
  949. }
  950. func testGenerateContentStream_usageMetadata() async throws {
  951. MockURLProtocol
  952. .requestHandler = try httpRequestHandler(
  953. forResource: "streaming-success-basic-reply-short",
  954. withExtension: "txt"
  955. )
  956. var responses = [GenerateContentResponse]()
  957. let stream = try model.generateContentStream(testPrompt)
  958. for try await response in stream {
  959. responses.append(response)
  960. }
  961. for (index, response) in responses.enumerated() {
  962. if index == responses.endIndex - 1 {
  963. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  964. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  965. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  966. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  967. } else {
  968. // Only the last streamed response contains usage metadata
  969. XCTAssertNil(response.usageMetadata)
  970. }
  971. }
  972. }
  973. func testGenerateContentStream_errorMidStream() async throws {
  974. MockURLProtocol.requestHandler = try httpRequestHandler(
  975. forResource: "streaming-failure-error-mid-stream",
  976. withExtension: "txt"
  977. )
  978. var responseCount = 0
  979. do {
  980. let stream = try model.generateContentStream("Hi")
  981. for try await content in stream {
  982. XCTAssertNotNil(content.text)
  983. responseCount += 1
  984. }
  985. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  986. XCTAssertEqual(rpcError.httpResponseCode, 499)
  987. XCTAssertEqual(rpcError.status, .cancelled)
  988. // Check the content count is correct.
  989. XCTAssertEqual(responseCount, 2)
  990. return
  991. }
  992. XCTFail("Expected an internalError with an RPCError.")
  993. }
  994. func testGenerateContentStream_nonHTTPResponse() async throws {
  995. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  996. let stream = try model.generateContentStream("Hi")
  997. do {
  998. for try await content in stream {
  999. XCTFail("Unexpected content in stream: \(content)")
  1000. }
  1001. } catch let GenerateContentError.internalError(underlying) {
  1002. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  1003. return
  1004. }
  1005. XCTFail("Expected an internal error.")
  1006. }
  1007. func testGenerateContentStream_invalidResponse() async throws {
  1008. MockURLProtocol
  1009. .requestHandler = try httpRequestHandler(
  1010. forResource: "streaming-failure-invalid-json",
  1011. withExtension: "txt"
  1012. )
  1013. let stream = try model.generateContentStream(testPrompt)
  1014. do {
  1015. for try await content in stream {
  1016. XCTFail("Unexpected content in stream: \(content)")
  1017. }
  1018. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  1019. guard case let .dataCorrupted(context) = underlying else {
  1020. XCTFail("Not a data corrupted error: \(underlying)")
  1021. return
  1022. }
  1023. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  1024. return
  1025. }
  1026. XCTFail("Expected an internal error.")
  1027. }
  1028. func testGenerateContentStream_malformedContent() async throws {
  1029. MockURLProtocol
  1030. .requestHandler = try httpRequestHandler(
  1031. forResource: "streaming-failure-malformed-content",
  1032. withExtension: "txt"
  1033. )
  1034. let stream = try model.generateContentStream(testPrompt)
  1035. do {
  1036. for try await content in stream {
  1037. XCTFail("Unexpected content in stream: \(content)")
  1038. }
  1039. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  1040. guard case let .malformedContent(contentError) = underlyingError else {
  1041. XCTFail("Not a malformed content error: \(underlyingError)")
  1042. return
  1043. }
  1044. XCTAssert(contentError is DecodingError)
  1045. return
  1046. }
  1047. XCTFail("Expected an internal decoding error.")
  1048. }
  1049. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1050. let expectedTimeout = 150.0
  1051. MockURLProtocol
  1052. .requestHandler = try httpRequestHandler(
  1053. forResource: "streaming-success-basic-reply-short",
  1054. withExtension: "txt",
  1055. timeout: expectedTimeout
  1056. )
  1057. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1058. model = GenerativeModel(
  1059. name: testModelResourceName,
  1060. projectID: "my-project-id",
  1061. apiKey: "API_KEY",
  1062. tools: nil,
  1063. requestOptions: requestOptions,
  1064. appCheck: nil,
  1065. auth: nil,
  1066. urlSession: urlSession
  1067. )
  1068. var responses = 0
  1069. let stream = try model.generateContentStream(testPrompt)
  1070. for try await content in stream {
  1071. XCTAssertNotNil(content.text)
  1072. responses += 1
  1073. }
  1074. XCTAssertEqual(responses, 1)
  1075. }
  1076. // MARK: - Count Tokens
  1077. func testCountTokens_succeeds() async throws {
  1078. MockURLProtocol.requestHandler = try httpRequestHandler(
  1079. forResource: "unary-success-total-tokens",
  1080. withExtension: "json"
  1081. )
  1082. let response = try await model.countTokens("Why is the sky blue?")
  1083. XCTAssertEqual(response.totalTokens, 6)
  1084. XCTAssertEqual(response.totalBillableCharacters, 16)
  1085. }
  1086. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1087. MockURLProtocol.requestHandler = try httpRequestHandler(
  1088. forResource: "unary-success-no-billable-characters",
  1089. withExtension: "json"
  1090. )
  1091. let response = try await model.countTokens(ModelContent.Part.inlineData(
  1092. mimetype: "image/jpeg",
  1093. Data()
  1094. ))
  1095. XCTAssertEqual(response.totalTokens, 258)
  1096. XCTAssertNil(response.totalBillableCharacters)
  1097. }
  1098. func testCountTokens_modelNotFound() async throws {
  1099. MockURLProtocol.requestHandler = try httpRequestHandler(
  1100. forResource: "unary-failure-model-not-found", withExtension: "json",
  1101. statusCode: 404
  1102. )
  1103. do {
  1104. _ = try await model.countTokens("Why is the sky blue?")
  1105. XCTFail("Request should not have succeeded.")
  1106. } catch let rpcError as RPCError {
  1107. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1108. XCTAssertEqual(rpcError.status, .notFound)
  1109. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1110. return
  1111. }
  1112. XCTFail("Expected internal RPCError.")
  1113. }
  1114. func testCountTokens_requestOptions_customTimeout() async throws {
  1115. let expectedTimeout = 150.0
  1116. MockURLProtocol
  1117. .requestHandler = try httpRequestHandler(
  1118. forResource: "unary-success-total-tokens",
  1119. withExtension: "json",
  1120. timeout: expectedTimeout
  1121. )
  1122. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1123. model = GenerativeModel(
  1124. name: testModelResourceName,
  1125. projectID: "my-project-id",
  1126. apiKey: "API_KEY",
  1127. tools: nil,
  1128. requestOptions: requestOptions,
  1129. appCheck: nil,
  1130. auth: nil,
  1131. urlSession: urlSession
  1132. )
  1133. let response = try await model.countTokens(testPrompt)
  1134. XCTAssertEqual(response.totalTokens, 6)
  1135. }
  1136. // MARK: - Helpers
  1137. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1138. URLResponse,
  1139. AsyncLineSequence<URL.AsyncBytes>?
  1140. )) {
  1141. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1142. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1143. #if os(watchOS)
  1144. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1145. #endif // os(watchOS)
  1146. return { request in
  1147. // This is *not* an HTTPURLResponse
  1148. let response = URLResponse(
  1149. url: request.url!,
  1150. mimeType: nil,
  1151. expectedContentLength: 0,
  1152. textEncodingName: nil
  1153. )
  1154. return (response, nil)
  1155. }
  1156. }
  1157. private func httpRequestHandler(forResource name: String,
  1158. withExtension ext: String,
  1159. statusCode: Int = 200,
  1160. timeout: TimeInterval = RequestOptions().timeout,
  1161. appCheckToken: String? = nil,
  1162. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1163. URLResponse,
  1164. AsyncLineSequence<URL.AsyncBytes>?
  1165. )) {
  1166. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1167. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1168. #if os(watchOS)
  1169. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1170. #endif // os(watchOS)
  1171. #if SWIFT_PACKAGE
  1172. let bundle = Bundle.module
  1173. #else // SWIFT_PACKAGE
  1174. let bundle = Bundle(for: Self.self)
  1175. #endif // SWIFT_PACKAGE
  1176. let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
  1177. return { request in
  1178. let requestURL = try XCTUnwrap(request.url)
  1179. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1180. XCTAssertEqual(request.timeoutInterval, timeout)
  1181. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1182. .components(separatedBy: " ")
  1183. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1184. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1185. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1186. if let authToken {
  1187. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1188. } else {
  1189. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1190. }
  1191. let response = try XCTUnwrap(HTTPURLResponse(
  1192. url: requestURL,
  1193. statusCode: statusCode,
  1194. httpVersion: nil,
  1195. headerFields: nil
  1196. ))
  1197. return (response, fileURL.lines)
  1198. }
  1199. }
  1200. }
  1201. private extension String {
  1202. /// Returns the number of occurrences of `substring` in the `String`.
  1203. func occurrenceCount(of substring: String) -> Int {
  1204. return components(separatedBy: substring).count - 1
  1205. }
  1206. }
  1207. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1208. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1209. /// The placeholder token value returned when an error occurs
  1210. static let placeholderTokenValue = "placeholder-token"
  1211. var token: String
  1212. var error: Error?
  1213. private init(token: String, error: Error?) {
  1214. self.token = token
  1215. self.error = error
  1216. }
  1217. convenience init(token: String) {
  1218. self.init(token: token, error: nil)
  1219. }
  1220. convenience init(error: Error) {
  1221. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1222. }
  1223. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1224. return AppCheckTokenResultInteropFake(token: token, error: error)
  1225. }
  1226. func tokenDidChangeNotificationName() -> String {
  1227. fatalError("\(#function) not implemented.")
  1228. }
  1229. func notificationTokenKey() -> String {
  1230. fatalError("\(#function) not implemented.")
  1231. }
  1232. func notificationAppNameKey() -> String {
  1233. fatalError("\(#function) not implemented.")
  1234. }
  1235. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1236. var token: String
  1237. var error: Error?
  1238. init(token: String, error: Error?) {
  1239. self.token = token
  1240. self.error = error
  1241. }
  1242. }
  1243. }
  1244. struct AppCheckErrorFake: Error {}
  1245. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1246. extension SafetyRating: Swift.Comparable {
  1247. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1248. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1249. return lhs.category.rawValue < rhs.category.rawValue
  1250. }
  1251. }