GenerativeModelTests.swift 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. let testModelResourceName =
  29. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  30. var urlSession: URLSession!
  31. var model: GenerativeModel!
  32. override func setUp() async throws {
  33. let configuration = URLSessionConfiguration.default
  34. configuration.protocolClasses = [MockURLProtocol.self]
  35. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  36. model = GenerativeModel(
  37. name: testModelResourceName,
  38. projectID: "my-project-id",
  39. apiKey: "API_KEY",
  40. tools: nil,
  41. requestOptions: RequestOptions(),
  42. appCheck: nil,
  43. auth: nil,
  44. urlSession: urlSession
  45. )
  46. }
  47. override func tearDown() {
  48. MockURLProtocol.requestHandler = nil
  49. }
  50. // MARK: - Generate Content
  51. func testGenerateContent_success_basicReplyLong() async throws {
  52. MockURLProtocol
  53. .requestHandler = try httpRequestHandler(
  54. forResource: "unary-success-basic-reply-long",
  55. withExtension: "json"
  56. )
  57. let response = try await model.generateContent(testPrompt)
  58. XCTAssertEqual(response.candidates.count, 1)
  59. let candidate = try XCTUnwrap(response.candidates.first)
  60. let finishReason = try XCTUnwrap(candidate.finishReason)
  61. XCTAssertEqual(finishReason, .stop)
  62. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  63. XCTAssertEqual(candidate.content.parts.count, 1)
  64. let part = try XCTUnwrap(candidate.content.parts.first)
  65. let partText = try XCTUnwrap(part.text)
  66. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  67. XCTAssertEqual(response.text, partText)
  68. XCTAssertEqual(response.functionCalls, [])
  69. }
  70. func testGenerateContent_success_basicReplyShort() async throws {
  71. MockURLProtocol
  72. .requestHandler = try httpRequestHandler(
  73. forResource: "unary-success-basic-reply-short",
  74. withExtension: "json"
  75. )
  76. let response = try await model.generateContent(testPrompt)
  77. XCTAssertEqual(response.candidates.count, 1)
  78. let candidate = try XCTUnwrap(response.candidates.first)
  79. let finishReason = try XCTUnwrap(candidate.finishReason)
  80. XCTAssertEqual(finishReason, .stop)
  81. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  82. XCTAssertEqual(candidate.content.parts.count, 1)
  83. let part = try XCTUnwrap(candidate.content.parts.first)
  84. XCTAssertEqual(part.text, "Mountain View, California")
  85. XCTAssertEqual(response.text, part.text)
  86. XCTAssertEqual(response.functionCalls, [])
  87. }
  88. func testGenerateContent_success_citations() async throws {
  89. MockURLProtocol
  90. .requestHandler = try httpRequestHandler(
  91. forResource: "unary-success-citations",
  92. withExtension: "json"
  93. )
  94. let response = try await model.generateContent(testPrompt)
  95. XCTAssertEqual(response.candidates.count, 1)
  96. let candidate = try XCTUnwrap(response.candidates.first)
  97. XCTAssertEqual(candidate.content.parts.count, 1)
  98. XCTAssertEqual(response.text, "Some information cited from an external source")
  99. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  100. XCTAssertEqual(citationMetadata.citations.count, 3)
  101. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  102. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  103. XCTAssertEqual(citationSource1.startIndex, 0)
  104. XCTAssertEqual(citationSource1.endIndex, 128)
  105. XCTAssertNil(citationSource1.title)
  106. XCTAssertNil(citationSource1.license)
  107. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  108. XCTAssertEqual(citationSource2.title, "some-citation-2")
  109. XCTAssertEqual(citationSource2.startIndex, 130)
  110. XCTAssertEqual(citationSource2.endIndex, 265)
  111. XCTAssertNil(citationSource2.uri)
  112. XCTAssertNil(citationSource2.license)
  113. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  114. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  115. XCTAssertEqual(citationSource3.startIndex, 272)
  116. XCTAssertEqual(citationSource3.endIndex, 431)
  117. XCTAssertEqual(citationSource3.license, "mit")
  118. XCTAssertNil(citationSource3.title)
  119. }
  120. func testGenerateContent_success_quoteReply() async throws {
  121. MockURLProtocol
  122. .requestHandler = try httpRequestHandler(
  123. forResource: "unary-success-quote-reply",
  124. withExtension: "json"
  125. )
  126. let response = try await model.generateContent(testPrompt)
  127. XCTAssertEqual(response.candidates.count, 1)
  128. let candidate = try XCTUnwrap(response.candidates.first)
  129. let finishReason = try XCTUnwrap(candidate.finishReason)
  130. XCTAssertEqual(finishReason, .stop)
  131. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  132. XCTAssertEqual(candidate.content.parts.count, 1)
  133. let part = try XCTUnwrap(candidate.content.parts.first)
  134. let partText = try XCTUnwrap(part.text)
  135. XCTAssertTrue(partText.hasPrefix("Google"))
  136. XCTAssertEqual(response.text, part.text)
  137. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  138. XCTAssertNil(promptFeedback.blockReason)
  139. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  140. }
  141. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  142. let expectedSafetyRatings = [
  143. SafetyRating(category: .harassment, probability: .medium),
  144. SafetyRating(category: .dangerousContent, probability: .unknown),
  145. SafetyRating(category: .unknown, probability: .high),
  146. ]
  147. MockURLProtocol
  148. .requestHandler = try httpRequestHandler(
  149. forResource: "unary-success-unknown-enum-safety-ratings",
  150. withExtension: "json"
  151. )
  152. let response = try await model.generateContent(testPrompt)
  153. XCTAssertEqual(response.text, "Some text")
  154. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  155. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  156. }
  157. func testGenerateContent_success_prefixedModelName() async throws {
  158. MockURLProtocol
  159. .requestHandler = try httpRequestHandler(
  160. forResource: "unary-success-basic-reply-short",
  161. withExtension: "json"
  162. )
  163. let model = GenerativeModel(
  164. // Model name is prefixed with "models/".
  165. name: "models/test-model",
  166. projectID: "my-project-id",
  167. apiKey: "API_KEY",
  168. tools: nil,
  169. requestOptions: RequestOptions(),
  170. appCheck: nil,
  171. auth: nil,
  172. urlSession: urlSession
  173. )
  174. _ = try await model.generateContent(testPrompt)
  175. }
  176. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  177. MockURLProtocol
  178. .requestHandler = try httpRequestHandler(
  179. forResource: "unary-success-function-call-empty-arguments",
  180. withExtension: "json"
  181. )
  182. let response = try await model.generateContent(testPrompt)
  183. XCTAssertEqual(response.candidates.count, 1)
  184. let candidate = try XCTUnwrap(response.candidates.first)
  185. XCTAssertEqual(candidate.content.parts.count, 1)
  186. let part = try XCTUnwrap(candidate.content.parts.first)
  187. guard case let .functionCall(functionCall) = part else {
  188. XCTFail("Part is not a FunctionCall.")
  189. return
  190. }
  191. XCTAssertEqual(functionCall.name, "current_time")
  192. XCTAssertTrue(functionCall.args.isEmpty)
  193. XCTAssertEqual(response.functionCalls, [functionCall])
  194. }
  195. func testGenerateContent_success_functionCall_noArguments() async throws {
  196. MockURLProtocol
  197. .requestHandler = try httpRequestHandler(
  198. forResource: "unary-success-function-call-no-arguments",
  199. withExtension: "json"
  200. )
  201. let response = try await model.generateContent(testPrompt)
  202. XCTAssertEqual(response.candidates.count, 1)
  203. let candidate = try XCTUnwrap(response.candidates.first)
  204. XCTAssertEqual(candidate.content.parts.count, 1)
  205. let part = try XCTUnwrap(candidate.content.parts.first)
  206. guard case let .functionCall(functionCall) = part else {
  207. XCTFail("Part is not a FunctionCall.")
  208. return
  209. }
  210. XCTAssertEqual(functionCall.name, "current_time")
  211. XCTAssertTrue(functionCall.args.isEmpty)
  212. XCTAssertEqual(response.functionCalls, [functionCall])
  213. }
  214. func testGenerateContent_success_functionCall_withArguments() async throws {
  215. MockURLProtocol
  216. .requestHandler = try httpRequestHandler(
  217. forResource: "unary-success-function-call-with-arguments",
  218. withExtension: "json"
  219. )
  220. let response = try await model.generateContent(testPrompt)
  221. XCTAssertEqual(response.candidates.count, 1)
  222. let candidate = try XCTUnwrap(response.candidates.first)
  223. XCTAssertEqual(candidate.content.parts.count, 1)
  224. let part = try XCTUnwrap(candidate.content.parts.first)
  225. guard case let .functionCall(functionCall) = part else {
  226. XCTFail("Part is not a FunctionCall.")
  227. return
  228. }
  229. XCTAssertEqual(functionCall.name, "sum")
  230. XCTAssertEqual(functionCall.args.count, 2)
  231. let argX = try XCTUnwrap(functionCall.args["x"])
  232. XCTAssertEqual(argX, .number(4))
  233. let argY = try XCTUnwrap(functionCall.args["y"])
  234. XCTAssertEqual(argY, .number(5))
  235. XCTAssertEqual(response.functionCalls, [functionCall])
  236. }
  237. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  238. MockURLProtocol
  239. .requestHandler = try httpRequestHandler(
  240. forResource: "unary-success-function-call-parallel-calls",
  241. withExtension: "json"
  242. )
  243. let response = try await model.generateContent(testPrompt)
  244. XCTAssertEqual(response.candidates.count, 1)
  245. let candidate = try XCTUnwrap(response.candidates.first)
  246. XCTAssertEqual(candidate.content.parts.count, 3)
  247. let functionCalls = response.functionCalls
  248. XCTAssertEqual(functionCalls.count, 3)
  249. }
  250. func testGenerateContent_success_functionCall_mixedContent() async throws {
  251. MockURLProtocol
  252. .requestHandler = try httpRequestHandler(
  253. forResource: "unary-success-function-call-mixed-content",
  254. withExtension: "json"
  255. )
  256. let response = try await model.generateContent(testPrompt)
  257. XCTAssertEqual(response.candidates.count, 1)
  258. let candidate = try XCTUnwrap(response.candidates.first)
  259. XCTAssertEqual(candidate.content.parts.count, 4)
  260. let functionCalls = response.functionCalls
  261. XCTAssertEqual(functionCalls.count, 2)
  262. let text = try XCTUnwrap(response.text)
  263. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  264. }
  265. func testGenerateContent_appCheck_validToken() async throws {
  266. let appCheckToken = "test-valid-token"
  267. model = GenerativeModel(
  268. name: testModelResourceName,
  269. projectID: "my-project-id",
  270. apiKey: "API_KEY",
  271. tools: nil,
  272. requestOptions: RequestOptions(),
  273. appCheck: AppCheckInteropFake(token: appCheckToken),
  274. auth: nil,
  275. urlSession: urlSession
  276. )
  277. MockURLProtocol
  278. .requestHandler = try httpRequestHandler(
  279. forResource: "unary-success-basic-reply-short",
  280. withExtension: "json",
  281. appCheckToken: appCheckToken
  282. )
  283. _ = try await model.generateContent(testPrompt)
  284. }
  285. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  286. model = GenerativeModel(
  287. name: testModelResourceName,
  288. projectID: "my-project-id",
  289. apiKey: "API_KEY",
  290. tools: nil,
  291. requestOptions: RequestOptions(),
  292. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  293. auth: nil,
  294. urlSession: urlSession
  295. )
  296. MockURLProtocol
  297. .requestHandler = try httpRequestHandler(
  298. forResource: "unary-success-basic-reply-short",
  299. withExtension: "json",
  300. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  301. )
  302. _ = try await model.generateContent(testPrompt)
  303. }
  304. func testGenerateContent_auth_validAuthToken() async throws {
  305. let authToken = "test-valid-token"
  306. model = GenerativeModel(
  307. name: testModelResourceName,
  308. projectID: "my-project-id",
  309. apiKey: "API_KEY",
  310. tools: nil,
  311. requestOptions: RequestOptions(),
  312. appCheck: nil,
  313. auth: AuthInteropFake(token: authToken),
  314. urlSession: urlSession
  315. )
  316. MockURLProtocol
  317. .requestHandler = try httpRequestHandler(
  318. forResource: "unary-success-basic-reply-short",
  319. withExtension: "json",
  320. authToken: authToken
  321. )
  322. _ = try await model.generateContent(testPrompt)
  323. }
  324. func testGenerateContent_auth_nilAuthToken() async throws {
  325. model = GenerativeModel(
  326. name: testModelResourceName,
  327. projectID: "my-project-id",
  328. apiKey: "API_KEY",
  329. tools: nil,
  330. requestOptions: RequestOptions(),
  331. appCheck: nil,
  332. auth: AuthInteropFake(token: nil),
  333. urlSession: urlSession
  334. )
  335. MockURLProtocol
  336. .requestHandler = try httpRequestHandler(
  337. forResource: "unary-success-basic-reply-short",
  338. withExtension: "json",
  339. authToken: nil
  340. )
  341. _ = try await model.generateContent(testPrompt)
  342. }
  343. func testGenerateContent_auth_authTokenRefreshError() async throws {
  344. model = GenerativeModel(
  345. name: "my-model",
  346. projectID: "my-project-id",
  347. apiKey: "API_KEY",
  348. tools: nil,
  349. requestOptions: RequestOptions(),
  350. appCheck: nil,
  351. auth: AuthInteropFake(error: AuthErrorFake()),
  352. urlSession: urlSession
  353. )
  354. MockURLProtocol
  355. .requestHandler = try httpRequestHandler(
  356. forResource: "unary-success-basic-reply-short",
  357. withExtension: "json",
  358. authToken: nil
  359. )
  360. do {
  361. _ = try await model.generateContent(testPrompt)
  362. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  363. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  364. //
  365. } catch {
  366. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  367. }
  368. }
  369. func testGenerateContent_usageMetadata() async throws {
  370. MockURLProtocol
  371. .requestHandler = try httpRequestHandler(
  372. forResource: "unary-success-basic-reply-short",
  373. withExtension: "json"
  374. )
  375. let response = try await model.generateContent(testPrompt)
  376. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  377. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  378. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  379. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  380. }
  381. func testGenerateContent_failure_invalidAPIKey() async throws {
  382. let expectedStatusCode = 400
  383. MockURLProtocol
  384. .requestHandler = try httpRequestHandler(
  385. forResource: "unary-failure-api-key",
  386. withExtension: "json",
  387. statusCode: expectedStatusCode
  388. )
  389. do {
  390. _ = try await model.generateContent(testPrompt)
  391. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  392. } catch let GenerateContentError.internalError(error as RPCError) {
  393. XCTAssertEqual(error.httpResponseCode, 400)
  394. XCTAssertEqual(error.status, .invalidArgument)
  395. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  396. return
  397. } catch {
  398. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  399. }
  400. }
  401. func testGenerateContent_failure_firebaseVertexAIAPINotEnabled() async throws {
  402. let expectedStatusCode = 403
  403. MockURLProtocol
  404. .requestHandler = try httpRequestHandler(
  405. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  406. withExtension: "json",
  407. statusCode: expectedStatusCode
  408. )
  409. do {
  410. _ = try await model.generateContent(testPrompt)
  411. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  412. } catch let GenerateContentError.internalError(error as RPCError) {
  413. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  414. XCTAssertEqual(error.status, .permissionDenied)
  415. XCTAssertTrue(error.message
  416. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  417. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  418. return
  419. } catch {
  420. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  421. }
  422. }
  423. func testGenerateContent_failure_emptyContent() async throws {
  424. MockURLProtocol
  425. .requestHandler = try httpRequestHandler(
  426. forResource: "unary-failure-empty-content",
  427. withExtension: "json"
  428. )
  429. do {
  430. _ = try await model.generateContent(testPrompt)
  431. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  432. } catch let GenerateContentError
  433. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  434. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  435. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  436. return
  437. }
  438. _ = try XCTUnwrap(decodingError as? DecodingError,
  439. "Not a DecodingError: \(decodingError)")
  440. } catch {
  441. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  442. }
  443. }
  444. func testGenerateContent_failure_finishReasonSafety() async throws {
  445. MockURLProtocol
  446. .requestHandler = try httpRequestHandler(
  447. forResource: "unary-failure-finish-reason-safety",
  448. withExtension: "json"
  449. )
  450. do {
  451. _ = try await model.generateContent(testPrompt)
  452. XCTFail("Should throw")
  453. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  454. XCTAssertEqual(reason, .safety)
  455. XCTAssertEqual(response.text, "<redacted>")
  456. } catch {
  457. XCTFail("Should throw a responseStoppedEarly")
  458. }
  459. }
  460. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  461. MockURLProtocol
  462. .requestHandler = try httpRequestHandler(
  463. forResource: "unary-failure-finish-reason-safety-no-content",
  464. withExtension: "json"
  465. )
  466. do {
  467. _ = try await model.generateContent(testPrompt)
  468. XCTFail("Should throw")
  469. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  470. XCTAssertEqual(reason, .safety)
  471. XCTAssertNil(response.text)
  472. } catch {
  473. XCTFail("Should throw a responseStoppedEarly")
  474. }
  475. }
  476. func testGenerateContent_failure_imageRejected() async throws {
  477. let expectedStatusCode = 400
  478. MockURLProtocol
  479. .requestHandler = try httpRequestHandler(
  480. forResource: "unary-failure-image-rejected",
  481. withExtension: "json",
  482. statusCode: 400
  483. )
  484. do {
  485. _ = try await model.generateContent(testPrompt)
  486. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  487. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  488. XCTAssertEqual(rpcError.status, .invalidArgument)
  489. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  490. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  491. } catch {
  492. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  493. }
  494. }
  495. func testGenerateContent_failure_promptBlockedSafety() async throws {
  496. MockURLProtocol
  497. .requestHandler = try httpRequestHandler(
  498. forResource: "unary-failure-prompt-blocked-safety",
  499. withExtension: "json"
  500. )
  501. do {
  502. _ = try await model.generateContent(testPrompt)
  503. XCTFail("Should throw")
  504. } catch let GenerateContentError.promptBlocked(response) {
  505. XCTAssertNil(response.text)
  506. } catch {
  507. XCTFail("Should throw a promptBlocked")
  508. }
  509. }
  510. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  511. MockURLProtocol
  512. .requestHandler = try httpRequestHandler(
  513. forResource: "unary-failure-unknown-enum-finish-reason",
  514. withExtension: "json"
  515. )
  516. do {
  517. _ = try await model.generateContent(testPrompt)
  518. XCTFail("Should throw")
  519. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  520. XCTAssertEqual(reason, .unknown)
  521. XCTAssertEqual(response.text, "Some text")
  522. } catch {
  523. XCTFail("Should throw a responseStoppedEarly")
  524. }
  525. }
  526. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  527. MockURLProtocol
  528. .requestHandler = try httpRequestHandler(
  529. forResource: "unary-failure-unknown-enum-prompt-blocked",
  530. withExtension: "json"
  531. )
  532. do {
  533. _ = try await model.generateContent(testPrompt)
  534. XCTFail("Should throw")
  535. } catch let GenerateContentError.promptBlocked(response) {
  536. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  537. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  538. } catch {
  539. XCTFail("Should throw a promptBlocked")
  540. }
  541. }
  542. func testGenerateContent_failure_unknownModel() async throws {
  543. let expectedStatusCode = 404
  544. MockURLProtocol
  545. .requestHandler = try httpRequestHandler(
  546. forResource: "unary-failure-unknown-model",
  547. withExtension: "json",
  548. statusCode: 404
  549. )
  550. do {
  551. _ = try await model.generateContent(testPrompt)
  552. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  553. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  554. XCTAssertEqual(rpcError.status, .notFound)
  555. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  556. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  557. } catch {
  558. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  559. }
  560. }
  561. func testGenerateContent_failure_nonHTTPResponse() async throws {
  562. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  563. var responseError: Error?
  564. var content: GenerateContentResponse?
  565. do {
  566. content = try await model.generateContent(testPrompt)
  567. } catch {
  568. responseError = error
  569. }
  570. XCTAssertNil(content)
  571. XCTAssertNotNil(responseError)
  572. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  573. guard case let .internalError(underlyingError) = generateContentError else {
  574. XCTFail("Not an internal error: \(generateContentError)")
  575. return
  576. }
  577. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  578. }
  579. func testGenerateContent_failure_invalidResponse() async throws {
  580. MockURLProtocol.requestHandler = try httpRequestHandler(
  581. forResource: "unary-failure-invalid-response",
  582. withExtension: "json"
  583. )
  584. var responseError: Error?
  585. var content: GenerateContentResponse?
  586. do {
  587. content = try await model.generateContent(testPrompt)
  588. } catch {
  589. responseError = error
  590. }
  591. XCTAssertNil(content)
  592. XCTAssertNotNil(responseError)
  593. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  594. guard case let .internalError(underlyingError) = generateContentError else {
  595. XCTFail("Not an internal error: \(generateContentError)")
  596. return
  597. }
  598. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  599. guard case let .dataCorrupted(context) = decodingError else {
  600. XCTFail("Not a data corrupted error: \(decodingError)")
  601. return
  602. }
  603. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  604. }
  605. func testGenerateContent_failure_malformedContent() async throws {
  606. MockURLProtocol
  607. .requestHandler = try httpRequestHandler(
  608. forResource: "unary-failure-malformed-content",
  609. withExtension: "json"
  610. )
  611. var responseError: Error?
  612. var content: GenerateContentResponse?
  613. do {
  614. content = try await model.generateContent(testPrompt)
  615. } catch {
  616. responseError = error
  617. }
  618. XCTAssertNil(content)
  619. XCTAssertNotNil(responseError)
  620. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  621. guard case let .internalError(underlyingError) = generateContentError else {
  622. XCTFail("Not an internal error: \(generateContentError)")
  623. return
  624. }
  625. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  626. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  627. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  628. return
  629. }
  630. _ = try XCTUnwrap(
  631. malformedContentUnderlyingError as? DecodingError,
  632. "Not a decoding error: \(malformedContentUnderlyingError)"
  633. )
  634. }
  635. func testGenerateContentMissingSafetyRatings() async throws {
  636. MockURLProtocol.requestHandler = try httpRequestHandler(
  637. forResource: "unary-success-missing-safety-ratings",
  638. withExtension: "json"
  639. )
  640. let content = try await model.generateContent(testPrompt)
  641. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  642. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  643. XCTAssertEqual(content.text, "This is the generated content.")
  644. }
  645. func testGenerateContent_requestOptions_customTimeout() async throws {
  646. let expectedTimeout = 150.0
  647. MockURLProtocol
  648. .requestHandler = try httpRequestHandler(
  649. forResource: "unary-success-basic-reply-short",
  650. withExtension: "json",
  651. timeout: expectedTimeout
  652. )
  653. let requestOptions = RequestOptions(timeout: expectedTimeout)
  654. model = GenerativeModel(
  655. name: testModelResourceName,
  656. projectID: "my-project-id",
  657. apiKey: "API_KEY",
  658. tools: nil,
  659. requestOptions: requestOptions,
  660. appCheck: nil,
  661. auth: nil,
  662. urlSession: urlSession
  663. )
  664. let response = try await model.generateContent(testPrompt)
  665. XCTAssertEqual(response.candidates.count, 1)
  666. }
  667. // MARK: - Generate Content (Streaming)
  668. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  669. MockURLProtocol
  670. .requestHandler = try httpRequestHandler(
  671. forResource: "unary-failure-api-key",
  672. withExtension: "json"
  673. )
  674. do {
  675. let stream = try model.generateContentStream("Hi")
  676. for try await _ in stream {
  677. XCTFail("No content is there, this shouldn't happen.")
  678. }
  679. } catch let GenerateContentError.internalError(error as RPCError) {
  680. XCTAssertEqual(error.httpResponseCode, 400)
  681. XCTAssertEqual(error.status, .invalidArgument)
  682. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  683. return
  684. }
  685. XCTFail("Should have caught an error.")
  686. }
  687. func testGenerateContentStream_failure_vertexAIInFirebaseAPINotEnabled() async throws {
  688. let expectedStatusCode = 403
  689. MockURLProtocol
  690. .requestHandler = try httpRequestHandler(
  691. forResource: "unary-failure-firebasevertexai-api-not-enabled",
  692. withExtension: "json",
  693. statusCode: expectedStatusCode
  694. )
  695. do {
  696. let stream = try model.generateContentStream(testPrompt)
  697. for try await _ in stream {
  698. XCTFail("No content is there, this shouldn't happen.")
  699. }
  700. } catch let GenerateContentError.internalError(error as RPCError) {
  701. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  702. XCTAssertEqual(error.status, .permissionDenied)
  703. XCTAssertTrue(error.message
  704. .starts(with: "Vertex AI in Firebase API has not been used in project"))
  705. XCTAssertTrue(error.isVertexAIInFirebaseServiceDisabledError())
  706. return
  707. }
  708. XCTFail("Should have caught an error.")
  709. }
  710. func testGenerateContentStream_failureEmptyContent() async throws {
  711. MockURLProtocol
  712. .requestHandler = try httpRequestHandler(
  713. forResource: "streaming-failure-empty-content",
  714. withExtension: "txt"
  715. )
  716. do {
  717. let stream = try model.generateContentStream("Hi")
  718. for try await _ in stream {
  719. XCTFail("No content is there, this shouldn't happen.")
  720. }
  721. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  722. // Underlying error is as expected, nothing else to check.
  723. return
  724. }
  725. XCTFail("Should have caught an error.")
  726. }
  727. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  728. MockURLProtocol
  729. .requestHandler = try httpRequestHandler(
  730. forResource: "streaming-failure-finish-reason-safety",
  731. withExtension: "txt"
  732. )
  733. do {
  734. let stream = try model.generateContentStream("Hi")
  735. for try await _ in stream {
  736. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  737. }
  738. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  739. XCTAssertEqual(reason, .safety)
  740. return
  741. }
  742. XCTFail("Should have caught an error.")
  743. }
  744. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  745. MockURLProtocol
  746. .requestHandler = try httpRequestHandler(
  747. forResource: "streaming-failure-prompt-blocked-safety",
  748. withExtension: "txt"
  749. )
  750. do {
  751. let stream = try model.generateContentStream("Hi")
  752. for try await _ in stream {
  753. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  754. }
  755. } catch let GenerateContentError.promptBlocked(response) {
  756. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  757. return
  758. }
  759. XCTFail("Should have caught an error.")
  760. }
  761. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  762. MockURLProtocol
  763. .requestHandler = try httpRequestHandler(
  764. forResource: "streaming-failure-unknown-finish-enum",
  765. withExtension: "txt"
  766. )
  767. let stream = try model.generateContentStream("Hi")
  768. do {
  769. for try await content in stream {
  770. XCTAssertNotNil(content.text)
  771. }
  772. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  773. XCTAssertEqual(reason, .unknown)
  774. return
  775. }
  776. XCTFail("Should have caught an error.")
  777. }
  778. func testGenerateContentStream_successBasicReplyLong() async throws {
  779. MockURLProtocol
  780. .requestHandler = try httpRequestHandler(
  781. forResource: "streaming-success-basic-reply-long",
  782. withExtension: "txt"
  783. )
  784. var responses = 0
  785. let stream = try model.generateContentStream("Hi")
  786. for try await content in stream {
  787. XCTAssertNotNil(content.text)
  788. responses += 1
  789. }
  790. XCTAssertEqual(responses, 6)
  791. }
  792. func testGenerateContentStream_successBasicReplyShort() async throws {
  793. MockURLProtocol
  794. .requestHandler = try httpRequestHandler(
  795. forResource: "streaming-success-basic-reply-short",
  796. withExtension: "txt"
  797. )
  798. var responses = 0
  799. let stream = try model.generateContentStream("Hi")
  800. for try await content in stream {
  801. XCTAssertNotNil(content.text)
  802. responses += 1
  803. }
  804. XCTAssertEqual(responses, 1)
  805. }
  806. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  807. MockURLProtocol
  808. .requestHandler = try httpRequestHandler(
  809. forResource: "streaming-success-unknown-safety-enum",
  810. withExtension: "txt"
  811. )
  812. var hadUnknown = false
  813. let stream = try model.generateContentStream("Hi")
  814. for try await content in stream {
  815. XCTAssertNotNil(content.text)
  816. if let ratings = content.candidates.first?.safetyRatings,
  817. ratings.contains(where: { $0.category == .unknown }) {
  818. hadUnknown = true
  819. }
  820. }
  821. XCTAssertTrue(hadUnknown)
  822. }
  823. func testGenerateContentStream_successWithCitations() async throws {
  824. MockURLProtocol
  825. .requestHandler = try httpRequestHandler(
  826. forResource: "streaming-success-citations",
  827. withExtension: "txt"
  828. )
  829. let stream = try model.generateContentStream("Hi")
  830. var citations = [Citation]()
  831. var responses = [GenerateContentResponse]()
  832. for try await content in stream {
  833. responses.append(content)
  834. XCTAssertNotNil(content.text)
  835. let candidate = try XCTUnwrap(content.candidates.first)
  836. if let sources = candidate.citationMetadata?.citations {
  837. citations.append(contentsOf: sources)
  838. }
  839. }
  840. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  841. XCTAssertEqual(lastCandidate.finishReason, .stop)
  842. XCTAssertEqual(citations.count, 6)
  843. XCTAssertTrue(citations
  844. .contains {
  845. $0.startIndex == 0 && $0.endIndex == 128
  846. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  847. && $0.license == nil
  848. })
  849. XCTAssertTrue(citations
  850. .contains {
  851. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  852. && $0.title == "some-citation-2" && $0.license == nil
  853. })
  854. XCTAssertTrue(citations
  855. .contains {
  856. $0.startIndex == 272 && $0.endIndex == 431
  857. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  858. && $0.license == "mit"
  859. })
  860. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  861. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  862. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  863. }
  864. func testGenerateContentStream_appCheck_validToken() async throws {
  865. let appCheckToken = "test-valid-token"
  866. model = GenerativeModel(
  867. name: testModelResourceName,
  868. projectID: "my-project-id",
  869. apiKey: "API_KEY",
  870. tools: nil,
  871. requestOptions: RequestOptions(),
  872. appCheck: AppCheckInteropFake(token: appCheckToken),
  873. auth: nil,
  874. urlSession: urlSession
  875. )
  876. MockURLProtocol
  877. .requestHandler = try httpRequestHandler(
  878. forResource: "streaming-success-basic-reply-short",
  879. withExtension: "txt",
  880. appCheckToken: appCheckToken
  881. )
  882. let stream = try model.generateContentStream(testPrompt)
  883. for try await _ in stream {}
  884. }
  885. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  886. model = GenerativeModel(
  887. name: testModelResourceName,
  888. projectID: "my-project-id",
  889. apiKey: "API_KEY",
  890. tools: nil,
  891. requestOptions: RequestOptions(),
  892. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  893. auth: nil,
  894. urlSession: urlSession
  895. )
  896. MockURLProtocol
  897. .requestHandler = try httpRequestHandler(
  898. forResource: "streaming-success-basic-reply-short",
  899. withExtension: "txt",
  900. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  901. )
  902. let stream = try model.generateContentStream(testPrompt)
  903. for try await _ in stream {}
  904. }
  905. func testGenerateContentStream_usageMetadata() async throws {
  906. MockURLProtocol
  907. .requestHandler = try httpRequestHandler(
  908. forResource: "streaming-success-basic-reply-short",
  909. withExtension: "txt"
  910. )
  911. var responses = [GenerateContentResponse]()
  912. let stream = try model.generateContentStream(testPrompt)
  913. for try await response in stream {
  914. responses.append(response)
  915. }
  916. for (index, response) in responses.enumerated() {
  917. if index == responses.endIndex - 1 {
  918. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  919. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  920. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  921. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  922. } else {
  923. // Only the last streamed response contains usage metadata
  924. XCTAssertNil(response.usageMetadata)
  925. }
  926. }
  927. }
  928. func testGenerateContentStream_errorMidStream() async throws {
  929. MockURLProtocol.requestHandler = try httpRequestHandler(
  930. forResource: "streaming-failure-error-mid-stream",
  931. withExtension: "txt"
  932. )
  933. var responseCount = 0
  934. do {
  935. let stream = try model.generateContentStream("Hi")
  936. for try await content in stream {
  937. XCTAssertNotNil(content.text)
  938. responseCount += 1
  939. }
  940. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  941. XCTAssertEqual(rpcError.httpResponseCode, 499)
  942. XCTAssertEqual(rpcError.status, .cancelled)
  943. // Check the content count is correct.
  944. XCTAssertEqual(responseCount, 2)
  945. return
  946. }
  947. XCTFail("Expected an internalError with an RPCError.")
  948. }
  949. func testGenerateContentStream_nonHTTPResponse() async throws {
  950. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  951. let stream = try model.generateContentStream("Hi")
  952. do {
  953. for try await content in stream {
  954. XCTFail("Unexpected content in stream: \(content)")
  955. }
  956. } catch let GenerateContentError.internalError(underlying) {
  957. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  958. return
  959. }
  960. XCTFail("Expected an internal error.")
  961. }
  962. func testGenerateContentStream_invalidResponse() async throws {
  963. MockURLProtocol
  964. .requestHandler = try httpRequestHandler(
  965. forResource: "streaming-failure-invalid-json",
  966. withExtension: "txt"
  967. )
  968. let stream = try model.generateContentStream(testPrompt)
  969. do {
  970. for try await content in stream {
  971. XCTFail("Unexpected content in stream: \(content)")
  972. }
  973. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  974. guard case let .dataCorrupted(context) = underlying else {
  975. XCTFail("Not a data corrupted error: \(underlying)")
  976. return
  977. }
  978. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  979. return
  980. }
  981. XCTFail("Expected an internal error.")
  982. }
  983. func testGenerateContentStream_malformedContent() async throws {
  984. MockURLProtocol
  985. .requestHandler = try httpRequestHandler(
  986. forResource: "streaming-failure-malformed-content",
  987. withExtension: "txt"
  988. )
  989. let stream = try model.generateContentStream(testPrompt)
  990. do {
  991. for try await content in stream {
  992. XCTFail("Unexpected content in stream: \(content)")
  993. }
  994. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  995. guard case let .malformedContent(contentError) = underlyingError else {
  996. XCTFail("Not a malformed content error: \(underlyingError)")
  997. return
  998. }
  999. XCTAssert(contentError is DecodingError)
  1000. return
  1001. }
  1002. XCTFail("Expected an internal decoding error.")
  1003. }
  1004. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1005. let expectedTimeout = 150.0
  1006. MockURLProtocol
  1007. .requestHandler = try httpRequestHandler(
  1008. forResource: "streaming-success-basic-reply-short",
  1009. withExtension: "txt",
  1010. timeout: expectedTimeout
  1011. )
  1012. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1013. model = GenerativeModel(
  1014. name: testModelResourceName,
  1015. projectID: "my-project-id",
  1016. apiKey: "API_KEY",
  1017. tools: nil,
  1018. requestOptions: requestOptions,
  1019. appCheck: nil,
  1020. auth: nil,
  1021. urlSession: urlSession
  1022. )
  1023. var responses = 0
  1024. let stream = try model.generateContentStream(testPrompt)
  1025. for try await content in stream {
  1026. XCTAssertNotNil(content.text)
  1027. responses += 1
  1028. }
  1029. XCTAssertEqual(responses, 1)
  1030. }
  1031. // MARK: - Count Tokens
  1032. func testCountTokens_succeeds() async throws {
  1033. MockURLProtocol.requestHandler = try httpRequestHandler(
  1034. forResource: "unary-success-total-tokens",
  1035. withExtension: "json"
  1036. )
  1037. let response = try await model.countTokens("Why is the sky blue?")
  1038. XCTAssertEqual(response.totalTokens, 6)
  1039. XCTAssertEqual(response.totalBillableCharacters, 16)
  1040. }
  1041. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1042. MockURLProtocol.requestHandler = try httpRequestHandler(
  1043. forResource: "unary-success-no-billable-characters",
  1044. withExtension: "json"
  1045. )
  1046. let response = try await model.countTokens(ModelContent.Part.inlineData(
  1047. mimetype: "image/jpeg",
  1048. Data()
  1049. ))
  1050. XCTAssertEqual(response.totalTokens, 258)
  1051. XCTAssertNil(response.totalBillableCharacters)
  1052. }
  1053. func testCountTokens_modelNotFound() async throws {
  1054. MockURLProtocol.requestHandler = try httpRequestHandler(
  1055. forResource: "unary-failure-model-not-found", withExtension: "json",
  1056. statusCode: 404
  1057. )
  1058. do {
  1059. _ = try await model.countTokens("Why is the sky blue?")
  1060. XCTFail("Request should not have succeeded.")
  1061. } catch let rpcError as RPCError {
  1062. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1063. XCTAssertEqual(rpcError.status, .notFound)
  1064. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1065. return
  1066. }
  1067. XCTFail("Expected internal RPCError.")
  1068. }
  1069. func testCountTokens_requestOptions_customTimeout() async throws {
  1070. let expectedTimeout = 150.0
  1071. MockURLProtocol
  1072. .requestHandler = try httpRequestHandler(
  1073. forResource: "unary-success-total-tokens",
  1074. withExtension: "json",
  1075. timeout: expectedTimeout
  1076. )
  1077. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1078. model = GenerativeModel(
  1079. name: testModelResourceName,
  1080. projectID: "my-project-id",
  1081. apiKey: "API_KEY",
  1082. tools: nil,
  1083. requestOptions: requestOptions,
  1084. appCheck: nil,
  1085. auth: nil,
  1086. urlSession: urlSession
  1087. )
  1088. let response = try await model.countTokens(testPrompt)
  1089. XCTAssertEqual(response.totalTokens, 6)
  1090. }
  1091. // MARK: - Helpers
  1092. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1093. URLResponse,
  1094. AsyncLineSequence<URL.AsyncBytes>?
  1095. )) {
  1096. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1097. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1098. #if os(watchOS)
  1099. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1100. #endif // os(watchOS)
  1101. return { request in
  1102. // This is *not* an HTTPURLResponse
  1103. let response = URLResponse(
  1104. url: request.url!,
  1105. mimeType: nil,
  1106. expectedContentLength: 0,
  1107. textEncodingName: nil
  1108. )
  1109. return (response, nil)
  1110. }
  1111. }
  1112. private func httpRequestHandler(forResource name: String,
  1113. withExtension ext: String,
  1114. statusCode: Int = 200,
  1115. timeout: TimeInterval = RequestOptions().timeout,
  1116. appCheckToken: String? = nil,
  1117. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1118. URLResponse,
  1119. AsyncLineSequence<URL.AsyncBytes>?
  1120. )) {
  1121. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1122. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1123. #if os(watchOS)
  1124. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1125. #endif // os(watchOS)
  1126. #if SWIFT_PACKAGE
  1127. let bundle = Bundle.module
  1128. #else // SWIFT_PACKAGE
  1129. let bundle = Bundle(for: Self.self)
  1130. #endif // SWIFT_PACKAGE
  1131. let fileURL = try XCTUnwrap(bundle.url(forResource: name, withExtension: ext))
  1132. return { request in
  1133. let requestURL = try XCTUnwrap(request.url)
  1134. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1135. XCTAssertEqual(request.timeoutInterval, timeout)
  1136. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1137. .components(separatedBy: " ")
  1138. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1139. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1140. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1141. if let authToken {
  1142. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1143. } else {
  1144. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1145. }
  1146. let response = try XCTUnwrap(HTTPURLResponse(
  1147. url: requestURL,
  1148. statusCode: statusCode,
  1149. httpVersion: nil,
  1150. headerFields: nil
  1151. ))
  1152. return (response, fileURL.lines)
  1153. }
  1154. }
  1155. }
  1156. private extension String {
  1157. /// Returns the number of occurrences of `substring` in the `String`.
  1158. func occurrenceCount(of substring: String) -> Int {
  1159. return components(separatedBy: substring).count - 1
  1160. }
  1161. }
  1162. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1163. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1164. /// The placeholder token value returned when an error occurs
  1165. static let placeholderTokenValue = "placeholder-token"
  1166. var token: String
  1167. var error: Error?
  1168. private init(token: String, error: Error?) {
  1169. self.token = token
  1170. self.error = error
  1171. }
  1172. convenience init(token: String) {
  1173. self.init(token: token, error: nil)
  1174. }
  1175. convenience init(error: Error) {
  1176. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1177. }
  1178. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1179. return AppCheckTokenResultInteropFake(token: token, error: error)
  1180. }
  1181. func tokenDidChangeNotificationName() -> String {
  1182. fatalError("\(#function) not implemented.")
  1183. }
  1184. func notificationTokenKey() -> String {
  1185. fatalError("\(#function) not implemented.")
  1186. }
  1187. func notificationAppNameKey() -> String {
  1188. fatalError("\(#function) not implemented.")
  1189. }
  1190. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1191. var token: String
  1192. var error: Error?
  1193. init(token: String, error: Error?) {
  1194. self.token = token
  1195. self.error = error
  1196. }
  1197. }
  1198. }
  1199. struct AppCheckErrorFake: Error {}
  1200. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1201. extension SafetyRating: Swift.Comparable {
  1202. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1203. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1204. return lhs.category.rawValue < rhs.category.rawValue
  1205. }
  1206. }