GenerativeModelTests.swift 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. var urlSession: URLSession!
  29. var model: GenerativeModel!
  30. override func setUp() async throws {
  31. let configuration = URLSessionConfiguration.default
  32. configuration.protocolClasses = [MockURLProtocol.self]
  33. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  34. model = GenerativeModel(
  35. name: "my-model",
  36. projectID: "my-project-id",
  37. apiKey: "API_KEY",
  38. tools: nil,
  39. requestOptions: RequestOptions(),
  40. appCheck: nil,
  41. auth: nil,
  42. urlSession: urlSession
  43. )
  44. }
  45. override func tearDown() {
  46. MockURLProtocol.requestHandler = nil
  47. }
  48. // MARK: - Generate Content
  49. func testGenerateContent_success_basicReplyLong() async throws {
  50. MockURLProtocol
  51. .requestHandler = try httpRequestHandler(
  52. forResource: "unary-success-basic-reply-long",
  53. withExtension: "json"
  54. )
  55. let response = try await model.generateContent(testPrompt)
  56. XCTAssertEqual(response.candidates.count, 1)
  57. let candidate = try XCTUnwrap(response.candidates.first)
  58. let finishReason = try XCTUnwrap(candidate.finishReason)
  59. XCTAssertEqual(finishReason, .stop)
  60. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  61. XCTAssertEqual(candidate.content.parts.count, 1)
  62. let part = try XCTUnwrap(candidate.content.parts.first)
  63. let partText = try XCTUnwrap(part.text)
  64. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  65. XCTAssertEqual(response.text, partText)
  66. XCTAssertEqual(response.functionCalls, [])
  67. }
  68. func testGenerateContent_success_basicReplyShort() async throws {
  69. MockURLProtocol
  70. .requestHandler = try httpRequestHandler(
  71. forResource: "unary-success-basic-reply-short",
  72. withExtension: "json"
  73. )
  74. let response = try await model.generateContent(testPrompt)
  75. XCTAssertEqual(response.candidates.count, 1)
  76. let candidate = try XCTUnwrap(response.candidates.first)
  77. let finishReason = try XCTUnwrap(candidate.finishReason)
  78. XCTAssertEqual(finishReason, .stop)
  79. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  80. XCTAssertEqual(candidate.content.parts.count, 1)
  81. let part = try XCTUnwrap(candidate.content.parts.first)
  82. XCTAssertEqual(part.text, "Mountain View, California")
  83. XCTAssertEqual(response.text, part.text)
  84. XCTAssertEqual(response.functionCalls, [])
  85. }
  86. func testGenerateContent_success_citations() async throws {
  87. MockURLProtocol
  88. .requestHandler = try httpRequestHandler(
  89. forResource: "unary-success-citations",
  90. withExtension: "json"
  91. )
  92. let response = try await model.generateContent(testPrompt)
  93. XCTAssertEqual(response.candidates.count, 1)
  94. let candidate = try XCTUnwrap(response.candidates.first)
  95. XCTAssertEqual(candidate.content.parts.count, 1)
  96. XCTAssertEqual(response.text, "Some information cited from an external source")
  97. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  98. XCTAssertEqual(citationMetadata.citationSources.count, 3)
  99. let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0])
  100. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  101. XCTAssertEqual(citationSource1.startIndex, 0)
  102. XCTAssertEqual(citationSource1.endIndex, 128)
  103. XCTAssertNil(citationSource1.title)
  104. XCTAssertNil(citationSource1.license)
  105. let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
  106. XCTAssertEqual(citationSource2.title, "some-citation-2")
  107. XCTAssertEqual(citationSource2.startIndex, 130)
  108. XCTAssertEqual(citationSource2.endIndex, 265)
  109. XCTAssertNil(citationSource2.uri)
  110. XCTAssertNil(citationSource2.license)
  111. let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
  112. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  113. XCTAssertEqual(citationSource3.startIndex, 272)
  114. XCTAssertEqual(citationSource3.endIndex, 431)
  115. XCTAssertEqual(citationSource3.license, "mit")
  116. XCTAssertNil(citationSource3.title)
  117. }
  118. func testGenerateContent_success_quoteReply() async throws {
  119. MockURLProtocol
  120. .requestHandler = try httpRequestHandler(
  121. forResource: "unary-success-quote-reply",
  122. withExtension: "json"
  123. )
  124. let response = try await model.generateContent(testPrompt)
  125. XCTAssertEqual(response.candidates.count, 1)
  126. let candidate = try XCTUnwrap(response.candidates.first)
  127. let finishReason = try XCTUnwrap(candidate.finishReason)
  128. XCTAssertEqual(finishReason, .stop)
  129. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  130. XCTAssertEqual(candidate.content.parts.count, 1)
  131. let part = try XCTUnwrap(candidate.content.parts.first)
  132. let partText = try XCTUnwrap(part.text)
  133. XCTAssertTrue(partText.hasPrefix("Google"))
  134. XCTAssertEqual(response.text, part.text)
  135. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  136. XCTAssertNil(promptFeedback.blockReason)
  137. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  138. }
  139. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  140. let expectedSafetyRatings = [
  141. SafetyRating(category: .harassment, probability: .medium),
  142. SafetyRating(category: .dangerousContent, probability: .unknown),
  143. SafetyRating(category: .unknown, probability: .high),
  144. ]
  145. MockURLProtocol
  146. .requestHandler = try httpRequestHandler(
  147. forResource: "unary-success-unknown-enum-safety-ratings",
  148. withExtension: "json"
  149. )
  150. let response = try await model.generateContent(testPrompt)
  151. XCTAssertEqual(response.text, "Some text")
  152. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  153. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  154. }
  155. func testGenerateContent_success_prefixedModelName() async throws {
  156. MockURLProtocol
  157. .requestHandler = try httpRequestHandler(
  158. forResource: "unary-success-basic-reply-short",
  159. withExtension: "json"
  160. )
  161. let model = GenerativeModel(
  162. // Model name is prefixed with "models/".
  163. name: "models/test-model",
  164. projectID: "my-project-id",
  165. apiKey: "API_KEY",
  166. tools: nil,
  167. requestOptions: RequestOptions(),
  168. appCheck: nil,
  169. auth: nil,
  170. urlSession: urlSession
  171. )
  172. _ = try await model.generateContent(testPrompt)
  173. }
  174. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  175. MockURLProtocol
  176. .requestHandler = try httpRequestHandler(
  177. forResource: "unary-success-function-call-empty-arguments",
  178. withExtension: "json"
  179. )
  180. let response = try await model.generateContent(testPrompt)
  181. XCTAssertEqual(response.candidates.count, 1)
  182. let candidate = try XCTUnwrap(response.candidates.first)
  183. XCTAssertEqual(candidate.content.parts.count, 1)
  184. let part = try XCTUnwrap(candidate.content.parts.first)
  185. guard case let .functionCall(functionCall) = part else {
  186. XCTFail("Part is not a FunctionCall.")
  187. return
  188. }
  189. XCTAssertEqual(functionCall.name, "current_time")
  190. XCTAssertTrue(functionCall.args.isEmpty)
  191. XCTAssertEqual(response.functionCalls, [functionCall])
  192. }
  193. func testGenerateContent_success_functionCall_noArguments() async throws {
  194. MockURLProtocol
  195. .requestHandler = try httpRequestHandler(
  196. forResource: "unary-success-function-call-no-arguments",
  197. withExtension: "json"
  198. )
  199. let response = try await model.generateContent(testPrompt)
  200. XCTAssertEqual(response.candidates.count, 1)
  201. let candidate = try XCTUnwrap(response.candidates.first)
  202. XCTAssertEqual(candidate.content.parts.count, 1)
  203. let part = try XCTUnwrap(candidate.content.parts.first)
  204. guard case let .functionCall(functionCall) = part else {
  205. XCTFail("Part is not a FunctionCall.")
  206. return
  207. }
  208. XCTAssertEqual(functionCall.name, "current_time")
  209. XCTAssertTrue(functionCall.args.isEmpty)
  210. XCTAssertEqual(response.functionCalls, [functionCall])
  211. }
  212. func testGenerateContent_success_functionCall_withArguments() async throws {
  213. MockURLProtocol
  214. .requestHandler = try httpRequestHandler(
  215. forResource: "unary-success-function-call-with-arguments",
  216. withExtension: "json"
  217. )
  218. let response = try await model.generateContent(testPrompt)
  219. XCTAssertEqual(response.candidates.count, 1)
  220. let candidate = try XCTUnwrap(response.candidates.first)
  221. XCTAssertEqual(candidate.content.parts.count, 1)
  222. let part = try XCTUnwrap(candidate.content.parts.first)
  223. guard case let .functionCall(functionCall) = part else {
  224. XCTFail("Part is not a FunctionCall.")
  225. return
  226. }
  227. XCTAssertEqual(functionCall.name, "sum")
  228. XCTAssertEqual(functionCall.args.count, 2)
  229. let argX = try XCTUnwrap(functionCall.args["x"])
  230. XCTAssertEqual(argX, .number(4))
  231. let argY = try XCTUnwrap(functionCall.args["y"])
  232. XCTAssertEqual(argY, .number(5))
  233. XCTAssertEqual(response.functionCalls, [functionCall])
  234. }
  235. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  236. MockURLProtocol
  237. .requestHandler = try httpRequestHandler(
  238. forResource: "unary-success-function-call-parallel-calls",
  239. withExtension: "json"
  240. )
  241. let response = try await model.generateContent(testPrompt)
  242. XCTAssertEqual(response.candidates.count, 1)
  243. let candidate = try XCTUnwrap(response.candidates.first)
  244. XCTAssertEqual(candidate.content.parts.count, 3)
  245. let functionCalls = response.functionCalls
  246. XCTAssertEqual(functionCalls.count, 3)
  247. }
  248. func testGenerateContent_success_functionCall_mixedContent() async throws {
  249. MockURLProtocol
  250. .requestHandler = try httpRequestHandler(
  251. forResource: "unary-success-function-call-mixed-content",
  252. withExtension: "json"
  253. )
  254. let response = try await model.generateContent(testPrompt)
  255. XCTAssertEqual(response.candidates.count, 1)
  256. let candidate = try XCTUnwrap(response.candidates.first)
  257. XCTAssertEqual(candidate.content.parts.count, 4)
  258. let functionCalls = response.functionCalls
  259. XCTAssertEqual(functionCalls.count, 2)
  260. let text = try XCTUnwrap(response.text)
  261. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  262. }
  263. func testGenerateContent_appCheck_validToken() async throws {
  264. let appCheckToken = "test-valid-token"
  265. model = GenerativeModel(
  266. name: "my-model",
  267. projectID: "my-project-id",
  268. apiKey: "API_KEY",
  269. tools: nil,
  270. requestOptions: RequestOptions(),
  271. appCheck: AppCheckInteropFake(token: appCheckToken),
  272. auth: nil,
  273. urlSession: urlSession
  274. )
  275. MockURLProtocol
  276. .requestHandler = try httpRequestHandler(
  277. forResource: "unary-success-basic-reply-short",
  278. withExtension: "json",
  279. appCheckToken: appCheckToken
  280. )
  281. _ = try await model.generateContent(testPrompt)
  282. }
  283. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  284. model = GenerativeModel(
  285. name: "my-model",
  286. projectID: "my-project-id",
  287. apiKey: "API_KEY",
  288. tools: nil,
  289. requestOptions: RequestOptions(),
  290. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  291. auth: nil,
  292. urlSession: urlSession
  293. )
  294. MockURLProtocol
  295. .requestHandler = try httpRequestHandler(
  296. forResource: "unary-success-basic-reply-short",
  297. withExtension: "json",
  298. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  299. )
  300. _ = try await model.generateContent(testPrompt)
  301. }
  302. func testGenerateContent_auth_validAuthToken() async throws {
  303. let authToken = "test-valid-token"
  304. model = GenerativeModel(
  305. name: "my-model",
  306. projectID: "my-project-id",
  307. apiKey: "API_KEY",
  308. tools: nil,
  309. requestOptions: RequestOptions(),
  310. appCheck: nil,
  311. auth: AuthInteropFake(token: authToken),
  312. urlSession: urlSession
  313. )
  314. MockURLProtocol
  315. .requestHandler = try httpRequestHandler(
  316. forResource: "unary-success-basic-reply-short",
  317. withExtension: "json",
  318. authToken: authToken
  319. )
  320. _ = try await model.generateContent(testPrompt)
  321. }
  322. func testGenerateContent_auth_nilAuthToken() async throws {
  323. model = GenerativeModel(
  324. name: "my-model",
  325. projectID: "my-project-id",
  326. apiKey: "API_KEY",
  327. tools: nil,
  328. requestOptions: RequestOptions(),
  329. appCheck: nil,
  330. auth: AuthInteropFake(token: nil),
  331. urlSession: urlSession
  332. )
  333. MockURLProtocol
  334. .requestHandler = try httpRequestHandler(
  335. forResource: "unary-success-basic-reply-short",
  336. withExtension: "json",
  337. authToken: nil
  338. )
  339. _ = try await model.generateContent(testPrompt)
  340. }
  341. func testGenerateContent_auth_authTokenRefreshError() async throws {
  342. model = GenerativeModel(
  343. name: "my-model",
  344. projectID: "my-project-id",
  345. apiKey: "API_KEY",
  346. tools: nil,
  347. requestOptions: RequestOptions(),
  348. appCheck: nil,
  349. auth: AuthInteropFake(error: AuthErrorFake()),
  350. urlSession: urlSession
  351. )
  352. MockURLProtocol
  353. .requestHandler = try httpRequestHandler(
  354. forResource: "unary-success-basic-reply-short",
  355. withExtension: "json",
  356. authToken: nil
  357. )
  358. do {
  359. _ = try await model.generateContent(testPrompt)
  360. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  361. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  362. //
  363. } catch {
  364. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  365. }
  366. }
  367. func testGenerateContent_usageMetadata() async throws {
  368. MockURLProtocol
  369. .requestHandler = try httpRequestHandler(
  370. forResource: "unary-success-basic-reply-short",
  371. withExtension: "json"
  372. )
  373. let response = try await model.generateContent(testPrompt)
  374. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  375. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  376. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  377. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  378. }
  379. func testGenerateContent_failure_invalidAPIKey() async throws {
  380. let expectedStatusCode = 400
  381. MockURLProtocol
  382. .requestHandler = try httpRequestHandler(
  383. forResource: "unary-failure-api-key",
  384. withExtension: "json",
  385. statusCode: expectedStatusCode
  386. )
  387. do {
  388. _ = try await model.generateContent(testPrompt)
  389. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  390. } catch let GenerateContentError.internalError(error as RPCError) {
  391. XCTAssertEqual(error.httpResponseCode, 400)
  392. XCTAssertEqual(error.status, .invalidArgument)
  393. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  394. return
  395. } catch {
  396. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  397. }
  398. }
  399. func testGenerateContent_failure_firebaseMLAPINotEnabled() async throws {
  400. let expectedStatusCode = 403
  401. MockURLProtocol
  402. .requestHandler = try httpRequestHandler(
  403. forResource: "unary-failure-firebaseml-api-not-enabled",
  404. withExtension: "json",
  405. statusCode: expectedStatusCode
  406. )
  407. do {
  408. _ = try await model.generateContent(testPrompt)
  409. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  410. } catch let GenerateContentError.internalError(error as RPCError) {
  411. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  412. XCTAssertEqual(error.status, .permissionDenied)
  413. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  414. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  415. return
  416. } catch {
  417. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  418. }
  419. }
  420. func testGenerateContent_failure_emptyContent() async throws {
  421. MockURLProtocol
  422. .requestHandler = try httpRequestHandler(
  423. forResource: "unary-failure-empty-content",
  424. withExtension: "json"
  425. )
  426. do {
  427. _ = try await model.generateContent(testPrompt)
  428. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  429. } catch let GenerateContentError
  430. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  431. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  432. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  433. return
  434. }
  435. _ = try XCTUnwrap(decodingError as? DecodingError,
  436. "Not a DecodingError: \(decodingError)")
  437. } catch {
  438. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  439. }
  440. }
  441. func testGenerateContent_failure_finishReasonSafety() async throws {
  442. MockURLProtocol
  443. .requestHandler = try httpRequestHandler(
  444. forResource: "unary-failure-finish-reason-safety",
  445. withExtension: "json"
  446. )
  447. do {
  448. _ = try await model.generateContent(testPrompt)
  449. XCTFail("Should throw")
  450. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  451. XCTAssertEqual(reason, .safety)
  452. XCTAssertEqual(response.text, "<redacted>")
  453. } catch {
  454. XCTFail("Should throw a responseStoppedEarly")
  455. }
  456. }
  457. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  458. MockURLProtocol
  459. .requestHandler = try httpRequestHandler(
  460. forResource: "unary-failure-finish-reason-safety-no-content",
  461. withExtension: "json"
  462. )
  463. do {
  464. _ = try await model.generateContent(testPrompt)
  465. XCTFail("Should throw")
  466. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  467. XCTAssertEqual(reason, .safety)
  468. XCTAssertNil(response.text)
  469. } catch {
  470. XCTFail("Should throw a responseStoppedEarly")
  471. }
  472. }
  473. func testGenerateContent_failure_imageRejected() async throws {
  474. let expectedStatusCode = 400
  475. MockURLProtocol
  476. .requestHandler = try httpRequestHandler(
  477. forResource: "unary-failure-image-rejected",
  478. withExtension: "json",
  479. statusCode: 400
  480. )
  481. do {
  482. _ = try await model.generateContent(testPrompt)
  483. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  484. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  485. XCTAssertEqual(rpcError.status, .invalidArgument)
  486. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  487. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  488. } catch {
  489. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  490. }
  491. }
  492. func testGenerateContent_failure_promptBlockedSafety() async throws {
  493. MockURLProtocol
  494. .requestHandler = try httpRequestHandler(
  495. forResource: "unary-failure-prompt-blocked-safety",
  496. withExtension: "json"
  497. )
  498. do {
  499. _ = try await model.generateContent(testPrompt)
  500. XCTFail("Should throw")
  501. } catch let GenerateContentError.promptBlocked(response) {
  502. XCTAssertNil(response.text)
  503. } catch {
  504. XCTFail("Should throw a promptBlocked")
  505. }
  506. }
  507. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  508. MockURLProtocol
  509. .requestHandler = try httpRequestHandler(
  510. forResource: "unary-failure-unknown-enum-finish-reason",
  511. withExtension: "json"
  512. )
  513. do {
  514. _ = try await model.generateContent(testPrompt)
  515. XCTFail("Should throw")
  516. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  517. XCTAssertEqual(reason, .unknown)
  518. XCTAssertEqual(response.text, "Some text")
  519. } catch {
  520. XCTFail("Should throw a responseStoppedEarly")
  521. }
  522. }
  523. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  524. MockURLProtocol
  525. .requestHandler = try httpRequestHandler(
  526. forResource: "unary-failure-unknown-enum-prompt-blocked",
  527. withExtension: "json"
  528. )
  529. do {
  530. _ = try await model.generateContent(testPrompt)
  531. XCTFail("Should throw")
  532. } catch let GenerateContentError.promptBlocked(response) {
  533. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  534. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  535. } catch {
  536. XCTFail("Should throw a promptBlocked")
  537. }
  538. }
  539. func testGenerateContent_failure_unknownModel() async throws {
  540. let expectedStatusCode = 404
  541. MockURLProtocol
  542. .requestHandler = try httpRequestHandler(
  543. forResource: "unary-failure-unknown-model",
  544. withExtension: "json",
  545. statusCode: 404
  546. )
  547. do {
  548. _ = try await model.generateContent(testPrompt)
  549. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  550. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  551. XCTAssertEqual(rpcError.status, .notFound)
  552. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  553. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  554. } catch {
  555. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  556. }
  557. }
  558. func testGenerateContent_failure_nonHTTPResponse() async throws {
  559. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  560. var responseError: Error?
  561. var content: GenerateContentResponse?
  562. do {
  563. content = try await model.generateContent(testPrompt)
  564. } catch {
  565. responseError = error
  566. }
  567. XCTAssertNil(content)
  568. XCTAssertNotNil(responseError)
  569. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  570. guard case let .internalError(underlyingError) = generateContentError else {
  571. XCTFail("Not an internal error: \(generateContentError)")
  572. return
  573. }
  574. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  575. }
  576. func testGenerateContent_failure_invalidResponse() async throws {
  577. MockURLProtocol.requestHandler = try httpRequestHandler(
  578. forResource: "unary-failure-invalid-response",
  579. withExtension: "json"
  580. )
  581. var responseError: Error?
  582. var content: GenerateContentResponse?
  583. do {
  584. content = try await model.generateContent(testPrompt)
  585. } catch {
  586. responseError = error
  587. }
  588. XCTAssertNil(content)
  589. XCTAssertNotNil(responseError)
  590. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  591. guard case let .internalError(underlyingError) = generateContentError else {
  592. XCTFail("Not an internal error: \(generateContentError)")
  593. return
  594. }
  595. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  596. guard case let .dataCorrupted(context) = decodingError else {
  597. XCTFail("Not a data corrupted error: \(decodingError)")
  598. return
  599. }
  600. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  601. }
  602. func testGenerateContent_failure_malformedContent() async throws {
  603. MockURLProtocol
  604. .requestHandler = try httpRequestHandler(
  605. forResource: "unary-failure-malformed-content",
  606. withExtension: "json"
  607. )
  608. var responseError: Error?
  609. var content: GenerateContentResponse?
  610. do {
  611. content = try await model.generateContent(testPrompt)
  612. } catch {
  613. responseError = error
  614. }
  615. XCTAssertNil(content)
  616. XCTAssertNotNil(responseError)
  617. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  618. guard case let .internalError(underlyingError) = generateContentError else {
  619. XCTFail("Not an internal error: \(generateContentError)")
  620. return
  621. }
  622. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  623. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  624. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  625. return
  626. }
  627. _ = try XCTUnwrap(
  628. malformedContentUnderlyingError as? DecodingError,
  629. "Not a decoding error: \(malformedContentUnderlyingError)"
  630. )
  631. }
  632. func testGenerateContentMissingSafetyRatings() async throws {
  633. MockURLProtocol.requestHandler = try httpRequestHandler(
  634. forResource: "unary-success-missing-safety-ratings",
  635. withExtension: "json"
  636. )
  637. let content = try await model.generateContent(testPrompt)
  638. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  639. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  640. XCTAssertEqual(content.text, "This is the generated content.")
  641. }
  642. func testGenerateContent_requestOptions_customTimeout() async throws {
  643. let expectedTimeout = 150.0
  644. MockURLProtocol
  645. .requestHandler = try httpRequestHandler(
  646. forResource: "unary-success-basic-reply-short",
  647. withExtension: "json",
  648. timeout: expectedTimeout
  649. )
  650. let requestOptions = RequestOptions(timeout: expectedTimeout)
  651. model = GenerativeModel(
  652. name: "my-model",
  653. projectID: "my-project-id",
  654. apiKey: "API_KEY",
  655. tools: nil,
  656. requestOptions: requestOptions,
  657. appCheck: nil,
  658. auth: nil,
  659. urlSession: urlSession
  660. )
  661. let response = try await model.generateContent(testPrompt)
  662. XCTAssertEqual(response.candidates.count, 1)
  663. }
  664. // MARK: - Generate Content (Streaming)
  665. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  666. MockURLProtocol
  667. .requestHandler = try httpRequestHandler(
  668. forResource: "unary-failure-api-key",
  669. withExtension: "json"
  670. )
  671. do {
  672. let stream = model.generateContentStream("Hi")
  673. for try await _ in stream {
  674. XCTFail("No content is there, this shouldn't happen.")
  675. }
  676. } catch let GenerateContentError.internalError(error as RPCError) {
  677. XCTAssertEqual(error.httpResponseCode, 400)
  678. XCTAssertEqual(error.status, .invalidArgument)
  679. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  680. return
  681. }
  682. XCTFail("Should have caught an error.")
  683. }
  684. func testGenerateContentStream_failure_firebaseMLAPINotEnabled() async throws {
  685. let expectedStatusCode = 403
  686. MockURLProtocol
  687. .requestHandler = try httpRequestHandler(
  688. forResource: "unary-failure-firebaseml-api-not-enabled",
  689. withExtension: "json",
  690. statusCode: expectedStatusCode
  691. )
  692. do {
  693. let stream = model.generateContentStream(testPrompt)
  694. for try await _ in stream {
  695. XCTFail("No content is there, this shouldn't happen.")
  696. }
  697. } catch let GenerateContentError.internalError(error as RPCError) {
  698. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  699. XCTAssertEqual(error.status, .permissionDenied)
  700. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  701. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  702. return
  703. }
  704. XCTFail("Should have caught an error.")
  705. }
  706. func testGenerateContentStream_failureEmptyContent() async throws {
  707. MockURLProtocol
  708. .requestHandler = try httpRequestHandler(
  709. forResource: "streaming-failure-empty-content",
  710. withExtension: "txt"
  711. )
  712. do {
  713. let stream = model.generateContentStream("Hi")
  714. for try await _ in stream {
  715. XCTFail("No content is there, this shouldn't happen.")
  716. }
  717. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  718. // Underlying error is as expected, nothing else to check.
  719. return
  720. }
  721. XCTFail("Should have caught an error.")
  722. }
  723. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  724. MockURLProtocol
  725. .requestHandler = try httpRequestHandler(
  726. forResource: "streaming-failure-finish-reason-safety",
  727. withExtension: "txt"
  728. )
  729. do {
  730. let stream = model.generateContentStream("Hi")
  731. for try await _ in stream {
  732. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  733. }
  734. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  735. XCTAssertEqual(reason, .safety)
  736. return
  737. }
  738. XCTFail("Should have caught an error.")
  739. }
  740. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  741. MockURLProtocol
  742. .requestHandler = try httpRequestHandler(
  743. forResource: "streaming-failure-prompt-blocked-safety",
  744. withExtension: "txt"
  745. )
  746. do {
  747. let stream = model.generateContentStream("Hi")
  748. for try await _ in stream {
  749. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  750. }
  751. } catch let GenerateContentError.promptBlocked(response) {
  752. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  753. return
  754. }
  755. XCTFail("Should have caught an error.")
  756. }
  757. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  758. MockURLProtocol
  759. .requestHandler = try httpRequestHandler(
  760. forResource: "streaming-failure-unknown-finish-enum",
  761. withExtension: "txt"
  762. )
  763. let stream = model.generateContentStream("Hi")
  764. do {
  765. for try await content in stream {
  766. XCTAssertNotNil(content.text)
  767. }
  768. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  769. XCTAssertEqual(reason, .unknown)
  770. return
  771. }
  772. XCTFail("Should have caught an error.")
  773. }
  774. func testGenerateContentStream_successBasicReplyLong() async throws {
  775. MockURLProtocol
  776. .requestHandler = try httpRequestHandler(
  777. forResource: "streaming-success-basic-reply-long",
  778. withExtension: "txt"
  779. )
  780. var responses = 0
  781. let stream = model.generateContentStream("Hi")
  782. for try await content in stream {
  783. XCTAssertNotNil(content.text)
  784. responses += 1
  785. }
  786. XCTAssertEqual(responses, 6)
  787. }
  788. func testGenerateContentStream_successBasicReplyShort() async throws {
  789. MockURLProtocol
  790. .requestHandler = try httpRequestHandler(
  791. forResource: "streaming-success-basic-reply-short",
  792. withExtension: "txt"
  793. )
  794. var responses = 0
  795. let stream = model.generateContentStream("Hi")
  796. for try await content in stream {
  797. XCTAssertNotNil(content.text)
  798. responses += 1
  799. }
  800. XCTAssertEqual(responses, 1)
  801. }
  802. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  803. MockURLProtocol
  804. .requestHandler = try httpRequestHandler(
  805. forResource: "streaming-success-unknown-safety-enum",
  806. withExtension: "txt"
  807. )
  808. var hadUnknown = false
  809. let stream = model.generateContentStream("Hi")
  810. for try await content in stream {
  811. XCTAssertNotNil(content.text)
  812. if let ratings = content.candidates.first?.safetyRatings,
  813. ratings.contains(where: { $0.category == .unknown }) {
  814. hadUnknown = true
  815. }
  816. }
  817. XCTAssertTrue(hadUnknown)
  818. }
  819. func testGenerateContentStream_successWithCitations() async throws {
  820. MockURLProtocol
  821. .requestHandler = try httpRequestHandler(
  822. forResource: "streaming-success-citations",
  823. withExtension: "txt"
  824. )
  825. let stream = model.generateContentStream("Hi")
  826. var citations = [Citation]()
  827. var responses = [GenerateContentResponse]()
  828. for try await content in stream {
  829. responses.append(content)
  830. XCTAssertNotNil(content.text)
  831. let candidate = try XCTUnwrap(content.candidates.first)
  832. if let sources = candidate.citationMetadata?.citationSources {
  833. citations.append(contentsOf: sources)
  834. }
  835. }
  836. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  837. XCTAssertEqual(lastCandidate.finishReason, .stop)
  838. XCTAssertEqual(citations.count, 6)
  839. XCTAssertTrue(citations
  840. .contains {
  841. $0.startIndex == 0 && $0.endIndex == 128
  842. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  843. && $0.license == nil
  844. })
  845. XCTAssertTrue(citations
  846. .contains {
  847. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  848. && $0.title == "some-citation-2" && $0.license == nil
  849. })
  850. XCTAssertTrue(citations
  851. .contains {
  852. $0.startIndex == 272 && $0.endIndex == 431
  853. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  854. && $0.license == "mit"
  855. })
  856. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  857. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  858. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  859. }
  860. func testGenerateContentStream_appCheck_validToken() async throws {
  861. let appCheckToken = "test-valid-token"
  862. model = GenerativeModel(
  863. name: "my-model",
  864. projectID: "my-project-id",
  865. apiKey: "API_KEY",
  866. tools: nil,
  867. requestOptions: RequestOptions(),
  868. appCheck: AppCheckInteropFake(token: appCheckToken),
  869. auth: nil,
  870. urlSession: urlSession
  871. )
  872. MockURLProtocol
  873. .requestHandler = try httpRequestHandler(
  874. forResource: "streaming-success-basic-reply-short",
  875. withExtension: "txt",
  876. appCheckToken: appCheckToken
  877. )
  878. let stream = model.generateContentStream(testPrompt)
  879. for try await _ in stream {}
  880. }
  881. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  882. model = GenerativeModel(
  883. name: "my-model",
  884. projectID: "my-project-id",
  885. apiKey: "API_KEY",
  886. tools: nil,
  887. requestOptions: RequestOptions(),
  888. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  889. auth: nil,
  890. urlSession: urlSession
  891. )
  892. MockURLProtocol
  893. .requestHandler = try httpRequestHandler(
  894. forResource: "streaming-success-basic-reply-short",
  895. withExtension: "txt",
  896. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  897. )
  898. let stream = model.generateContentStream(testPrompt)
  899. for try await _ in stream {}
  900. }
  901. func testGenerateContentStream_usageMetadata() async throws {
  902. MockURLProtocol
  903. .requestHandler = try httpRequestHandler(
  904. forResource: "streaming-success-basic-reply-short",
  905. withExtension: "txt"
  906. )
  907. var responses = [GenerateContentResponse]()
  908. let stream = model.generateContentStream(testPrompt)
  909. for try await response in stream {
  910. responses.append(response)
  911. }
  912. for (index, response) in responses.enumerated() {
  913. if index == responses.endIndex - 1 {
  914. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  915. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  916. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  917. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  918. } else {
  919. // Only the last streamed response contains usage metadata
  920. XCTAssertNil(response.usageMetadata)
  921. }
  922. }
  923. }
  924. func testGenerateContentStream_errorMidStream() async throws {
  925. MockURLProtocol.requestHandler = try httpRequestHandler(
  926. forResource: "streaming-failure-error-mid-stream",
  927. withExtension: "txt"
  928. )
  929. var responseCount = 0
  930. do {
  931. let stream = model.generateContentStream("Hi")
  932. for try await content in stream {
  933. XCTAssertNotNil(content.text)
  934. responseCount += 1
  935. }
  936. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  937. XCTAssertEqual(rpcError.httpResponseCode, 499)
  938. XCTAssertEqual(rpcError.status, .cancelled)
  939. // Check the content count is correct.
  940. XCTAssertEqual(responseCount, 2)
  941. return
  942. }
  943. XCTFail("Expected an internalError with an RPCError.")
  944. }
  945. func testGenerateContentStream_nonHTTPResponse() async throws {
  946. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  947. let stream = model.generateContentStream("Hi")
  948. do {
  949. for try await content in stream {
  950. XCTFail("Unexpected content in stream: \(content)")
  951. }
  952. } catch let GenerateContentError.internalError(underlying) {
  953. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  954. return
  955. }
  956. XCTFail("Expected an internal error.")
  957. }
  958. func testGenerateContentStream_invalidResponse() async throws {
  959. MockURLProtocol
  960. .requestHandler = try httpRequestHandler(
  961. forResource: "streaming-failure-invalid-json",
  962. withExtension: "txt"
  963. )
  964. let stream = model.generateContentStream(testPrompt)
  965. do {
  966. for try await content in stream {
  967. XCTFail("Unexpected content in stream: \(content)")
  968. }
  969. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  970. guard case let .dataCorrupted(context) = underlying else {
  971. XCTFail("Not a data corrupted error: \(underlying)")
  972. return
  973. }
  974. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  975. return
  976. }
  977. XCTFail("Expected an internal error.")
  978. }
  979. func testGenerateContentStream_malformedContent() async throws {
  980. MockURLProtocol
  981. .requestHandler = try httpRequestHandler(
  982. forResource: "streaming-failure-malformed-content",
  983. withExtension: "txt"
  984. )
  985. let stream = model.generateContentStream(testPrompt)
  986. do {
  987. for try await content in stream {
  988. XCTFail("Unexpected content in stream: \(content)")
  989. }
  990. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  991. guard case let .malformedContent(contentError) = underlyingError else {
  992. XCTFail("Not a malformed content error: \(underlyingError)")
  993. return
  994. }
  995. XCTAssert(contentError is DecodingError)
  996. return
  997. }
  998. XCTFail("Expected an internal decoding error.")
  999. }
  1000. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1001. let expectedTimeout = 150.0
  1002. MockURLProtocol
  1003. .requestHandler = try httpRequestHandler(
  1004. forResource: "streaming-success-basic-reply-short",
  1005. withExtension: "txt",
  1006. timeout: expectedTimeout
  1007. )
  1008. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1009. model = GenerativeModel(
  1010. name: "my-model",
  1011. projectID: "my-project-id",
  1012. apiKey: "API_KEY",
  1013. tools: nil,
  1014. requestOptions: requestOptions,
  1015. appCheck: nil,
  1016. auth: nil,
  1017. urlSession: urlSession
  1018. )
  1019. var responses = 0
  1020. let stream = model.generateContentStream(testPrompt)
  1021. for try await content in stream {
  1022. XCTAssertNotNil(content.text)
  1023. responses += 1
  1024. }
  1025. XCTAssertEqual(responses, 1)
  1026. }
  1027. // MARK: - Count Tokens
  1028. func testCountTokens_succeeds() async throws {
  1029. MockURLProtocol.requestHandler = try httpRequestHandler(
  1030. forResource: "unary-success-total-tokens",
  1031. withExtension: "json"
  1032. )
  1033. let response = try await model.countTokens("Why is the sky blue?")
  1034. XCTAssertEqual(response.totalTokens, 6)
  1035. XCTAssertEqual(response.totalBillableCharacters, 16)
  1036. }
  1037. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1038. MockURLProtocol.requestHandler = try httpRequestHandler(
  1039. forResource: "unary-success-no-billable-characters",
  1040. withExtension: "json"
  1041. )
  1042. let response = try await model.countTokens(ModelContent.Part.data(
  1043. mimetype: "image/jpeg",
  1044. Data()
  1045. ))
  1046. XCTAssertEqual(response.totalTokens, 258)
  1047. XCTAssertEqual(response.totalBillableCharacters, 0)
  1048. }
  1049. func testCountTokens_modelNotFound() async throws {
  1050. MockURLProtocol.requestHandler = try httpRequestHandler(
  1051. forResource: "unary-failure-model-not-found", withExtension: "json",
  1052. statusCode: 404
  1053. )
  1054. do {
  1055. _ = try await model.countTokens("Why is the sky blue?")
  1056. XCTFail("Request should not have succeeded.")
  1057. } catch let CountTokensError.internalError(rpcError as RPCError) {
  1058. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1059. XCTAssertEqual(rpcError.status, .notFound)
  1060. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1061. return
  1062. }
  1063. XCTFail("Expected internal RPCError.")
  1064. }
  1065. func testCountTokens_requestOptions_customTimeout() async throws {
  1066. let expectedTimeout = 150.0
  1067. MockURLProtocol
  1068. .requestHandler = try httpRequestHandler(
  1069. forResource: "unary-success-total-tokens",
  1070. withExtension: "json",
  1071. timeout: expectedTimeout
  1072. )
  1073. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1074. model = GenerativeModel(
  1075. name: "my-model",
  1076. projectID: "my-project-id",
  1077. apiKey: "API_KEY",
  1078. tools: nil,
  1079. requestOptions: requestOptions,
  1080. appCheck: nil,
  1081. auth: nil,
  1082. urlSession: urlSession
  1083. )
  1084. let response = try await model.countTokens(testPrompt)
  1085. XCTAssertEqual(response.totalTokens, 6)
  1086. }
  1087. // MARK: - Model Resource Name
  1088. func testModelResourceName_noPrefix() async throws {
  1089. let modelName = "my-model"
  1090. let modelResourceName = "models/\(modelName)"
  1091. model = GenerativeModel(
  1092. name: modelName,
  1093. projectID: "my-project-id",
  1094. apiKey: "API_KEY",
  1095. tools: nil,
  1096. requestOptions: RequestOptions(),
  1097. appCheck: nil,
  1098. auth: nil
  1099. )
  1100. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1101. }
  1102. func testModelResourceName_modelsPrefix() async throws {
  1103. let modelResourceName = "models/my-model"
  1104. model = GenerativeModel(
  1105. name: modelResourceName,
  1106. projectID: "my-project-id",
  1107. apiKey: "API_KEY",
  1108. tools: nil,
  1109. requestOptions: RequestOptions(),
  1110. appCheck: nil,
  1111. auth: nil
  1112. )
  1113. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1114. }
  1115. func testModelResourceName_tunedModelsPrefix() async throws {
  1116. let tunedModelResourceName = "tunedModels/my-model"
  1117. model = GenerativeModel(
  1118. name: tunedModelResourceName,
  1119. projectID: "my-project-id",
  1120. apiKey: "API_KEY",
  1121. tools: nil,
  1122. requestOptions: RequestOptions(),
  1123. appCheck: nil,
  1124. auth: nil
  1125. )
  1126. XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
  1127. }
  1128. // MARK: - Helpers
  1129. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1130. URLResponse,
  1131. AsyncLineSequence<URL.AsyncBytes>?
  1132. )) {
  1133. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1134. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1135. #if os(watchOS)
  1136. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1137. #endif // os(watchOS)
  1138. return { request in
  1139. // This is *not* an HTTPURLResponse
  1140. let response = URLResponse(
  1141. url: request.url!,
  1142. mimeType: nil,
  1143. expectedContentLength: 0,
  1144. textEncodingName: nil
  1145. )
  1146. return (response, nil)
  1147. }
  1148. }
  1149. private func httpRequestHandler(forResource name: String,
  1150. withExtension ext: String,
  1151. statusCode: Int = 200,
  1152. timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
  1153. appCheckToken: String? = nil,
  1154. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1155. URLResponse,
  1156. AsyncLineSequence<URL.AsyncBytes>?
  1157. )) {
  1158. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1159. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1160. #if os(watchOS)
  1161. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1162. #endif // os(watchOS)
  1163. let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
  1164. return { request in
  1165. let requestURL = try XCTUnwrap(request.url)
  1166. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1167. XCTAssertEqual(request.timeoutInterval, timeout)
  1168. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1169. .components(separatedBy: " ")
  1170. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1171. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1172. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1173. if let authToken {
  1174. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1175. } else {
  1176. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1177. }
  1178. let response = try XCTUnwrap(HTTPURLResponse(
  1179. url: requestURL,
  1180. statusCode: statusCode,
  1181. httpVersion: nil,
  1182. headerFields: nil
  1183. ))
  1184. return (response, fileURL.lines)
  1185. }
  1186. }
  1187. }
  1188. private extension String {
  1189. /// Returns the number of occurrences of `substring` in the `String`.
  1190. func occurrenceCount(of substring: String) -> Int {
  1191. return components(separatedBy: substring).count - 1
  1192. }
  1193. }
  1194. private extension URLRequest {
  1195. /// Returns the default `timeoutInterval` for a `URLRequest`.
  1196. static func defaultTimeoutInterval() -> TimeInterval {
  1197. let placeholderURL = URL(string: "https://example.com")!
  1198. return URLRequest(url: placeholderURL).timeoutInterval
  1199. }
  1200. }
  1201. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1202. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1203. /// The placeholder token value returned when an error occurs
  1204. static let placeholderTokenValue = "placeholder-token"
  1205. var token: String
  1206. var error: Error?
  1207. private init(token: String, error: Error?) {
  1208. self.token = token
  1209. self.error = error
  1210. }
  1211. convenience init(token: String) {
  1212. self.init(token: token, error: nil)
  1213. }
  1214. convenience init(error: Error) {
  1215. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1216. }
  1217. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1218. return AppCheckTokenResultInteropFake(token: token, error: error)
  1219. }
  1220. func tokenDidChangeNotificationName() -> String {
  1221. fatalError("\(#function) not implemented.")
  1222. }
  1223. func notificationTokenKey() -> String {
  1224. fatalError("\(#function) not implemented.")
  1225. }
  1226. func notificationAppNameKey() -> String {
  1227. fatalError("\(#function) not implemented.")
  1228. }
  1229. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1230. var token: String
  1231. var error: Error?
  1232. init(token: String, error: Error?) {
  1233. self.token = token
  1234. self.error = error
  1235. }
  1236. }
  1237. }
  1238. struct AppCheckErrorFake: Error {}
  1239. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1240. extension SafetyRating: Comparable {
  1241. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1242. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1243. return lhs.category.rawValue < rhs.category.rawValue
  1244. }
  1245. }