GenerativeModelTests.swift 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. var urlSession: URLSession!
  29. var model: GenerativeModel!
  30. override func setUp() async throws {
  31. let configuration = URLSessionConfiguration.default
  32. configuration.protocolClasses = [MockURLProtocol.self]
  33. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  34. model = GenerativeModel(
  35. name: "my-model",
  36. apiKey: "API_KEY",
  37. tools: nil,
  38. requestOptions: RequestOptions(),
  39. appCheck: nil,
  40. auth: nil,
  41. urlSession: urlSession
  42. )
  43. }
  44. override func tearDown() {
  45. MockURLProtocol.requestHandler = nil
  46. }
  47. // MARK: - Generate Content
  48. func testGenerateContent_success_basicReplyLong() async throws {
  49. MockURLProtocol
  50. .requestHandler = try httpRequestHandler(
  51. forResource: "unary-success-basic-reply-long",
  52. withExtension: "json"
  53. )
  54. let response = try await model.generateContent(testPrompt)
  55. XCTAssertEqual(response.candidates.count, 1)
  56. let candidate = try XCTUnwrap(response.candidates.first)
  57. let finishReason = try XCTUnwrap(candidate.finishReason)
  58. XCTAssertEqual(finishReason, .stop)
  59. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  60. XCTAssertEqual(candidate.content.parts.count, 1)
  61. let part = try XCTUnwrap(candidate.content.parts.first)
  62. let partText = try XCTUnwrap(part.text)
  63. XCTAssertTrue(partText.hasPrefix("You can ask me a wide range of questions"))
  64. XCTAssertEqual(response.text, partText)
  65. XCTAssertEqual(response.functionCalls, [])
  66. }
  67. func testGenerateContent_success_basicReplyShort() async throws {
  68. MockURLProtocol
  69. .requestHandler = try httpRequestHandler(
  70. forResource: "unary-success-basic-reply-short",
  71. withExtension: "json"
  72. )
  73. let response = try await model.generateContent(testPrompt)
  74. XCTAssertEqual(response.candidates.count, 1)
  75. let candidate = try XCTUnwrap(response.candidates.first)
  76. let finishReason = try XCTUnwrap(candidate.finishReason)
  77. XCTAssertEqual(finishReason, .stop)
  78. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  79. XCTAssertEqual(candidate.content.parts.count, 1)
  80. let part = try XCTUnwrap(candidate.content.parts.first)
  81. XCTAssertEqual(part.text, "Mountain View, California, United States")
  82. XCTAssertEqual(response.text, part.text)
  83. XCTAssertEqual(response.functionCalls, [])
  84. }
  85. func testGenerateContent_success_citations() async throws {
  86. MockURLProtocol
  87. .requestHandler = try httpRequestHandler(
  88. forResource: "unary-success-citations",
  89. withExtension: "json"
  90. )
  91. let response = try await model.generateContent(testPrompt)
  92. XCTAssertEqual(response.candidates.count, 1)
  93. let candidate = try XCTUnwrap(response.candidates.first)
  94. XCTAssertEqual(candidate.content.parts.count, 1)
  95. XCTAssertEqual(response.text, "Some information cited from an external source")
  96. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  97. XCTAssertEqual(citationMetadata.citationSources.count, 3)
  98. let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0])
  99. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  100. XCTAssertEqual(citationSource1.startIndex, 0)
  101. XCTAssertEqual(citationSource1.endIndex, 128)
  102. XCTAssertNil(citationSource1.license)
  103. let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
  104. XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2")
  105. XCTAssertEqual(citationSource2.startIndex, 130)
  106. XCTAssertEqual(citationSource2.endIndex, 265)
  107. XCTAssertNil(citationSource2.license)
  108. let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
  109. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  110. XCTAssertEqual(citationSource3.startIndex, 272)
  111. XCTAssertEqual(citationSource3.endIndex, 431)
  112. XCTAssertEqual(citationSource3.license, "mit")
  113. }
  114. func testGenerateContent_success_quoteReply() async throws {
  115. MockURLProtocol
  116. .requestHandler = try httpRequestHandler(
  117. forResource: "unary-success-quote-reply",
  118. withExtension: "json"
  119. )
  120. let response = try await model.generateContent(testPrompt)
  121. XCTAssertEqual(response.candidates.count, 1)
  122. let candidate = try XCTUnwrap(response.candidates.first)
  123. let finishReason = try XCTUnwrap(candidate.finishReason)
  124. XCTAssertEqual(finishReason, .stop)
  125. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  126. XCTAssertEqual(candidate.content.parts.count, 1)
  127. let part = try XCTUnwrap(candidate.content.parts.first)
  128. let partText = try XCTUnwrap(part.text)
  129. XCTAssertTrue(partText.hasPrefix("Google"))
  130. XCTAssertEqual(response.text, part.text)
  131. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  132. XCTAssertNil(promptFeedback.blockReason)
  133. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  134. }
  135. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  136. let expectedSafetyRatings = [
  137. SafetyRating(category: .harassment, probability: .medium),
  138. SafetyRating(category: .dangerousContent, probability: .unknown),
  139. SafetyRating(category: .unknown, probability: .high),
  140. ]
  141. MockURLProtocol
  142. .requestHandler = try httpRequestHandler(
  143. forResource: "unary-success-unknown-enum-safety-ratings",
  144. withExtension: "json"
  145. )
  146. let response = try await model.generateContent(testPrompt)
  147. XCTAssertEqual(response.text, "Some text")
  148. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  149. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  150. }
  151. func testGenerateContent_success_prefixedModelName() async throws {
  152. MockURLProtocol
  153. .requestHandler = try httpRequestHandler(
  154. forResource: "unary-success-basic-reply-short",
  155. withExtension: "json"
  156. )
  157. let model = GenerativeModel(
  158. // Model name is prefixed with "models/".
  159. name: "models/test-model",
  160. apiKey: "API_KEY",
  161. tools: nil,
  162. requestOptions: RequestOptions(),
  163. appCheck: nil,
  164. auth: nil,
  165. urlSession: urlSession
  166. )
  167. _ = try await model.generateContent(testPrompt)
  168. }
  169. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  170. MockURLProtocol
  171. .requestHandler = try httpRequestHandler(
  172. forResource: "unary-success-function-call-empty-arguments",
  173. withExtension: "json"
  174. )
  175. let response = try await model.generateContent(testPrompt)
  176. XCTAssertEqual(response.candidates.count, 1)
  177. let candidate = try XCTUnwrap(response.candidates.first)
  178. XCTAssertEqual(candidate.content.parts.count, 1)
  179. let part = try XCTUnwrap(candidate.content.parts.first)
  180. guard case let .functionCall(functionCall) = part else {
  181. XCTFail("Part is not a FunctionCall.")
  182. return
  183. }
  184. XCTAssertEqual(functionCall.name, "current_time")
  185. XCTAssertTrue(functionCall.args.isEmpty)
  186. XCTAssertEqual(response.functionCalls, [functionCall])
  187. }
  188. func testGenerateContent_success_functionCall_noArguments() async throws {
  189. MockURLProtocol
  190. .requestHandler = try httpRequestHandler(
  191. forResource: "unary-success-function-call-no-arguments",
  192. withExtension: "json"
  193. )
  194. let response = try await model.generateContent(testPrompt)
  195. XCTAssertEqual(response.candidates.count, 1)
  196. let candidate = try XCTUnwrap(response.candidates.first)
  197. XCTAssertEqual(candidate.content.parts.count, 1)
  198. let part = try XCTUnwrap(candidate.content.parts.first)
  199. guard case let .functionCall(functionCall) = part else {
  200. XCTFail("Part is not a FunctionCall.")
  201. return
  202. }
  203. XCTAssertEqual(functionCall.name, "current_time")
  204. XCTAssertTrue(functionCall.args.isEmpty)
  205. XCTAssertEqual(response.functionCalls, [functionCall])
  206. }
  207. func testGenerateContent_success_functionCall_withArguments() async throws {
  208. MockURLProtocol
  209. .requestHandler = try httpRequestHandler(
  210. forResource: "unary-success-function-call-with-arguments",
  211. withExtension: "json"
  212. )
  213. let response = try await model.generateContent(testPrompt)
  214. XCTAssertEqual(response.candidates.count, 1)
  215. let candidate = try XCTUnwrap(response.candidates.first)
  216. XCTAssertEqual(candidate.content.parts.count, 1)
  217. let part = try XCTUnwrap(candidate.content.parts.first)
  218. guard case let .functionCall(functionCall) = part else {
  219. XCTFail("Part is not a FunctionCall.")
  220. return
  221. }
  222. XCTAssertEqual(functionCall.name, "sum")
  223. XCTAssertEqual(functionCall.args.count, 2)
  224. let argX = try XCTUnwrap(functionCall.args["x"])
  225. XCTAssertEqual(argX, .number(4))
  226. let argY = try XCTUnwrap(functionCall.args["y"])
  227. XCTAssertEqual(argY, .number(5))
  228. XCTAssertEqual(response.functionCalls, [functionCall])
  229. }
  230. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  231. MockURLProtocol
  232. .requestHandler = try httpRequestHandler(
  233. forResource: "unary-success-function-call-parallel-calls",
  234. withExtension: "json"
  235. )
  236. let response = try await model.generateContent(testPrompt)
  237. XCTAssertEqual(response.candidates.count, 1)
  238. let candidate = try XCTUnwrap(response.candidates.first)
  239. XCTAssertEqual(candidate.content.parts.count, 3)
  240. let functionCalls = response.functionCalls
  241. XCTAssertEqual(functionCalls.count, 3)
  242. }
  243. func testGenerateContent_success_functionCall_mixedContent() async throws {
  244. MockURLProtocol
  245. .requestHandler = try httpRequestHandler(
  246. forResource: "unary-success-function-call-mixed-content",
  247. withExtension: "json"
  248. )
  249. let response = try await model.generateContent(testPrompt)
  250. XCTAssertEqual(response.candidates.count, 1)
  251. let candidate = try XCTUnwrap(response.candidates.first)
  252. XCTAssertEqual(candidate.content.parts.count, 4)
  253. let functionCalls = response.functionCalls
  254. XCTAssertEqual(functionCalls.count, 2)
  255. let text = try XCTUnwrap(response.text)
  256. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  257. }
  258. func testGenerateContent_appCheck_validToken() async throws {
  259. let appCheckToken = "test-valid-token"
  260. model = GenerativeModel(
  261. name: "my-model",
  262. apiKey: "API_KEY",
  263. tools: nil,
  264. requestOptions: RequestOptions(),
  265. appCheck: AppCheckInteropFake(token: appCheckToken),
  266. auth: nil,
  267. urlSession: urlSession
  268. )
  269. MockURLProtocol
  270. .requestHandler = try httpRequestHandler(
  271. forResource: "unary-success-basic-reply-short",
  272. withExtension: "json",
  273. appCheckToken: appCheckToken
  274. )
  275. _ = try await model.generateContent(testPrompt)
  276. }
  277. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  278. model = GenerativeModel(
  279. name: "my-model",
  280. apiKey: "API_KEY",
  281. tools: nil,
  282. requestOptions: RequestOptions(),
  283. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  284. auth: nil,
  285. urlSession: urlSession
  286. )
  287. MockURLProtocol
  288. .requestHandler = try httpRequestHandler(
  289. forResource: "unary-success-basic-reply-short",
  290. withExtension: "json",
  291. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  292. )
  293. _ = try await model.generateContent(testPrompt)
  294. }
  295. func testGenerateContent_auth_validAuthToken() async throws {
  296. let authToken = "test-valid-token"
  297. model = GenerativeModel(
  298. name: "my-model",
  299. apiKey: "API_KEY",
  300. tools: nil,
  301. requestOptions: RequestOptions(),
  302. appCheck: nil,
  303. auth: AuthInteropFake(token: authToken),
  304. urlSession: urlSession
  305. )
  306. MockURLProtocol
  307. .requestHandler = try httpRequestHandler(
  308. forResource: "unary-success-basic-reply-short",
  309. withExtension: "json",
  310. authToken: authToken
  311. )
  312. _ = try await model.generateContent(testPrompt)
  313. }
  314. func testGenerateContent_auth_nilAuthToken() async throws {
  315. model = GenerativeModel(
  316. name: "my-model",
  317. apiKey: "API_KEY",
  318. tools: nil,
  319. requestOptions: RequestOptions(),
  320. appCheck: nil,
  321. auth: AuthInteropFake(token: nil),
  322. urlSession: urlSession
  323. )
  324. MockURLProtocol
  325. .requestHandler = try httpRequestHandler(
  326. forResource: "unary-success-basic-reply-short",
  327. withExtension: "json",
  328. authToken: nil
  329. )
  330. _ = try await model.generateContent(testPrompt)
  331. }
  332. func testGenerateContent_auth_authTokenRefreshError() async throws {
  333. model = GenerativeModel(
  334. name: "my-model",
  335. apiKey: "API_KEY",
  336. tools: nil,
  337. requestOptions: RequestOptions(),
  338. appCheck: nil,
  339. auth: AuthInteropFake(error: AuthErrorFake()),
  340. urlSession: urlSession
  341. )
  342. MockURLProtocol
  343. .requestHandler = try httpRequestHandler(
  344. forResource: "unary-success-basic-reply-short",
  345. withExtension: "json",
  346. authToken: nil
  347. )
  348. do {
  349. _ = try await model.generateContent(testPrompt)
  350. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  351. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  352. //
  353. } catch {
  354. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  355. }
  356. }
  357. func testGenerateContent_usageMetadata() async throws {
  358. MockURLProtocol
  359. .requestHandler = try httpRequestHandler(
  360. forResource: "unary-success-basic-reply-short",
  361. withExtension: "json"
  362. )
  363. let response = try await model.generateContent(testPrompt)
  364. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  365. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  366. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  367. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  368. }
  369. func testGenerateContent_failure_invalidAPIKey() async throws {
  370. let expectedStatusCode = 400
  371. MockURLProtocol
  372. .requestHandler = try httpRequestHandler(
  373. forResource: "unary-failure-api-key",
  374. withExtension: "json",
  375. statusCode: expectedStatusCode
  376. )
  377. do {
  378. _ = try await model.generateContent(testPrompt)
  379. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  380. } catch let GenerateContentError.internalError(error as RPCError) {
  381. XCTAssertEqual(error.httpResponseCode, 400)
  382. XCTAssertEqual(error.status, .invalidArgument)
  383. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  384. return
  385. } catch {
  386. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  387. }
  388. }
  389. func testGenerateContent_failure_emptyContent() async throws {
  390. MockURLProtocol
  391. .requestHandler = try httpRequestHandler(
  392. forResource: "unary-failure-empty-content",
  393. withExtension: "json"
  394. )
  395. do {
  396. _ = try await model.generateContent(testPrompt)
  397. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  398. } catch let GenerateContentError
  399. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  400. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  401. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  402. return
  403. }
  404. _ = try XCTUnwrap(decodingError as? DecodingError,
  405. "Not a DecodingError: \(decodingError)")
  406. } catch {
  407. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  408. }
  409. }
  410. func testGenerateContent_failure_finishReasonSafety() async throws {
  411. MockURLProtocol
  412. .requestHandler = try httpRequestHandler(
  413. forResource: "unary-failure-finish-reason-safety",
  414. withExtension: "json"
  415. )
  416. do {
  417. _ = try await model.generateContent(testPrompt)
  418. XCTFail("Should throw")
  419. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  420. XCTAssertEqual(reason, .safety)
  421. XCTAssertEqual(response.text, "No")
  422. } catch {
  423. XCTFail("Should throw a responseStoppedEarly")
  424. }
  425. }
  426. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  427. MockURLProtocol
  428. .requestHandler = try httpRequestHandler(
  429. forResource: "unary-failure-finish-reason-safety-no-content",
  430. withExtension: "json"
  431. )
  432. do {
  433. _ = try await model.generateContent(testPrompt)
  434. XCTFail("Should throw")
  435. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  436. XCTAssertEqual(reason, .safety)
  437. XCTAssertNil(response.text)
  438. } catch {
  439. XCTFail("Should throw a responseStoppedEarly")
  440. }
  441. }
  442. func testGenerateContent_failure_imageRejected() async throws {
  443. let expectedStatusCode = 400
  444. MockURLProtocol
  445. .requestHandler = try httpRequestHandler(
  446. forResource: "unary-failure-image-rejected",
  447. withExtension: "json",
  448. statusCode: 400
  449. )
  450. do {
  451. _ = try await model.generateContent(testPrompt)
  452. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  453. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  454. XCTAssertEqual(rpcError.status, .invalidArgument)
  455. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  456. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  457. } catch {
  458. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  459. }
  460. }
  461. func testGenerateContent_failure_promptBlockedSafety() async throws {
  462. MockURLProtocol
  463. .requestHandler = try httpRequestHandler(
  464. forResource: "unary-failure-prompt-blocked-safety",
  465. withExtension: "json"
  466. )
  467. do {
  468. _ = try await model.generateContent(testPrompt)
  469. XCTFail("Should throw")
  470. } catch let GenerateContentError.promptBlocked(response) {
  471. XCTAssertNil(response.text)
  472. } catch {
  473. XCTFail("Should throw a promptBlocked")
  474. }
  475. }
  476. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  477. MockURLProtocol
  478. .requestHandler = try httpRequestHandler(
  479. forResource: "unary-failure-unknown-enum-finish-reason",
  480. withExtension: "json"
  481. )
  482. do {
  483. _ = try await model.generateContent(testPrompt)
  484. XCTFail("Should throw")
  485. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  486. XCTAssertEqual(reason, .unknown)
  487. XCTAssertEqual(response.text, "Some text")
  488. } catch {
  489. XCTFail("Should throw a responseStoppedEarly")
  490. }
  491. }
  492. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  493. MockURLProtocol
  494. .requestHandler = try httpRequestHandler(
  495. forResource: "unary-failure-unknown-enum-prompt-blocked",
  496. withExtension: "json"
  497. )
  498. do {
  499. _ = try await model.generateContent(testPrompt)
  500. XCTFail("Should throw")
  501. } catch let GenerateContentError.promptBlocked(response) {
  502. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  503. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  504. } catch {
  505. XCTFail("Should throw a promptBlocked")
  506. }
  507. }
  508. func testGenerateContent_failure_unknownModel() async throws {
  509. let expectedStatusCode = 404
  510. MockURLProtocol
  511. .requestHandler = try httpRequestHandler(
  512. forResource: "unary-failure-unknown-model",
  513. withExtension: "json",
  514. statusCode: 404
  515. )
  516. do {
  517. _ = try await model.generateContent(testPrompt)
  518. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  519. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  520. XCTAssertEqual(rpcError.status, .notFound)
  521. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  522. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  523. } catch {
  524. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  525. }
  526. }
  527. func testGenerateContent_failure_nonHTTPResponse() async throws {
  528. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  529. var responseError: Error?
  530. var content: GenerateContentResponse?
  531. do {
  532. content = try await model.generateContent(testPrompt)
  533. } catch {
  534. responseError = error
  535. }
  536. XCTAssertNil(content)
  537. XCTAssertNotNil(responseError)
  538. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  539. guard case let .internalError(underlyingError) = generateContentError else {
  540. XCTFail("Not an internal error: \(generateContentError)")
  541. return
  542. }
  543. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  544. }
  545. func testGenerateContent_failure_invalidResponse() async throws {
  546. MockURLProtocol.requestHandler = try httpRequestHandler(
  547. forResource: "unary-failure-invalid-response",
  548. withExtension: "json"
  549. )
  550. var responseError: Error?
  551. var content: GenerateContentResponse?
  552. do {
  553. content = try await model.generateContent(testPrompt)
  554. } catch {
  555. responseError = error
  556. }
  557. XCTAssertNil(content)
  558. XCTAssertNotNil(responseError)
  559. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  560. guard case let .internalError(underlyingError) = generateContentError else {
  561. XCTFail("Not an internal error: \(generateContentError)")
  562. return
  563. }
  564. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  565. guard case let .dataCorrupted(context) = decodingError else {
  566. XCTFail("Not a data corrupted error: \(decodingError)")
  567. return
  568. }
  569. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  570. }
  571. func testGenerateContent_failure_malformedContent() async throws {
  572. MockURLProtocol
  573. .requestHandler = try httpRequestHandler(
  574. forResource: "unary-failure-malformed-content",
  575. withExtension: "json"
  576. )
  577. var responseError: Error?
  578. var content: GenerateContentResponse?
  579. do {
  580. content = try await model.generateContent(testPrompt)
  581. } catch {
  582. responseError = error
  583. }
  584. XCTAssertNil(content)
  585. XCTAssertNotNil(responseError)
  586. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  587. guard case let .internalError(underlyingError) = generateContentError else {
  588. XCTFail("Not an internal error: \(generateContentError)")
  589. return
  590. }
  591. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  592. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  593. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  594. return
  595. }
  596. _ = try XCTUnwrap(
  597. malformedContentUnderlyingError as? DecodingError,
  598. "Not a decoding error: \(malformedContentUnderlyingError)"
  599. )
  600. }
  601. func testGenerateContentMissingSafetyRatings() async throws {
  602. MockURLProtocol.requestHandler = try httpRequestHandler(
  603. forResource: "unary-success-missing-safety-ratings",
  604. withExtension: "json"
  605. )
  606. let content = try await model.generateContent(testPrompt)
  607. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  608. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  609. XCTAssertEqual(content.text, "This is the generated content.")
  610. }
  611. func testGenerateContent_requestOptions_customTimeout() async throws {
  612. let expectedTimeout = 150.0
  613. MockURLProtocol
  614. .requestHandler = try httpRequestHandler(
  615. forResource: "unary-success-basic-reply-short",
  616. withExtension: "json",
  617. timeout: expectedTimeout
  618. )
  619. let requestOptions = RequestOptions(timeout: expectedTimeout)
  620. model = GenerativeModel(
  621. name: "my-model",
  622. apiKey: "API_KEY",
  623. tools: nil,
  624. requestOptions: requestOptions,
  625. appCheck: nil,
  626. auth: nil,
  627. urlSession: urlSession
  628. )
  629. let response = try await model.generateContent(testPrompt)
  630. XCTAssertEqual(response.candidates.count, 1)
  631. }
  632. // MARK: - Generate Content (Streaming)
  633. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  634. MockURLProtocol
  635. .requestHandler = try httpRequestHandler(
  636. forResource: "unary-failure-api-key",
  637. withExtension: "json"
  638. )
  639. do {
  640. let stream = model.generateContentStream("Hi")
  641. for try await _ in stream {
  642. XCTFail("No content is there, this shouldn't happen.")
  643. }
  644. } catch let GenerateContentError.internalError(error as RPCError) {
  645. XCTAssertEqual(error.httpResponseCode, 400)
  646. XCTAssertEqual(error.status, .invalidArgument)
  647. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  648. return
  649. }
  650. XCTFail("Should have caught an error.")
  651. }
  652. func testGenerateContentStream_failureEmptyContent() async throws {
  653. MockURLProtocol
  654. .requestHandler = try httpRequestHandler(
  655. forResource: "streaming-failure-empty-content",
  656. withExtension: "txt"
  657. )
  658. do {
  659. let stream = model.generateContentStream("Hi")
  660. for try await _ in stream {
  661. XCTFail("No content is there, this shouldn't happen.")
  662. }
  663. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  664. // Underlying error is as expected, nothing else to check.
  665. return
  666. }
  667. XCTFail("Should have caught an error.")
  668. }
  669. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  670. MockURLProtocol
  671. .requestHandler = try httpRequestHandler(
  672. forResource: "streaming-failure-finish-reason-safety",
  673. withExtension: "txt"
  674. )
  675. do {
  676. let stream = model.generateContentStream("Hi")
  677. for try await _ in stream {
  678. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  679. }
  680. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  681. XCTAssertEqual(reason, .safety)
  682. return
  683. }
  684. XCTFail("Should have caught an error.")
  685. }
  686. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  687. MockURLProtocol
  688. .requestHandler = try httpRequestHandler(
  689. forResource: "streaming-failure-prompt-blocked-safety",
  690. withExtension: "txt"
  691. )
  692. do {
  693. let stream = model.generateContentStream("Hi")
  694. for try await _ in stream {
  695. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  696. }
  697. } catch let GenerateContentError.promptBlocked(response) {
  698. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  699. return
  700. }
  701. XCTFail("Should have caught an error.")
  702. }
  703. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  704. MockURLProtocol
  705. .requestHandler = try httpRequestHandler(
  706. forResource: "streaming-failure-unknown-finish-enum",
  707. withExtension: "txt"
  708. )
  709. let stream = model.generateContentStream("Hi")
  710. do {
  711. for try await content in stream {
  712. XCTAssertNotNil(content.text)
  713. }
  714. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  715. XCTAssertEqual(reason, .unknown)
  716. return
  717. }
  718. XCTFail("Should have caught an error.")
  719. }
  720. func testGenerateContentStream_successBasicReplyLong() async throws {
  721. MockURLProtocol
  722. .requestHandler = try httpRequestHandler(
  723. forResource: "streaming-success-basic-reply-long",
  724. withExtension: "txt"
  725. )
  726. var responses = 0
  727. let stream = model.generateContentStream("Hi")
  728. for try await content in stream {
  729. XCTAssertNotNil(content.text)
  730. responses += 1
  731. }
  732. XCTAssertEqual(responses, 8)
  733. }
  734. func testGenerateContentStream_successBasicReplyShort() async throws {
  735. MockURLProtocol
  736. .requestHandler = try httpRequestHandler(
  737. forResource: "streaming-success-basic-reply-short",
  738. withExtension: "txt"
  739. )
  740. var responses = 0
  741. let stream = model.generateContentStream("Hi")
  742. for try await content in stream {
  743. XCTAssertNotNil(content.text)
  744. responses += 1
  745. }
  746. XCTAssertEqual(responses, 1)
  747. }
  748. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  749. MockURLProtocol
  750. .requestHandler = try httpRequestHandler(
  751. forResource: "streaming-success-unknown-safety-enum",
  752. withExtension: "txt"
  753. )
  754. var hadUnknown = false
  755. let stream = model.generateContentStream("Hi")
  756. for try await content in stream {
  757. XCTAssertNotNil(content.text)
  758. if let ratings = content.candidates.first?.safetyRatings,
  759. ratings.contains(where: { $0.category == .unknown }) {
  760. hadUnknown = true
  761. }
  762. }
  763. XCTAssertTrue(hadUnknown)
  764. }
  765. func testGenerateContentStream_successWithCitations() async throws {
  766. MockURLProtocol
  767. .requestHandler = try httpRequestHandler(
  768. forResource: "streaming-success-citations",
  769. withExtension: "txt"
  770. )
  771. let stream = model.generateContentStream("Hi")
  772. var citations = [Citation]()
  773. var responses = [GenerateContentResponse]()
  774. for try await content in stream {
  775. responses.append(content)
  776. XCTAssertNotNil(content.text)
  777. let candidate = try XCTUnwrap(content.candidates.first)
  778. if let sources = candidate.citationMetadata?.citationSources {
  779. citations.append(contentsOf: sources)
  780. }
  781. }
  782. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  783. XCTAssertEqual(lastCandidate.finishReason, .stop)
  784. XCTAssertEqual(citations.count, 3)
  785. XCTAssertTrue(citations
  786. .contains(where: {
  787. $0.startIndex == 0 && $0.endIndex == 128 && !$0.uri.isEmpty && $0.license == nil
  788. }))
  789. XCTAssertTrue(citations
  790. .contains(where: {
  791. $0.startIndex == 130 && $0.endIndex == 265 && !$0.uri.isEmpty && $0.license == nil
  792. }))
  793. XCTAssertTrue(citations
  794. .contains(where: {
  795. $0.startIndex == 272 && $0.endIndex == 431 && !$0.uri.isEmpty && $0.license == "mit"
  796. }))
  797. }
  798. func testGenerateContentStream_appCheck_validToken() async throws {
  799. let appCheckToken = "test-valid-token"
  800. model = GenerativeModel(
  801. name: "my-model",
  802. apiKey: "API_KEY",
  803. tools: nil,
  804. requestOptions: RequestOptions(),
  805. appCheck: AppCheckInteropFake(token: appCheckToken),
  806. auth: nil,
  807. urlSession: urlSession
  808. )
  809. MockURLProtocol
  810. .requestHandler = try httpRequestHandler(
  811. forResource: "streaming-success-basic-reply-short",
  812. withExtension: "txt",
  813. appCheckToken: appCheckToken
  814. )
  815. let stream = model.generateContentStream(testPrompt)
  816. for try await _ in stream {}
  817. }
  818. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  819. model = GenerativeModel(
  820. name: "my-model",
  821. apiKey: "API_KEY",
  822. tools: nil,
  823. requestOptions: RequestOptions(),
  824. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  825. auth: nil,
  826. urlSession: urlSession
  827. )
  828. MockURLProtocol
  829. .requestHandler = try httpRequestHandler(
  830. forResource: "streaming-success-basic-reply-short",
  831. withExtension: "txt",
  832. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  833. )
  834. let stream = model.generateContentStream(testPrompt)
  835. for try await _ in stream {}
  836. }
  837. func testGenerateContentStream_usageMetadata() async throws {
  838. MockURLProtocol
  839. .requestHandler = try httpRequestHandler(
  840. forResource: "streaming-success-basic-reply-short",
  841. withExtension: "txt"
  842. )
  843. var responses = [GenerateContentResponse]()
  844. let stream = model.generateContentStream(testPrompt)
  845. for try await response in stream {
  846. responses.append(response)
  847. }
  848. for (index, response) in responses.enumerated() {
  849. if index == responses.endIndex - 1 {
  850. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  851. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  852. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  853. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  854. } else {
  855. // Only the last streamed response contains usage metadata
  856. XCTAssertNil(response.usageMetadata)
  857. }
  858. }
  859. }
  860. func testGenerateContentStream_errorMidStream() async throws {
  861. MockURLProtocol.requestHandler = try httpRequestHandler(
  862. forResource: "streaming-failure-error-mid-stream",
  863. withExtension: "txt"
  864. )
  865. var responseCount = 0
  866. do {
  867. let stream = model.generateContentStream("Hi")
  868. for try await content in stream {
  869. XCTAssertNotNil(content.text)
  870. responseCount += 1
  871. }
  872. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  873. XCTAssertEqual(rpcError.httpResponseCode, 499)
  874. XCTAssertEqual(rpcError.status, .cancelled)
  875. // Check the content count is correct.
  876. XCTAssertEqual(responseCount, 2)
  877. return
  878. }
  879. XCTFail("Expected an internalError with an RPCError.")
  880. }
  881. func testGenerateContentStream_nonHTTPResponse() async throws {
  882. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  883. let stream = model.generateContentStream("Hi")
  884. do {
  885. for try await content in stream {
  886. XCTFail("Unexpected content in stream: \(content)")
  887. }
  888. } catch let GenerateContentError.internalError(underlying) {
  889. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  890. return
  891. }
  892. XCTFail("Expected an internal error.")
  893. }
  894. func testGenerateContentStream_invalidResponse() async throws {
  895. MockURLProtocol
  896. .requestHandler = try httpRequestHandler(
  897. forResource: "streaming-failure-invalid-json",
  898. withExtension: "txt"
  899. )
  900. let stream = model.generateContentStream(testPrompt)
  901. do {
  902. for try await content in stream {
  903. XCTFail("Unexpected content in stream: \(content)")
  904. }
  905. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  906. guard case let .dataCorrupted(context) = underlying else {
  907. XCTFail("Not a data corrupted error: \(underlying)")
  908. return
  909. }
  910. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  911. return
  912. }
  913. XCTFail("Expected an internal error.")
  914. }
  915. func testGenerateContentStream_malformedContent() async throws {
  916. MockURLProtocol
  917. .requestHandler = try httpRequestHandler(
  918. forResource: "streaming-failure-malformed-content",
  919. withExtension: "txt"
  920. )
  921. let stream = model.generateContentStream(testPrompt)
  922. do {
  923. for try await content in stream {
  924. XCTFail("Unexpected content in stream: \(content)")
  925. }
  926. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  927. guard case let .malformedContent(contentError) = underlyingError else {
  928. XCTFail("Not a malformed content error: \(underlyingError)")
  929. return
  930. }
  931. XCTAssert(contentError is DecodingError)
  932. return
  933. }
  934. XCTFail("Expected an internal decoding error.")
  935. }
  936. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  937. let expectedTimeout = 150.0
  938. MockURLProtocol
  939. .requestHandler = try httpRequestHandler(
  940. forResource: "streaming-success-basic-reply-short",
  941. withExtension: "txt",
  942. timeout: expectedTimeout
  943. )
  944. let requestOptions = RequestOptions(timeout: expectedTimeout)
  945. model = GenerativeModel(
  946. name: "my-model",
  947. apiKey: "API_KEY",
  948. tools: nil,
  949. requestOptions: requestOptions,
  950. appCheck: nil,
  951. auth: nil,
  952. urlSession: urlSession
  953. )
  954. var responses = 0
  955. let stream = model.generateContentStream(testPrompt)
  956. for try await content in stream {
  957. XCTAssertNotNil(content.text)
  958. responses += 1
  959. }
  960. XCTAssertEqual(responses, 1)
  961. }
  962. // MARK: - Count Tokens
  963. func testCountTokens_succeeds() async throws {
  964. MockURLProtocol.requestHandler = try httpRequestHandler(
  965. forResource: "success-total-tokens",
  966. withExtension: "json"
  967. )
  968. let response = try await model.countTokens("Why is the sky blue?")
  969. XCTAssertEqual(response.totalTokens, 6)
  970. XCTAssertEqual(response.totalBillableCharacters, 16)
  971. }
  972. func testCountTokens_succeeds_noBillableCharacters() async throws {
  973. MockURLProtocol.requestHandler = try httpRequestHandler(
  974. forResource: "success-no-billable-characters",
  975. withExtension: "json"
  976. )
  977. let response = try await model.countTokens(ModelContent.Part.data(
  978. mimetype: "image/jpeg",
  979. Data()
  980. ))
  981. XCTAssertEqual(response.totalTokens, 258)
  982. XCTAssertEqual(response.totalBillableCharacters, 0)
  983. }
  984. func testCountTokens_modelNotFound() async throws {
  985. MockURLProtocol.requestHandler = try httpRequestHandler(
  986. forResource: "failure-model-not-found", withExtension: "json",
  987. statusCode: 404
  988. )
  989. do {
  990. _ = try await model.countTokens("Why is the sky blue?")
  991. XCTFail("Request should not have succeeded.")
  992. } catch let CountTokensError.internalError(rpcError as RPCError) {
  993. XCTAssertEqual(rpcError.httpResponseCode, 404)
  994. XCTAssertEqual(rpcError.status, .notFound)
  995. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  996. return
  997. }
  998. XCTFail("Expected internal RPCError.")
  999. }
  1000. func testCountTokens_requestOptions_customTimeout() async throws {
  1001. let expectedTimeout = 150.0
  1002. MockURLProtocol
  1003. .requestHandler = try httpRequestHandler(
  1004. forResource: "success-total-tokens",
  1005. withExtension: "json",
  1006. timeout: expectedTimeout
  1007. )
  1008. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1009. model = GenerativeModel(
  1010. name: "my-model",
  1011. apiKey: "API_KEY",
  1012. tools: nil,
  1013. requestOptions: requestOptions,
  1014. appCheck: nil,
  1015. auth: nil,
  1016. urlSession: urlSession
  1017. )
  1018. let response = try await model.countTokens(testPrompt)
  1019. XCTAssertEqual(response.totalTokens, 6)
  1020. }
  1021. // MARK: - Model Resource Name
  1022. func testModelResourceName_noPrefix() async throws {
  1023. let modelName = "my-model"
  1024. let modelResourceName = "models/\(modelName)"
  1025. model = GenerativeModel(
  1026. name: modelName,
  1027. apiKey: "API_KEY",
  1028. tools: nil,
  1029. requestOptions: RequestOptions(),
  1030. appCheck: nil,
  1031. auth: nil
  1032. )
  1033. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1034. }
  1035. func testModelResourceName_modelsPrefix() async throws {
  1036. let modelResourceName = "models/my-model"
  1037. model = GenerativeModel(
  1038. name: modelResourceName,
  1039. apiKey: "API_KEY",
  1040. tools: nil,
  1041. requestOptions: RequestOptions(),
  1042. appCheck: nil,
  1043. auth: nil
  1044. )
  1045. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1046. }
  1047. func testModelResourceName_tunedModelsPrefix() async throws {
  1048. let tunedModelResourceName = "tunedModels/my-model"
  1049. model = GenerativeModel(
  1050. name: tunedModelResourceName,
  1051. apiKey: "API_KEY",
  1052. tools: nil,
  1053. requestOptions: RequestOptions(),
  1054. appCheck: nil,
  1055. auth: nil
  1056. )
  1057. XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
  1058. }
  1059. // MARK: - Helpers
  1060. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1061. URLResponse,
  1062. AsyncLineSequence<URL.AsyncBytes>?
  1063. )) {
  1064. return { request in
  1065. // This is *not* an HTTPURLResponse
  1066. let response = URLResponse(
  1067. url: request.url!,
  1068. mimeType: nil,
  1069. expectedContentLength: 0,
  1070. textEncodingName: nil
  1071. )
  1072. return (response, nil)
  1073. }
  1074. }
  1075. private func httpRequestHandler(forResource name: String,
  1076. withExtension ext: String,
  1077. statusCode: Int = 200,
  1078. timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
  1079. appCheckToken: String? = nil,
  1080. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1081. URLResponse,
  1082. AsyncLineSequence<URL.AsyncBytes>?
  1083. )) {
  1084. let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
  1085. return { request in
  1086. let requestURL = try XCTUnwrap(request.url)
  1087. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1088. XCTAssertEqual(request.timeoutInterval, timeout)
  1089. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1090. .components(separatedBy: " ")
  1091. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1092. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1093. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1094. if let authToken {
  1095. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1096. } else {
  1097. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1098. }
  1099. let response = try XCTUnwrap(HTTPURLResponse(
  1100. url: requestURL,
  1101. statusCode: statusCode,
  1102. httpVersion: nil,
  1103. headerFields: nil
  1104. ))
  1105. return (response, fileURL.lines)
  1106. }
  1107. }
  1108. }
  1109. private extension String {
  1110. /// Returns the number of occurrences of `substring` in the `String`.
  1111. func occurrenceCount(of substring: String) -> Int {
  1112. return components(separatedBy: substring).count - 1
  1113. }
  1114. }
  1115. private extension URLRequest {
  1116. /// Returns the default `timeoutInterval` for a `URLRequest`.
  1117. static func defaultTimeoutInterval() -> TimeInterval {
  1118. let placeholderURL = URL(string: "https://example.com")!
  1119. return URLRequest(url: placeholderURL).timeoutInterval
  1120. }
  1121. }
  1122. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
  1123. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1124. /// The placeholder token value returned when an error occurs
  1125. static let placeholderTokenValue = "placeholder-token"
  1126. var token: String
  1127. var error: Error?
  1128. private init(token: String, error: Error?) {
  1129. self.token = token
  1130. self.error = error
  1131. }
  1132. convenience init(token: String) {
  1133. self.init(token: token, error: nil)
  1134. }
  1135. convenience init(error: Error) {
  1136. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1137. }
  1138. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1139. return AppCheckTokenResultInteropFake(token: token, error: error)
  1140. }
  1141. func tokenDidChangeNotificationName() -> String {
  1142. fatalError("\(#function) not implemented.")
  1143. }
  1144. func notificationTokenKey() -> String {
  1145. fatalError("\(#function) not implemented.")
  1146. }
  1147. func notificationAppNameKey() -> String {
  1148. fatalError("\(#function) not implemented.")
  1149. }
  1150. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1151. var token: String
  1152. var error: Error?
  1153. init(token: String, error: Error?) {
  1154. self.token = token
  1155. self.error = error
  1156. }
  1157. }
  1158. }
  1159. struct AppCheckErrorFake: Error {}
  1160. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
  1161. extension SafetyRating: Comparable {
  1162. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1163. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1164. return lhs.category.rawValue < rhs.category.rawValue
  1165. }
  1166. }