GenerativeModelTests.swift 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. var urlSession: URLSession!
  29. var model: GenerativeModel!
  30. override func setUp() async throws {
  31. let configuration = URLSessionConfiguration.default
  32. configuration.protocolClasses = [MockURLProtocol.self]
  33. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  34. model = GenerativeModel(
  35. name: "my-model",
  36. projectID: "my-project-id",
  37. apiKey: "API_KEY",
  38. tools: nil,
  39. requestOptions: RequestOptions(),
  40. appCheck: nil,
  41. auth: nil,
  42. urlSession: urlSession
  43. )
  44. }
  45. override func tearDown() {
  46. MockURLProtocol.requestHandler = nil
  47. }
  48. // MARK: - Generate Content
  49. func testGenerateContent_success_basicReplyLong() async throws {
  50. MockURLProtocol
  51. .requestHandler = try httpRequestHandler(
  52. forResource: "unary-success-basic-reply-long",
  53. withExtension: "json"
  54. )
  55. let response = try await model.generateContent(testPrompt)
  56. XCTAssertEqual(response.candidates.count, 1)
  57. let candidate = try XCTUnwrap(response.candidates.first)
  58. let finishReason = try XCTUnwrap(candidate.finishReason)
  59. XCTAssertEqual(finishReason, .stop)
  60. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  61. XCTAssertEqual(candidate.content.parts.count, 1)
  62. let part = try XCTUnwrap(candidate.content.parts.first)
  63. let partText = try XCTUnwrap(part.text)
  64. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  65. XCTAssertEqual(response.text, partText)
  66. XCTAssertEqual(response.functionCalls, [])
  67. }
  68. func testGenerateContent_success_basicReplyShort() async throws {
  69. MockURLProtocol
  70. .requestHandler = try httpRequestHandler(
  71. forResource: "unary-success-basic-reply-short",
  72. withExtension: "json"
  73. )
  74. let response = try await model.generateContent(testPrompt)
  75. XCTAssertEqual(response.candidates.count, 1)
  76. let candidate = try XCTUnwrap(response.candidates.first)
  77. let finishReason = try XCTUnwrap(candidate.finishReason)
  78. XCTAssertEqual(finishReason, .stop)
  79. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  80. XCTAssertEqual(candidate.content.parts.count, 1)
  81. let part = try XCTUnwrap(candidate.content.parts.first)
  82. XCTAssertEqual(part.text, "Mountain View, California")
  83. XCTAssertEqual(response.text, part.text)
  84. XCTAssertEqual(response.functionCalls, [])
  85. }
  86. func testGenerateContent_success_citations() async throws {
  87. MockURLProtocol
  88. .requestHandler = try httpRequestHandler(
  89. forResource: "unary-success-citations",
  90. withExtension: "json"
  91. )
  92. let response = try await model.generateContent(testPrompt)
  93. XCTAssertEqual(response.candidates.count, 1)
  94. let candidate = try XCTUnwrap(response.candidates.first)
  95. XCTAssertEqual(candidate.content.parts.count, 1)
  96. XCTAssertEqual(response.text, "Some information cited from an external source")
  97. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  98. XCTAssertEqual(citationMetadata.citationSources.count, 3)
  99. let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0])
  100. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  101. XCTAssertEqual(citationSource1.startIndex, 0)
  102. XCTAssertEqual(citationSource1.endIndex, 128)
  103. XCTAssertNil(citationSource1.license)
  104. let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
  105. XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2")
  106. XCTAssertEqual(citationSource2.startIndex, 130)
  107. XCTAssertEqual(citationSource2.endIndex, 265)
  108. XCTAssertNil(citationSource2.license)
  109. let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
  110. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  111. XCTAssertEqual(citationSource3.startIndex, 272)
  112. XCTAssertEqual(citationSource3.endIndex, 431)
  113. XCTAssertEqual(citationSource3.license, "mit")
  114. }
  115. func testGenerateContent_success_quoteReply() async throws {
  116. MockURLProtocol
  117. .requestHandler = try httpRequestHandler(
  118. forResource: "unary-success-quote-reply",
  119. withExtension: "json"
  120. )
  121. let response = try await model.generateContent(testPrompt)
  122. XCTAssertEqual(response.candidates.count, 1)
  123. let candidate = try XCTUnwrap(response.candidates.first)
  124. let finishReason = try XCTUnwrap(candidate.finishReason)
  125. XCTAssertEqual(finishReason, .stop)
  126. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  127. XCTAssertEqual(candidate.content.parts.count, 1)
  128. let part = try XCTUnwrap(candidate.content.parts.first)
  129. let partText = try XCTUnwrap(part.text)
  130. XCTAssertTrue(partText.hasPrefix("Google"))
  131. XCTAssertEqual(response.text, part.text)
  132. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  133. XCTAssertNil(promptFeedback.blockReason)
  134. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  135. }
  136. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  137. let expectedSafetyRatings = [
  138. SafetyRating(category: .harassment, probability: .medium),
  139. SafetyRating(category: .dangerousContent, probability: .unknown),
  140. SafetyRating(category: .unknown, probability: .high),
  141. ]
  142. MockURLProtocol
  143. .requestHandler = try httpRequestHandler(
  144. forResource: "unary-success-unknown-enum-safety-ratings",
  145. withExtension: "json"
  146. )
  147. let response = try await model.generateContent(testPrompt)
  148. XCTAssertEqual(response.text, "Some text")
  149. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  150. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  151. }
  152. func testGenerateContent_success_prefixedModelName() async throws {
  153. MockURLProtocol
  154. .requestHandler = try httpRequestHandler(
  155. forResource: "unary-success-basic-reply-short",
  156. withExtension: "json"
  157. )
  158. let model = GenerativeModel(
  159. // Model name is prefixed with "models/".
  160. name: "models/test-model",
  161. projectID: "my-project-id",
  162. apiKey: "API_KEY",
  163. tools: nil,
  164. requestOptions: RequestOptions(),
  165. appCheck: nil,
  166. auth: nil,
  167. urlSession: urlSession
  168. )
  169. _ = try await model.generateContent(testPrompt)
  170. }
  171. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  172. MockURLProtocol
  173. .requestHandler = try httpRequestHandler(
  174. forResource: "unary-success-function-call-empty-arguments",
  175. withExtension: "json"
  176. )
  177. let response = try await model.generateContent(testPrompt)
  178. XCTAssertEqual(response.candidates.count, 1)
  179. let candidate = try XCTUnwrap(response.candidates.first)
  180. XCTAssertEqual(candidate.content.parts.count, 1)
  181. let part = try XCTUnwrap(candidate.content.parts.first)
  182. guard case let .functionCall(functionCall) = part else {
  183. XCTFail("Part is not a FunctionCall.")
  184. return
  185. }
  186. XCTAssertEqual(functionCall.name, "current_time")
  187. XCTAssertTrue(functionCall.args.isEmpty)
  188. XCTAssertEqual(response.functionCalls, [functionCall])
  189. }
  190. func testGenerateContent_success_functionCall_noArguments() async throws {
  191. MockURLProtocol
  192. .requestHandler = try httpRequestHandler(
  193. forResource: "unary-success-function-call-no-arguments",
  194. withExtension: "json"
  195. )
  196. let response = try await model.generateContent(testPrompt)
  197. XCTAssertEqual(response.candidates.count, 1)
  198. let candidate = try XCTUnwrap(response.candidates.first)
  199. XCTAssertEqual(candidate.content.parts.count, 1)
  200. let part = try XCTUnwrap(candidate.content.parts.first)
  201. guard case let .functionCall(functionCall) = part else {
  202. XCTFail("Part is not a FunctionCall.")
  203. return
  204. }
  205. XCTAssertEqual(functionCall.name, "current_time")
  206. XCTAssertTrue(functionCall.args.isEmpty)
  207. XCTAssertEqual(response.functionCalls, [functionCall])
  208. }
  209. func testGenerateContent_success_functionCall_withArguments() async throws {
  210. MockURLProtocol
  211. .requestHandler = try httpRequestHandler(
  212. forResource: "unary-success-function-call-with-arguments",
  213. withExtension: "json"
  214. )
  215. let response = try await model.generateContent(testPrompt)
  216. XCTAssertEqual(response.candidates.count, 1)
  217. let candidate = try XCTUnwrap(response.candidates.first)
  218. XCTAssertEqual(candidate.content.parts.count, 1)
  219. let part = try XCTUnwrap(candidate.content.parts.first)
  220. guard case let .functionCall(functionCall) = part else {
  221. XCTFail("Part is not a FunctionCall.")
  222. return
  223. }
  224. XCTAssertEqual(functionCall.name, "sum")
  225. XCTAssertEqual(functionCall.args.count, 2)
  226. let argX = try XCTUnwrap(functionCall.args["x"])
  227. XCTAssertEqual(argX, .number(4))
  228. let argY = try XCTUnwrap(functionCall.args["y"])
  229. XCTAssertEqual(argY, .number(5))
  230. XCTAssertEqual(response.functionCalls, [functionCall])
  231. }
  232. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  233. MockURLProtocol
  234. .requestHandler = try httpRequestHandler(
  235. forResource: "unary-success-function-call-parallel-calls",
  236. withExtension: "json"
  237. )
  238. let response = try await model.generateContent(testPrompt)
  239. XCTAssertEqual(response.candidates.count, 1)
  240. let candidate = try XCTUnwrap(response.candidates.first)
  241. XCTAssertEqual(candidate.content.parts.count, 3)
  242. let functionCalls = response.functionCalls
  243. XCTAssertEqual(functionCalls.count, 3)
  244. }
  245. func testGenerateContent_success_functionCall_mixedContent() async throws {
  246. MockURLProtocol
  247. .requestHandler = try httpRequestHandler(
  248. forResource: "unary-success-function-call-mixed-content",
  249. withExtension: "json"
  250. )
  251. let response = try await model.generateContent(testPrompt)
  252. XCTAssertEqual(response.candidates.count, 1)
  253. let candidate = try XCTUnwrap(response.candidates.first)
  254. XCTAssertEqual(candidate.content.parts.count, 4)
  255. let functionCalls = response.functionCalls
  256. XCTAssertEqual(functionCalls.count, 2)
  257. let text = try XCTUnwrap(response.text)
  258. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  259. }
  260. func testGenerateContent_appCheck_validToken() async throws {
  261. let appCheckToken = "test-valid-token"
  262. model = GenerativeModel(
  263. name: "my-model",
  264. projectID: "my-project-id",
  265. apiKey: "API_KEY",
  266. tools: nil,
  267. requestOptions: RequestOptions(),
  268. appCheck: AppCheckInteropFake(token: appCheckToken),
  269. auth: nil,
  270. urlSession: urlSession
  271. )
  272. MockURLProtocol
  273. .requestHandler = try httpRequestHandler(
  274. forResource: "unary-success-basic-reply-short",
  275. withExtension: "json",
  276. appCheckToken: appCheckToken
  277. )
  278. _ = try await model.generateContent(testPrompt)
  279. }
  280. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  281. model = GenerativeModel(
  282. name: "my-model",
  283. projectID: "my-project-id",
  284. apiKey: "API_KEY",
  285. tools: nil,
  286. requestOptions: RequestOptions(),
  287. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  288. auth: nil,
  289. urlSession: urlSession
  290. )
  291. MockURLProtocol
  292. .requestHandler = try httpRequestHandler(
  293. forResource: "unary-success-basic-reply-short",
  294. withExtension: "json",
  295. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_auth_validAuthToken() async throws {
  300. let authToken = "test-valid-token"
  301. model = GenerativeModel(
  302. name: "my-model",
  303. projectID: "my-project-id",
  304. apiKey: "API_KEY",
  305. tools: nil,
  306. requestOptions: RequestOptions(),
  307. appCheck: nil,
  308. auth: AuthInteropFake(token: authToken),
  309. urlSession: urlSession
  310. )
  311. MockURLProtocol
  312. .requestHandler = try httpRequestHandler(
  313. forResource: "unary-success-basic-reply-short",
  314. withExtension: "json",
  315. authToken: authToken
  316. )
  317. _ = try await model.generateContent(testPrompt)
  318. }
  319. func testGenerateContent_auth_nilAuthToken() async throws {
  320. model = GenerativeModel(
  321. name: "my-model",
  322. projectID: "my-project-id",
  323. apiKey: "API_KEY",
  324. tools: nil,
  325. requestOptions: RequestOptions(),
  326. appCheck: nil,
  327. auth: AuthInteropFake(token: nil),
  328. urlSession: urlSession
  329. )
  330. MockURLProtocol
  331. .requestHandler = try httpRequestHandler(
  332. forResource: "unary-success-basic-reply-short",
  333. withExtension: "json",
  334. authToken: nil
  335. )
  336. _ = try await model.generateContent(testPrompt)
  337. }
  338. func testGenerateContent_auth_authTokenRefreshError() async throws {
  339. model = GenerativeModel(
  340. name: "my-model",
  341. projectID: "my-project-id",
  342. apiKey: "API_KEY",
  343. tools: nil,
  344. requestOptions: RequestOptions(),
  345. appCheck: nil,
  346. auth: AuthInteropFake(error: AuthErrorFake()),
  347. urlSession: urlSession
  348. )
  349. MockURLProtocol
  350. .requestHandler = try httpRequestHandler(
  351. forResource: "unary-success-basic-reply-short",
  352. withExtension: "json",
  353. authToken: nil
  354. )
  355. do {
  356. _ = try await model.generateContent(testPrompt)
  357. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  358. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  359. //
  360. } catch {
  361. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  362. }
  363. }
  364. func testGenerateContent_usageMetadata() async throws {
  365. MockURLProtocol
  366. .requestHandler = try httpRequestHandler(
  367. forResource: "unary-success-basic-reply-short",
  368. withExtension: "json"
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  372. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  373. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  374. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  375. }
  376. func testGenerateContent_failure_invalidAPIKey() async throws {
  377. let expectedStatusCode = 400
  378. MockURLProtocol
  379. .requestHandler = try httpRequestHandler(
  380. forResource: "unary-failure-api-key",
  381. withExtension: "json",
  382. statusCode: expectedStatusCode
  383. )
  384. do {
  385. _ = try await model.generateContent(testPrompt)
  386. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  387. } catch let GenerateContentError.internalError(error as RPCError) {
  388. XCTAssertEqual(error.httpResponseCode, 400)
  389. XCTAssertEqual(error.status, .invalidArgument)
  390. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  391. return
  392. } catch {
  393. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  394. }
  395. }
  396. func testGenerateContent_failure_firebaseMLAPINotEnabled() async throws {
  397. let expectedStatusCode = 403
  398. MockURLProtocol
  399. .requestHandler = try httpRequestHandler(
  400. forResource: "unary-failure-firebaseml-api-not-enabled",
  401. withExtension: "json",
  402. statusCode: expectedStatusCode
  403. )
  404. do {
  405. _ = try await model.generateContent(testPrompt)
  406. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  407. } catch let GenerateContentError.internalError(error as RPCError) {
  408. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  409. XCTAssertEqual(error.status, .permissionDenied)
  410. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  411. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  412. return
  413. } catch {
  414. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  415. }
  416. }
  417. func testGenerateContent_failure_emptyContent() async throws {
  418. MockURLProtocol
  419. .requestHandler = try httpRequestHandler(
  420. forResource: "unary-failure-empty-content",
  421. withExtension: "json"
  422. )
  423. do {
  424. _ = try await model.generateContent(testPrompt)
  425. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  426. } catch let GenerateContentError
  427. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  428. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  429. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  430. return
  431. }
  432. _ = try XCTUnwrap(decodingError as? DecodingError,
  433. "Not a DecodingError: \(decodingError)")
  434. } catch {
  435. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  436. }
  437. }
  438. func testGenerateContent_failure_finishReasonSafety() async throws {
  439. MockURLProtocol
  440. .requestHandler = try httpRequestHandler(
  441. forResource: "unary-failure-finish-reason-safety",
  442. withExtension: "json"
  443. )
  444. do {
  445. _ = try await model.generateContent(testPrompt)
  446. XCTFail("Should throw")
  447. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  448. XCTAssertEqual(reason, .safety)
  449. XCTAssertEqual(response.text, "<redacted>")
  450. } catch {
  451. XCTFail("Should throw a responseStoppedEarly")
  452. }
  453. }
  454. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  455. MockURLProtocol
  456. .requestHandler = try httpRequestHandler(
  457. forResource: "unary-failure-finish-reason-safety-no-content",
  458. withExtension: "json"
  459. )
  460. do {
  461. _ = try await model.generateContent(testPrompt)
  462. XCTFail("Should throw")
  463. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  464. XCTAssertEqual(reason, .safety)
  465. XCTAssertNil(response.text)
  466. } catch {
  467. XCTFail("Should throw a responseStoppedEarly")
  468. }
  469. }
  470. func testGenerateContent_failure_imageRejected() async throws {
  471. let expectedStatusCode = 400
  472. MockURLProtocol
  473. .requestHandler = try httpRequestHandler(
  474. forResource: "unary-failure-image-rejected",
  475. withExtension: "json",
  476. statusCode: 400
  477. )
  478. do {
  479. _ = try await model.generateContent(testPrompt)
  480. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  481. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  482. XCTAssertEqual(rpcError.status, .invalidArgument)
  483. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  484. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  485. } catch {
  486. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  487. }
  488. }
  489. func testGenerateContent_failure_promptBlockedSafety() async throws {
  490. MockURLProtocol
  491. .requestHandler = try httpRequestHandler(
  492. forResource: "unary-failure-prompt-blocked-safety",
  493. withExtension: "json"
  494. )
  495. do {
  496. _ = try await model.generateContent(testPrompt)
  497. XCTFail("Should throw")
  498. } catch let GenerateContentError.promptBlocked(response) {
  499. XCTAssertNil(response.text)
  500. } catch {
  501. XCTFail("Should throw a promptBlocked")
  502. }
  503. }
  504. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  505. MockURLProtocol
  506. .requestHandler = try httpRequestHandler(
  507. forResource: "unary-failure-unknown-enum-finish-reason",
  508. withExtension: "json"
  509. )
  510. do {
  511. _ = try await model.generateContent(testPrompt)
  512. XCTFail("Should throw")
  513. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  514. XCTAssertEqual(reason, .unknown)
  515. XCTAssertEqual(response.text, "Some text")
  516. } catch {
  517. XCTFail("Should throw a responseStoppedEarly")
  518. }
  519. }
  520. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  521. MockURLProtocol
  522. .requestHandler = try httpRequestHandler(
  523. forResource: "unary-failure-unknown-enum-prompt-blocked",
  524. withExtension: "json"
  525. )
  526. do {
  527. _ = try await model.generateContent(testPrompt)
  528. XCTFail("Should throw")
  529. } catch let GenerateContentError.promptBlocked(response) {
  530. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  531. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  532. } catch {
  533. XCTFail("Should throw a promptBlocked")
  534. }
  535. }
  536. func testGenerateContent_failure_unknownModel() async throws {
  537. let expectedStatusCode = 404
  538. MockURLProtocol
  539. .requestHandler = try httpRequestHandler(
  540. forResource: "unary-failure-unknown-model",
  541. withExtension: "json",
  542. statusCode: 404
  543. )
  544. do {
  545. _ = try await model.generateContent(testPrompt)
  546. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  547. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  548. XCTAssertEqual(rpcError.status, .notFound)
  549. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  550. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  551. } catch {
  552. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  553. }
  554. }
  555. func testGenerateContent_failure_nonHTTPResponse() async throws {
  556. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  557. var responseError: Error?
  558. var content: GenerateContentResponse?
  559. do {
  560. content = try await model.generateContent(testPrompt)
  561. } catch {
  562. responseError = error
  563. }
  564. XCTAssertNil(content)
  565. XCTAssertNotNil(responseError)
  566. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  567. guard case let .internalError(underlyingError) = generateContentError else {
  568. XCTFail("Not an internal error: \(generateContentError)")
  569. return
  570. }
  571. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  572. }
  573. func testGenerateContent_failure_invalidResponse() async throws {
  574. MockURLProtocol.requestHandler = try httpRequestHandler(
  575. forResource: "unary-failure-invalid-response",
  576. withExtension: "json"
  577. )
  578. var responseError: Error?
  579. var content: GenerateContentResponse?
  580. do {
  581. content = try await model.generateContent(testPrompt)
  582. } catch {
  583. responseError = error
  584. }
  585. XCTAssertNil(content)
  586. XCTAssertNotNil(responseError)
  587. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  588. guard case let .internalError(underlyingError) = generateContentError else {
  589. XCTFail("Not an internal error: \(generateContentError)")
  590. return
  591. }
  592. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  593. guard case let .dataCorrupted(context) = decodingError else {
  594. XCTFail("Not a data corrupted error: \(decodingError)")
  595. return
  596. }
  597. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  598. }
  599. func testGenerateContent_failure_malformedContent() async throws {
  600. MockURLProtocol
  601. .requestHandler = try httpRequestHandler(
  602. forResource: "unary-failure-malformed-content",
  603. withExtension: "json"
  604. )
  605. var responseError: Error?
  606. var content: GenerateContentResponse?
  607. do {
  608. content = try await model.generateContent(testPrompt)
  609. } catch {
  610. responseError = error
  611. }
  612. XCTAssertNil(content)
  613. XCTAssertNotNil(responseError)
  614. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  615. guard case let .internalError(underlyingError) = generateContentError else {
  616. XCTFail("Not an internal error: \(generateContentError)")
  617. return
  618. }
  619. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  620. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  621. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  622. return
  623. }
  624. _ = try XCTUnwrap(
  625. malformedContentUnderlyingError as? DecodingError,
  626. "Not a decoding error: \(malformedContentUnderlyingError)"
  627. )
  628. }
  629. func testGenerateContentMissingSafetyRatings() async throws {
  630. MockURLProtocol.requestHandler = try httpRequestHandler(
  631. forResource: "unary-success-missing-safety-ratings",
  632. withExtension: "json"
  633. )
  634. let content = try await model.generateContent(testPrompt)
  635. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  636. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  637. XCTAssertEqual(content.text, "This is the generated content.")
  638. }
  639. func testGenerateContent_requestOptions_customTimeout() async throws {
  640. let expectedTimeout = 150.0
  641. MockURLProtocol
  642. .requestHandler = try httpRequestHandler(
  643. forResource: "unary-success-basic-reply-short",
  644. withExtension: "json",
  645. timeout: expectedTimeout
  646. )
  647. let requestOptions = RequestOptions(timeout: expectedTimeout)
  648. model = GenerativeModel(
  649. name: "my-model",
  650. projectID: "my-project-id",
  651. apiKey: "API_KEY",
  652. tools: nil,
  653. requestOptions: requestOptions,
  654. appCheck: nil,
  655. auth: nil,
  656. urlSession: urlSession
  657. )
  658. let response = try await model.generateContent(testPrompt)
  659. XCTAssertEqual(response.candidates.count, 1)
  660. }
  661. // MARK: - Generate Content (Streaming)
  662. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  663. MockURLProtocol
  664. .requestHandler = try httpRequestHandler(
  665. forResource: "unary-failure-api-key",
  666. withExtension: "json"
  667. )
  668. do {
  669. let stream = model.generateContentStream("Hi")
  670. for try await _ in stream {
  671. XCTFail("No content is there, this shouldn't happen.")
  672. }
  673. } catch let GenerateContentError.internalError(error as RPCError) {
  674. XCTAssertEqual(error.httpResponseCode, 400)
  675. XCTAssertEqual(error.status, .invalidArgument)
  676. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  677. return
  678. }
  679. XCTFail("Should have caught an error.")
  680. }
  681. func testGenerateContentStream_failure_firebaseMLAPINotEnabled() async throws {
  682. let expectedStatusCode = 403
  683. MockURLProtocol
  684. .requestHandler = try httpRequestHandler(
  685. forResource: "unary-failure-firebaseml-api-not-enabled",
  686. withExtension: "json",
  687. statusCode: expectedStatusCode
  688. )
  689. do {
  690. let stream = model.generateContentStream(testPrompt)
  691. for try await _ in stream {
  692. XCTFail("No content is there, this shouldn't happen.")
  693. }
  694. } catch let GenerateContentError.internalError(error as RPCError) {
  695. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  696. XCTAssertEqual(error.status, .permissionDenied)
  697. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  698. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  699. return
  700. }
  701. XCTFail("Should have caught an error.")
  702. }
  703. func testGenerateContentStream_failureEmptyContent() async throws {
  704. MockURLProtocol
  705. .requestHandler = try httpRequestHandler(
  706. forResource: "streaming-failure-empty-content",
  707. withExtension: "txt"
  708. )
  709. do {
  710. let stream = model.generateContentStream("Hi")
  711. for try await _ in stream {
  712. XCTFail("No content is there, this shouldn't happen.")
  713. }
  714. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  715. // Underlying error is as expected, nothing else to check.
  716. return
  717. }
  718. XCTFail("Should have caught an error.")
  719. }
  720. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  721. MockURLProtocol
  722. .requestHandler = try httpRequestHandler(
  723. forResource: "streaming-failure-finish-reason-safety",
  724. withExtension: "txt"
  725. )
  726. do {
  727. let stream = model.generateContentStream("Hi")
  728. for try await _ in stream {
  729. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  730. }
  731. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  732. XCTAssertEqual(reason, .safety)
  733. return
  734. }
  735. XCTFail("Should have caught an error.")
  736. }
  737. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  738. MockURLProtocol
  739. .requestHandler = try httpRequestHandler(
  740. forResource: "streaming-failure-prompt-blocked-safety",
  741. withExtension: "txt"
  742. )
  743. do {
  744. let stream = model.generateContentStream("Hi")
  745. for try await _ in stream {
  746. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  747. }
  748. } catch let GenerateContentError.promptBlocked(response) {
  749. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  750. return
  751. }
  752. XCTFail("Should have caught an error.")
  753. }
  754. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  755. MockURLProtocol
  756. .requestHandler = try httpRequestHandler(
  757. forResource: "streaming-failure-unknown-finish-enum",
  758. withExtension: "txt"
  759. )
  760. let stream = model.generateContentStream("Hi")
  761. do {
  762. for try await content in stream {
  763. XCTAssertNotNil(content.text)
  764. }
  765. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  766. XCTAssertEqual(reason, .unknown)
  767. return
  768. }
  769. XCTFail("Should have caught an error.")
  770. }
  771. func testGenerateContentStream_successBasicReplyLong() async throws {
  772. MockURLProtocol
  773. .requestHandler = try httpRequestHandler(
  774. forResource: "streaming-success-basic-reply-long",
  775. withExtension: "txt"
  776. )
  777. var responses = 0
  778. let stream = model.generateContentStream("Hi")
  779. for try await content in stream {
  780. XCTAssertNotNil(content.text)
  781. responses += 1
  782. }
  783. XCTAssertEqual(responses, 6)
  784. }
  785. func testGenerateContentStream_successBasicReplyShort() async throws {
  786. MockURLProtocol
  787. .requestHandler = try httpRequestHandler(
  788. forResource: "streaming-success-basic-reply-short",
  789. withExtension: "txt"
  790. )
  791. var responses = 0
  792. let stream = model.generateContentStream("Hi")
  793. for try await content in stream {
  794. XCTAssertNotNil(content.text)
  795. responses += 1
  796. }
  797. XCTAssertEqual(responses, 1)
  798. }
  799. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  800. MockURLProtocol
  801. .requestHandler = try httpRequestHandler(
  802. forResource: "streaming-success-unknown-safety-enum",
  803. withExtension: "txt"
  804. )
  805. var hadUnknown = false
  806. let stream = model.generateContentStream("Hi")
  807. for try await content in stream {
  808. XCTAssertNotNil(content.text)
  809. if let ratings = content.candidates.first?.safetyRatings,
  810. ratings.contains(where: { $0.category == .unknown }) {
  811. hadUnknown = true
  812. }
  813. }
  814. XCTAssertTrue(hadUnknown)
  815. }
  816. func testGenerateContentStream_successWithCitations() async throws {
  817. MockURLProtocol
  818. .requestHandler = try httpRequestHandler(
  819. forResource: "streaming-success-citations",
  820. withExtension: "txt"
  821. )
  822. let stream = model.generateContentStream("Hi")
  823. var citations = [Citation]()
  824. var responses = [GenerateContentResponse]()
  825. for try await content in stream {
  826. responses.append(content)
  827. XCTAssertNotNil(content.text)
  828. let candidate = try XCTUnwrap(content.candidates.first)
  829. if let sources = candidate.citationMetadata?.citationSources {
  830. citations.append(contentsOf: sources)
  831. }
  832. }
  833. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  834. XCTAssertEqual(lastCandidate.finishReason, .stop)
  835. XCTAssertEqual(citations.count, 6)
  836. XCTAssertTrue(citations
  837. .contains(where: {
  838. $0.startIndex == 574 && $0.endIndex == 705 && !$0.uri.isEmpty && $0.license == ""
  839. }))
  840. XCTAssertTrue(citations
  841. .contains(where: {
  842. $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty && $0.license == ""
  843. }))
  844. }
  845. func testGenerateContentStream_appCheck_validToken() async throws {
  846. let appCheckToken = "test-valid-token"
  847. model = GenerativeModel(
  848. name: "my-model",
  849. projectID: "my-project-id",
  850. apiKey: "API_KEY",
  851. tools: nil,
  852. requestOptions: RequestOptions(),
  853. appCheck: AppCheckInteropFake(token: appCheckToken),
  854. auth: nil,
  855. urlSession: urlSession
  856. )
  857. MockURLProtocol
  858. .requestHandler = try httpRequestHandler(
  859. forResource: "streaming-success-basic-reply-short",
  860. withExtension: "txt",
  861. appCheckToken: appCheckToken
  862. )
  863. let stream = model.generateContentStream(testPrompt)
  864. for try await _ in stream {}
  865. }
  866. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  867. model = GenerativeModel(
  868. name: "my-model",
  869. projectID: "my-project-id",
  870. apiKey: "API_KEY",
  871. tools: nil,
  872. requestOptions: RequestOptions(),
  873. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  874. auth: nil,
  875. urlSession: urlSession
  876. )
  877. MockURLProtocol
  878. .requestHandler = try httpRequestHandler(
  879. forResource: "streaming-success-basic-reply-short",
  880. withExtension: "txt",
  881. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  882. )
  883. let stream = model.generateContentStream(testPrompt)
  884. for try await _ in stream {}
  885. }
  886. func testGenerateContentStream_usageMetadata() async throws {
  887. MockURLProtocol
  888. .requestHandler = try httpRequestHandler(
  889. forResource: "streaming-success-basic-reply-short",
  890. withExtension: "txt"
  891. )
  892. var responses = [GenerateContentResponse]()
  893. let stream = model.generateContentStream(testPrompt)
  894. for try await response in stream {
  895. responses.append(response)
  896. }
  897. for (index, response) in responses.enumerated() {
  898. if index == responses.endIndex - 1 {
  899. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  900. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  901. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  902. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  903. } else {
  904. // Only the last streamed response contains usage metadata
  905. XCTAssertNil(response.usageMetadata)
  906. }
  907. }
  908. }
  909. func testGenerateContentStream_errorMidStream() async throws {
  910. MockURLProtocol.requestHandler = try httpRequestHandler(
  911. forResource: "streaming-failure-error-mid-stream",
  912. withExtension: "txt"
  913. )
  914. var responseCount = 0
  915. do {
  916. let stream = model.generateContentStream("Hi")
  917. for try await content in stream {
  918. XCTAssertNotNil(content.text)
  919. responseCount += 1
  920. }
  921. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  922. XCTAssertEqual(rpcError.httpResponseCode, 499)
  923. XCTAssertEqual(rpcError.status, .cancelled)
  924. // Check the content count is correct.
  925. XCTAssertEqual(responseCount, 2)
  926. return
  927. }
  928. XCTFail("Expected an internalError with an RPCError.")
  929. }
  930. func testGenerateContentStream_nonHTTPResponse() async throws {
  931. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  932. let stream = model.generateContentStream("Hi")
  933. do {
  934. for try await content in stream {
  935. XCTFail("Unexpected content in stream: \(content)")
  936. }
  937. } catch let GenerateContentError.internalError(underlying) {
  938. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  939. return
  940. }
  941. XCTFail("Expected an internal error.")
  942. }
  943. func testGenerateContentStream_invalidResponse() async throws {
  944. MockURLProtocol
  945. .requestHandler = try httpRequestHandler(
  946. forResource: "streaming-failure-invalid-json",
  947. withExtension: "txt"
  948. )
  949. let stream = model.generateContentStream(testPrompt)
  950. do {
  951. for try await content in stream {
  952. XCTFail("Unexpected content in stream: \(content)")
  953. }
  954. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  955. guard case let .dataCorrupted(context) = underlying else {
  956. XCTFail("Not a data corrupted error: \(underlying)")
  957. return
  958. }
  959. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  960. return
  961. }
  962. XCTFail("Expected an internal error.")
  963. }
  964. func testGenerateContentStream_malformedContent() async throws {
  965. MockURLProtocol
  966. .requestHandler = try httpRequestHandler(
  967. forResource: "streaming-failure-malformed-content",
  968. withExtension: "txt"
  969. )
  970. let stream = model.generateContentStream(testPrompt)
  971. do {
  972. for try await content in stream {
  973. XCTFail("Unexpected content in stream: \(content)")
  974. }
  975. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  976. guard case let .malformedContent(contentError) = underlyingError else {
  977. XCTFail("Not a malformed content error: \(underlyingError)")
  978. return
  979. }
  980. XCTAssert(contentError is DecodingError)
  981. return
  982. }
  983. XCTFail("Expected an internal decoding error.")
  984. }
  985. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  986. let expectedTimeout = 150.0
  987. MockURLProtocol
  988. .requestHandler = try httpRequestHandler(
  989. forResource: "streaming-success-basic-reply-short",
  990. withExtension: "txt",
  991. timeout: expectedTimeout
  992. )
  993. let requestOptions = RequestOptions(timeout: expectedTimeout)
  994. model = GenerativeModel(
  995. name: "my-model",
  996. projectID: "my-project-id",
  997. apiKey: "API_KEY",
  998. tools: nil,
  999. requestOptions: requestOptions,
  1000. appCheck: nil,
  1001. auth: nil,
  1002. urlSession: urlSession
  1003. )
  1004. var responses = 0
  1005. let stream = model.generateContentStream(testPrompt)
  1006. for try await content in stream {
  1007. XCTAssertNotNil(content.text)
  1008. responses += 1
  1009. }
  1010. XCTAssertEqual(responses, 1)
  1011. }
  1012. // MARK: - Count Tokens
  1013. func testCountTokens_succeeds() async throws {
  1014. MockURLProtocol.requestHandler = try httpRequestHandler(
  1015. forResource: "unary-success-total-tokens",
  1016. withExtension: "json"
  1017. )
  1018. let response = try await model.countTokens("Why is the sky blue?")
  1019. XCTAssertEqual(response.totalTokens, 6)
  1020. XCTAssertEqual(response.totalBillableCharacters, 16)
  1021. }
  1022. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1023. MockURLProtocol.requestHandler = try httpRequestHandler(
  1024. forResource: "unary-success-no-billable-characters",
  1025. withExtension: "json"
  1026. )
  1027. let response = try await model.countTokens(ModelContent.Part.data(
  1028. mimetype: "image/jpeg",
  1029. Data()
  1030. ))
  1031. XCTAssertEqual(response.totalTokens, 258)
  1032. XCTAssertEqual(response.totalBillableCharacters, 0)
  1033. }
  1034. func testCountTokens_modelNotFound() async throws {
  1035. MockURLProtocol.requestHandler = try httpRequestHandler(
  1036. forResource: "unary-failure-model-not-found", withExtension: "json",
  1037. statusCode: 404
  1038. )
  1039. do {
  1040. _ = try await model.countTokens("Why is the sky blue?")
  1041. XCTFail("Request should not have succeeded.")
  1042. } catch let CountTokensError.internalError(rpcError as RPCError) {
  1043. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1044. XCTAssertEqual(rpcError.status, .notFound)
  1045. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1046. return
  1047. }
  1048. XCTFail("Expected internal RPCError.")
  1049. }
  1050. func testCountTokens_requestOptions_customTimeout() async throws {
  1051. let expectedTimeout = 150.0
  1052. MockURLProtocol
  1053. .requestHandler = try httpRequestHandler(
  1054. forResource: "unary-success-total-tokens",
  1055. withExtension: "json",
  1056. timeout: expectedTimeout
  1057. )
  1058. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1059. model = GenerativeModel(
  1060. name: "my-model",
  1061. projectID: "my-project-id",
  1062. apiKey: "API_KEY",
  1063. tools: nil,
  1064. requestOptions: requestOptions,
  1065. appCheck: nil,
  1066. auth: nil,
  1067. urlSession: urlSession
  1068. )
  1069. let response = try await model.countTokens(testPrompt)
  1070. XCTAssertEqual(response.totalTokens, 6)
  1071. }
  1072. // MARK: - Model Resource Name
  1073. func testModelResourceName_noPrefix() async throws {
  1074. let modelName = "my-model"
  1075. let modelResourceName = "models/\(modelName)"
  1076. model = GenerativeModel(
  1077. name: modelName,
  1078. projectID: "my-project-id",
  1079. apiKey: "API_KEY",
  1080. tools: nil,
  1081. requestOptions: RequestOptions(),
  1082. appCheck: nil,
  1083. auth: nil
  1084. )
  1085. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1086. }
  1087. func testModelResourceName_modelsPrefix() async throws {
  1088. let modelResourceName = "models/my-model"
  1089. model = GenerativeModel(
  1090. name: modelResourceName,
  1091. projectID: "my-project-id",
  1092. apiKey: "API_KEY",
  1093. tools: nil,
  1094. requestOptions: RequestOptions(),
  1095. appCheck: nil,
  1096. auth: nil
  1097. )
  1098. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1099. }
  1100. func testModelResourceName_tunedModelsPrefix() async throws {
  1101. let tunedModelResourceName = "tunedModels/my-model"
  1102. model = GenerativeModel(
  1103. name: tunedModelResourceName,
  1104. projectID: "my-project-id",
  1105. apiKey: "API_KEY",
  1106. tools: nil,
  1107. requestOptions: RequestOptions(),
  1108. appCheck: nil,
  1109. auth: nil
  1110. )
  1111. XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
  1112. }
  1113. // MARK: - Helpers
  1114. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1115. URLResponse,
  1116. AsyncLineSequence<URL.AsyncBytes>?
  1117. )) {
  1118. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1119. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1120. guard #unavailable(watchOS 2) else {
  1121. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1122. }
  1123. return { request in
  1124. // This is *not* an HTTPURLResponse
  1125. let response = URLResponse(
  1126. url: request.url!,
  1127. mimeType: nil,
  1128. expectedContentLength: 0,
  1129. textEncodingName: nil
  1130. )
  1131. return (response, nil)
  1132. }
  1133. }
  1134. private func httpRequestHandler(forResource name: String,
  1135. withExtension ext: String,
  1136. statusCode: Int = 200,
  1137. timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
  1138. appCheckToken: String? = nil,
  1139. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1140. URLResponse,
  1141. AsyncLineSequence<URL.AsyncBytes>?
  1142. )) {
  1143. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1144. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1145. guard #unavailable(watchOS 2) else {
  1146. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1147. }
  1148. let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
  1149. return { request in
  1150. let requestURL = try XCTUnwrap(request.url)
  1151. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1152. XCTAssertEqual(request.timeoutInterval, timeout)
  1153. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1154. .components(separatedBy: " ")
  1155. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1156. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1157. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1158. if let authToken {
  1159. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1160. } else {
  1161. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1162. }
  1163. let response = try XCTUnwrap(HTTPURLResponse(
  1164. url: requestURL,
  1165. statusCode: statusCode,
  1166. httpVersion: nil,
  1167. headerFields: nil
  1168. ))
  1169. return (response, fileURL.lines)
  1170. }
  1171. }
  1172. }
  1173. private extension String {
  1174. /// Returns the number of occurrences of `substring` in the `String`.
  1175. func occurrenceCount(of substring: String) -> Int {
  1176. return components(separatedBy: substring).count - 1
  1177. }
  1178. }
  1179. private extension URLRequest {
  1180. /// Returns the default `timeoutInterval` for a `URLRequest`.
  1181. static func defaultTimeoutInterval() -> TimeInterval {
  1182. let placeholderURL = URL(string: "https://example.com")!
  1183. return URLRequest(url: placeholderURL).timeoutInterval
  1184. }
  1185. }
  1186. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1187. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1188. /// The placeholder token value returned when an error occurs
  1189. static let placeholderTokenValue = "placeholder-token"
  1190. var token: String
  1191. var error: Error?
  1192. private init(token: String, error: Error?) {
  1193. self.token = token
  1194. self.error = error
  1195. }
  1196. convenience init(token: String) {
  1197. self.init(token: token, error: nil)
  1198. }
  1199. convenience init(error: Error) {
  1200. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1201. }
  1202. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1203. return AppCheckTokenResultInteropFake(token: token, error: error)
  1204. }
  1205. func tokenDidChangeNotificationName() -> String {
  1206. fatalError("\(#function) not implemented.")
  1207. }
  1208. func notificationTokenKey() -> String {
  1209. fatalError("\(#function) not implemented.")
  1210. }
  1211. func notificationAppNameKey() -> String {
  1212. fatalError("\(#function) not implemented.")
  1213. }
  1214. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1215. var token: String
  1216. var error: Error?
  1217. init(token: String, error: Error?) {
  1218. self.token = token
  1219. self.error = error
  1220. }
  1221. }
  1222. }
  1223. struct AppCheckErrorFake: Error {}
  1224. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1225. extension SafetyRating: Comparable {
  1226. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1227. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1228. return lhs.category.rawValue < rhs.category.rawValue
  1229. }
  1230. }