GenerativeModelTests.swift 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. let testModelResourceName =
  29. "projects/test-project-id/locations/test-location/publishers/google/models/test-model"
  30. var urlSession: URLSession!
  31. var model: GenerativeModel!
  32. override func setUp() async throws {
  33. let configuration = URLSessionConfiguration.default
  34. configuration.protocolClasses = [MockURLProtocol.self]
  35. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  36. model = GenerativeModel(
  37. name: testModelResourceName,
  38. projectID: "my-project-id",
  39. apiKey: "API_KEY",
  40. tools: nil,
  41. requestOptions: RequestOptions(),
  42. appCheck: nil,
  43. auth: nil,
  44. urlSession: urlSession
  45. )
  46. }
  47. override func tearDown() {
  48. MockURLProtocol.requestHandler = nil
  49. }
  50. // MARK: - Generate Content
  51. func testGenerateContent_success_basicReplyLong() async throws {
  52. MockURLProtocol
  53. .requestHandler = try httpRequestHandler(
  54. forResource: "unary-success-basic-reply-long",
  55. withExtension: "json"
  56. )
  57. let response = try await model.generateContent(testPrompt)
  58. XCTAssertEqual(response.candidates.count, 1)
  59. let candidate = try XCTUnwrap(response.candidates.first)
  60. let finishReason = try XCTUnwrap(candidate.finishReason)
  61. XCTAssertEqual(finishReason, .stop)
  62. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  63. XCTAssertEqual(candidate.content.parts.count, 1)
  64. let part = try XCTUnwrap(candidate.content.parts.first)
  65. let partText = try XCTUnwrap(part.text)
  66. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  67. XCTAssertEqual(response.text, partText)
  68. XCTAssertEqual(response.functionCalls, [])
  69. }
  70. func testGenerateContent_success_basicReplyShort() async throws {
  71. MockURLProtocol
  72. .requestHandler = try httpRequestHandler(
  73. forResource: "unary-success-basic-reply-short",
  74. withExtension: "json"
  75. )
  76. let response = try await model.generateContent(testPrompt)
  77. XCTAssertEqual(response.candidates.count, 1)
  78. let candidate = try XCTUnwrap(response.candidates.first)
  79. let finishReason = try XCTUnwrap(candidate.finishReason)
  80. XCTAssertEqual(finishReason, .stop)
  81. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  82. XCTAssertEqual(candidate.content.parts.count, 1)
  83. let part = try XCTUnwrap(candidate.content.parts.first)
  84. XCTAssertEqual(part.text, "Mountain View, California")
  85. XCTAssertEqual(response.text, part.text)
  86. XCTAssertEqual(response.functionCalls, [])
  87. }
  88. func testGenerateContent_success_citations() async throws {
  89. MockURLProtocol
  90. .requestHandler = try httpRequestHandler(
  91. forResource: "unary-success-citations",
  92. withExtension: "json"
  93. )
  94. let response = try await model.generateContent(testPrompt)
  95. XCTAssertEqual(response.candidates.count, 1)
  96. let candidate = try XCTUnwrap(response.candidates.first)
  97. XCTAssertEqual(candidate.content.parts.count, 1)
  98. XCTAssertEqual(response.text, "Some information cited from an external source")
  99. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  100. XCTAssertEqual(citationMetadata.citationSources.count, 3)
  101. let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0])
  102. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  103. XCTAssertEqual(citationSource1.startIndex, 0)
  104. XCTAssertEqual(citationSource1.endIndex, 128)
  105. XCTAssertNil(citationSource1.title)
  106. XCTAssertNil(citationSource1.license)
  107. let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
  108. XCTAssertEqual(citationSource2.title, "some-citation-2")
  109. XCTAssertEqual(citationSource2.startIndex, 130)
  110. XCTAssertEqual(citationSource2.endIndex, 265)
  111. XCTAssertNil(citationSource2.uri)
  112. XCTAssertNil(citationSource2.license)
  113. let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
  114. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  115. XCTAssertEqual(citationSource3.startIndex, 272)
  116. XCTAssertEqual(citationSource3.endIndex, 431)
  117. XCTAssertEqual(citationSource3.license, "mit")
  118. XCTAssertNil(citationSource3.title)
  119. }
  120. func testGenerateContent_success_quoteReply() async throws {
  121. MockURLProtocol
  122. .requestHandler = try httpRequestHandler(
  123. forResource: "unary-success-quote-reply",
  124. withExtension: "json"
  125. )
  126. let response = try await model.generateContent(testPrompt)
  127. XCTAssertEqual(response.candidates.count, 1)
  128. let candidate = try XCTUnwrap(response.candidates.first)
  129. let finishReason = try XCTUnwrap(candidate.finishReason)
  130. XCTAssertEqual(finishReason, .stop)
  131. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  132. XCTAssertEqual(candidate.content.parts.count, 1)
  133. let part = try XCTUnwrap(candidate.content.parts.first)
  134. let partText = try XCTUnwrap(part.text)
  135. XCTAssertTrue(partText.hasPrefix("Google"))
  136. XCTAssertEqual(response.text, part.text)
  137. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  138. XCTAssertNil(promptFeedback.blockReason)
  139. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  140. }
  141. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  142. let expectedSafetyRatings = [
  143. SafetyRating(category: .harassment, probability: .medium),
  144. SafetyRating(category: .dangerousContent, probability: .unknown),
  145. SafetyRating(category: .unknown, probability: .high),
  146. ]
  147. MockURLProtocol
  148. .requestHandler = try httpRequestHandler(
  149. forResource: "unary-success-unknown-enum-safety-ratings",
  150. withExtension: "json"
  151. )
  152. let response = try await model.generateContent(testPrompt)
  153. XCTAssertEqual(response.text, "Some text")
  154. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  155. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  156. }
  157. func testGenerateContent_success_prefixedModelName() async throws {
  158. MockURLProtocol
  159. .requestHandler = try httpRequestHandler(
  160. forResource: "unary-success-basic-reply-short",
  161. withExtension: "json"
  162. )
  163. let model = GenerativeModel(
  164. // Model name is prefixed with "models/".
  165. name: "models/test-model",
  166. projectID: "my-project-id",
  167. apiKey: "API_KEY",
  168. tools: nil,
  169. requestOptions: RequestOptions(),
  170. appCheck: nil,
  171. auth: nil,
  172. urlSession: urlSession
  173. )
  174. _ = try await model.generateContent(testPrompt)
  175. }
  176. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  177. MockURLProtocol
  178. .requestHandler = try httpRequestHandler(
  179. forResource: "unary-success-function-call-empty-arguments",
  180. withExtension: "json"
  181. )
  182. let response = try await model.generateContent(testPrompt)
  183. XCTAssertEqual(response.candidates.count, 1)
  184. let candidate = try XCTUnwrap(response.candidates.first)
  185. XCTAssertEqual(candidate.content.parts.count, 1)
  186. let part = try XCTUnwrap(candidate.content.parts.first)
  187. guard case let .functionCall(functionCall) = part else {
  188. XCTFail("Part is not a FunctionCall.")
  189. return
  190. }
  191. XCTAssertEqual(functionCall.name, "current_time")
  192. XCTAssertTrue(functionCall.args.isEmpty)
  193. XCTAssertEqual(response.functionCalls, [functionCall])
  194. }
  195. func testGenerateContent_success_functionCall_noArguments() async throws {
  196. MockURLProtocol
  197. .requestHandler = try httpRequestHandler(
  198. forResource: "unary-success-function-call-no-arguments",
  199. withExtension: "json"
  200. )
  201. let response = try await model.generateContent(testPrompt)
  202. XCTAssertEqual(response.candidates.count, 1)
  203. let candidate = try XCTUnwrap(response.candidates.first)
  204. XCTAssertEqual(candidate.content.parts.count, 1)
  205. let part = try XCTUnwrap(candidate.content.parts.first)
  206. guard case let .functionCall(functionCall) = part else {
  207. XCTFail("Part is not a FunctionCall.")
  208. return
  209. }
  210. XCTAssertEqual(functionCall.name, "current_time")
  211. XCTAssertTrue(functionCall.args.isEmpty)
  212. XCTAssertEqual(response.functionCalls, [functionCall])
  213. }
  214. func testGenerateContent_success_functionCall_withArguments() async throws {
  215. MockURLProtocol
  216. .requestHandler = try httpRequestHandler(
  217. forResource: "unary-success-function-call-with-arguments",
  218. withExtension: "json"
  219. )
  220. let response = try await model.generateContent(testPrompt)
  221. XCTAssertEqual(response.candidates.count, 1)
  222. let candidate = try XCTUnwrap(response.candidates.first)
  223. XCTAssertEqual(candidate.content.parts.count, 1)
  224. let part = try XCTUnwrap(candidate.content.parts.first)
  225. guard case let .functionCall(functionCall) = part else {
  226. XCTFail("Part is not a FunctionCall.")
  227. return
  228. }
  229. XCTAssertEqual(functionCall.name, "sum")
  230. XCTAssertEqual(functionCall.args.count, 2)
  231. let argX = try XCTUnwrap(functionCall.args["x"])
  232. XCTAssertEqual(argX, .number(4))
  233. let argY = try XCTUnwrap(functionCall.args["y"])
  234. XCTAssertEqual(argY, .number(5))
  235. XCTAssertEqual(response.functionCalls, [functionCall])
  236. }
  237. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  238. MockURLProtocol
  239. .requestHandler = try httpRequestHandler(
  240. forResource: "unary-success-function-call-parallel-calls",
  241. withExtension: "json"
  242. )
  243. let response = try await model.generateContent(testPrompt)
  244. XCTAssertEqual(response.candidates.count, 1)
  245. let candidate = try XCTUnwrap(response.candidates.first)
  246. XCTAssertEqual(candidate.content.parts.count, 3)
  247. let functionCalls = response.functionCalls
  248. XCTAssertEqual(functionCalls.count, 3)
  249. }
  250. func testGenerateContent_success_functionCall_mixedContent() async throws {
  251. MockURLProtocol
  252. .requestHandler = try httpRequestHandler(
  253. forResource: "unary-success-function-call-mixed-content",
  254. withExtension: "json"
  255. )
  256. let response = try await model.generateContent(testPrompt)
  257. XCTAssertEqual(response.candidates.count, 1)
  258. let candidate = try XCTUnwrap(response.candidates.first)
  259. XCTAssertEqual(candidate.content.parts.count, 4)
  260. let functionCalls = response.functionCalls
  261. XCTAssertEqual(functionCalls.count, 2)
  262. let text = try XCTUnwrap(response.text)
  263. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  264. }
  265. func testGenerateContent_appCheck_validToken() async throws {
  266. let appCheckToken = "test-valid-token"
  267. model = GenerativeModel(
  268. name: testModelResourceName,
  269. projectID: "my-project-id",
  270. apiKey: "API_KEY",
  271. tools: nil,
  272. requestOptions: RequestOptions(),
  273. appCheck: AppCheckInteropFake(token: appCheckToken),
  274. auth: nil,
  275. urlSession: urlSession
  276. )
  277. MockURLProtocol
  278. .requestHandler = try httpRequestHandler(
  279. forResource: "unary-success-basic-reply-short",
  280. withExtension: "json",
  281. appCheckToken: appCheckToken
  282. )
  283. _ = try await model.generateContent(testPrompt)
  284. }
  285. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  286. model = GenerativeModel(
  287. name: testModelResourceName,
  288. projectID: "my-project-id",
  289. apiKey: "API_KEY",
  290. tools: nil,
  291. requestOptions: RequestOptions(),
  292. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  293. auth: nil,
  294. urlSession: urlSession
  295. )
  296. MockURLProtocol
  297. .requestHandler = try httpRequestHandler(
  298. forResource: "unary-success-basic-reply-short",
  299. withExtension: "json",
  300. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  301. )
  302. _ = try await model.generateContent(testPrompt)
  303. }
  304. func testGenerateContent_auth_validAuthToken() async throws {
  305. let authToken = "test-valid-token"
  306. model = GenerativeModel(
  307. name: testModelResourceName,
  308. projectID: "my-project-id",
  309. apiKey: "API_KEY",
  310. tools: nil,
  311. requestOptions: RequestOptions(),
  312. appCheck: nil,
  313. auth: AuthInteropFake(token: authToken),
  314. urlSession: urlSession
  315. )
  316. MockURLProtocol
  317. .requestHandler = try httpRequestHandler(
  318. forResource: "unary-success-basic-reply-short",
  319. withExtension: "json",
  320. authToken: authToken
  321. )
  322. _ = try await model.generateContent(testPrompt)
  323. }
  324. func testGenerateContent_auth_nilAuthToken() async throws {
  325. model = GenerativeModel(
  326. name: testModelResourceName,
  327. projectID: "my-project-id",
  328. apiKey: "API_KEY",
  329. tools: nil,
  330. requestOptions: RequestOptions(),
  331. appCheck: nil,
  332. auth: AuthInteropFake(token: nil),
  333. urlSession: urlSession
  334. )
  335. MockURLProtocol
  336. .requestHandler = try httpRequestHandler(
  337. forResource: "unary-success-basic-reply-short",
  338. withExtension: "json",
  339. authToken: nil
  340. )
  341. _ = try await model.generateContent(testPrompt)
  342. }
  343. func testGenerateContent_auth_authTokenRefreshError() async throws {
  344. model = GenerativeModel(
  345. name: "my-model",
  346. projectID: "my-project-id",
  347. apiKey: "API_KEY",
  348. tools: nil,
  349. requestOptions: RequestOptions(),
  350. appCheck: nil,
  351. auth: AuthInteropFake(error: AuthErrorFake()),
  352. urlSession: urlSession
  353. )
  354. MockURLProtocol
  355. .requestHandler = try httpRequestHandler(
  356. forResource: "unary-success-basic-reply-short",
  357. withExtension: "json",
  358. authToken: nil
  359. )
  360. do {
  361. _ = try await model.generateContent(testPrompt)
  362. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  363. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  364. //
  365. } catch {
  366. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  367. }
  368. }
  369. func testGenerateContent_usageMetadata() async throws {
  370. MockURLProtocol
  371. .requestHandler = try httpRequestHandler(
  372. forResource: "unary-success-basic-reply-short",
  373. withExtension: "json"
  374. )
  375. let response = try await model.generateContent(testPrompt)
  376. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  377. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  378. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  379. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  380. }
  381. func testGenerateContent_failure_invalidAPIKey() async throws {
  382. let expectedStatusCode = 400
  383. MockURLProtocol
  384. .requestHandler = try httpRequestHandler(
  385. forResource: "unary-failure-api-key",
  386. withExtension: "json",
  387. statusCode: expectedStatusCode
  388. )
  389. do {
  390. _ = try await model.generateContent(testPrompt)
  391. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  392. } catch let GenerateContentError.internalError(error as RPCError) {
  393. XCTAssertEqual(error.httpResponseCode, 400)
  394. XCTAssertEqual(error.status, .invalidArgument)
  395. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  396. return
  397. } catch {
  398. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  399. }
  400. }
  401. func testGenerateContent_failure_firebaseMLAPINotEnabled() async throws {
  402. let expectedStatusCode = 403
  403. MockURLProtocol
  404. .requestHandler = try httpRequestHandler(
  405. forResource: "unary-failure-firebaseml-api-not-enabled",
  406. withExtension: "json",
  407. statusCode: expectedStatusCode
  408. )
  409. do {
  410. _ = try await model.generateContent(testPrompt)
  411. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  412. } catch let GenerateContentError.internalError(error as RPCError) {
  413. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  414. XCTAssertEqual(error.status, .permissionDenied)
  415. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  416. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  417. return
  418. } catch {
  419. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  420. }
  421. }
  422. func testGenerateContent_failure_emptyContent() async throws {
  423. MockURLProtocol
  424. .requestHandler = try httpRequestHandler(
  425. forResource: "unary-failure-empty-content",
  426. withExtension: "json"
  427. )
  428. do {
  429. _ = try await model.generateContent(testPrompt)
  430. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  431. } catch let GenerateContentError
  432. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  433. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  434. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  435. return
  436. }
  437. _ = try XCTUnwrap(decodingError as? DecodingError,
  438. "Not a DecodingError: \(decodingError)")
  439. } catch {
  440. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  441. }
  442. }
  443. func testGenerateContent_failure_finishReasonSafety() async throws {
  444. MockURLProtocol
  445. .requestHandler = try httpRequestHandler(
  446. forResource: "unary-failure-finish-reason-safety",
  447. withExtension: "json"
  448. )
  449. do {
  450. _ = try await model.generateContent(testPrompt)
  451. XCTFail("Should throw")
  452. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  453. XCTAssertEqual(reason, .safety)
  454. XCTAssertEqual(response.text, "<redacted>")
  455. } catch {
  456. XCTFail("Should throw a responseStoppedEarly")
  457. }
  458. }
  459. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  460. MockURLProtocol
  461. .requestHandler = try httpRequestHandler(
  462. forResource: "unary-failure-finish-reason-safety-no-content",
  463. withExtension: "json"
  464. )
  465. do {
  466. _ = try await model.generateContent(testPrompt)
  467. XCTFail("Should throw")
  468. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  469. XCTAssertEqual(reason, .safety)
  470. XCTAssertNil(response.text)
  471. } catch {
  472. XCTFail("Should throw a responseStoppedEarly")
  473. }
  474. }
  475. func testGenerateContent_failure_imageRejected() async throws {
  476. let expectedStatusCode = 400
  477. MockURLProtocol
  478. .requestHandler = try httpRequestHandler(
  479. forResource: "unary-failure-image-rejected",
  480. withExtension: "json",
  481. statusCode: 400
  482. )
  483. do {
  484. _ = try await model.generateContent(testPrompt)
  485. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  486. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  487. XCTAssertEqual(rpcError.status, .invalidArgument)
  488. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  489. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  490. } catch {
  491. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  492. }
  493. }
  494. func testGenerateContent_failure_promptBlockedSafety() async throws {
  495. MockURLProtocol
  496. .requestHandler = try httpRequestHandler(
  497. forResource: "unary-failure-prompt-blocked-safety",
  498. withExtension: "json"
  499. )
  500. do {
  501. _ = try await model.generateContent(testPrompt)
  502. XCTFail("Should throw")
  503. } catch let GenerateContentError.promptBlocked(response) {
  504. XCTAssertNil(response.text)
  505. } catch {
  506. XCTFail("Should throw a promptBlocked")
  507. }
  508. }
  509. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  510. MockURLProtocol
  511. .requestHandler = try httpRequestHandler(
  512. forResource: "unary-failure-unknown-enum-finish-reason",
  513. withExtension: "json"
  514. )
  515. do {
  516. _ = try await model.generateContent(testPrompt)
  517. XCTFail("Should throw")
  518. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  519. XCTAssertEqual(reason, .unknown)
  520. XCTAssertEqual(response.text, "Some text")
  521. } catch {
  522. XCTFail("Should throw a responseStoppedEarly")
  523. }
  524. }
  525. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  526. MockURLProtocol
  527. .requestHandler = try httpRequestHandler(
  528. forResource: "unary-failure-unknown-enum-prompt-blocked",
  529. withExtension: "json"
  530. )
  531. do {
  532. _ = try await model.generateContent(testPrompt)
  533. XCTFail("Should throw")
  534. } catch let GenerateContentError.promptBlocked(response) {
  535. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  536. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  537. } catch {
  538. XCTFail("Should throw a promptBlocked")
  539. }
  540. }
  541. func testGenerateContent_failure_unknownModel() async throws {
  542. let expectedStatusCode = 404
  543. MockURLProtocol
  544. .requestHandler = try httpRequestHandler(
  545. forResource: "unary-failure-unknown-model",
  546. withExtension: "json",
  547. statusCode: 404
  548. )
  549. do {
  550. _ = try await model.generateContent(testPrompt)
  551. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  552. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  553. XCTAssertEqual(rpcError.status, .notFound)
  554. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  555. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  556. } catch {
  557. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  558. }
  559. }
  560. func testGenerateContent_failure_nonHTTPResponse() async throws {
  561. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  562. var responseError: Error?
  563. var content: GenerateContentResponse?
  564. do {
  565. content = try await model.generateContent(testPrompt)
  566. } catch {
  567. responseError = error
  568. }
  569. XCTAssertNil(content)
  570. XCTAssertNotNil(responseError)
  571. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  572. guard case let .internalError(underlyingError) = generateContentError else {
  573. XCTFail("Not an internal error: \(generateContentError)")
  574. return
  575. }
  576. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  577. }
  578. func testGenerateContent_failure_invalidResponse() async throws {
  579. MockURLProtocol.requestHandler = try httpRequestHandler(
  580. forResource: "unary-failure-invalid-response",
  581. withExtension: "json"
  582. )
  583. var responseError: Error?
  584. var content: GenerateContentResponse?
  585. do {
  586. content = try await model.generateContent(testPrompt)
  587. } catch {
  588. responseError = error
  589. }
  590. XCTAssertNil(content)
  591. XCTAssertNotNil(responseError)
  592. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  593. guard case let .internalError(underlyingError) = generateContentError else {
  594. XCTFail("Not an internal error: \(generateContentError)")
  595. return
  596. }
  597. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  598. guard case let .dataCorrupted(context) = decodingError else {
  599. XCTFail("Not a data corrupted error: \(decodingError)")
  600. return
  601. }
  602. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  603. }
  604. func testGenerateContent_failure_malformedContent() async throws {
  605. MockURLProtocol
  606. .requestHandler = try httpRequestHandler(
  607. forResource: "unary-failure-malformed-content",
  608. withExtension: "json"
  609. )
  610. var responseError: Error?
  611. var content: GenerateContentResponse?
  612. do {
  613. content = try await model.generateContent(testPrompt)
  614. } catch {
  615. responseError = error
  616. }
  617. XCTAssertNil(content)
  618. XCTAssertNotNil(responseError)
  619. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  620. guard case let .internalError(underlyingError) = generateContentError else {
  621. XCTFail("Not an internal error: \(generateContentError)")
  622. return
  623. }
  624. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  625. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  626. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  627. return
  628. }
  629. _ = try XCTUnwrap(
  630. malformedContentUnderlyingError as? DecodingError,
  631. "Not a decoding error: \(malformedContentUnderlyingError)"
  632. )
  633. }
  634. func testGenerateContentMissingSafetyRatings() async throws {
  635. MockURLProtocol.requestHandler = try httpRequestHandler(
  636. forResource: "unary-success-missing-safety-ratings",
  637. withExtension: "json"
  638. )
  639. let content = try await model.generateContent(testPrompt)
  640. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  641. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  642. XCTAssertEqual(content.text, "This is the generated content.")
  643. }
  644. func testGenerateContent_requestOptions_customTimeout() async throws {
  645. let expectedTimeout = 150.0
  646. MockURLProtocol
  647. .requestHandler = try httpRequestHandler(
  648. forResource: "unary-success-basic-reply-short",
  649. withExtension: "json",
  650. timeout: expectedTimeout
  651. )
  652. let requestOptions = RequestOptions(timeout: expectedTimeout)
  653. model = GenerativeModel(
  654. name: testModelResourceName,
  655. projectID: "my-project-id",
  656. apiKey: "API_KEY",
  657. tools: nil,
  658. requestOptions: requestOptions,
  659. appCheck: nil,
  660. auth: nil,
  661. urlSession: urlSession
  662. )
  663. let response = try await model.generateContent(testPrompt)
  664. XCTAssertEqual(response.candidates.count, 1)
  665. }
  666. // MARK: - Generate Content (Streaming)
  667. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  668. MockURLProtocol
  669. .requestHandler = try httpRequestHandler(
  670. forResource: "unary-failure-api-key",
  671. withExtension: "json"
  672. )
  673. do {
  674. let stream = try await model.generateContentStream("Hi")
  675. for try await _ in stream {
  676. XCTFail("No content is there, this shouldn't happen.")
  677. }
  678. } catch let GenerateContentError.internalError(error as RPCError) {
  679. XCTAssertEqual(error.httpResponseCode, 400)
  680. XCTAssertEqual(error.status, .invalidArgument)
  681. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  682. return
  683. }
  684. XCTFail("Should have caught an error.")
  685. }
  686. func testGenerateContentStream_failure_firebaseMLAPINotEnabled() async throws {
  687. let expectedStatusCode = 403
  688. MockURLProtocol
  689. .requestHandler = try httpRequestHandler(
  690. forResource: "unary-failure-firebaseml-api-not-enabled",
  691. withExtension: "json",
  692. statusCode: expectedStatusCode
  693. )
  694. do {
  695. let stream = try await model.generateContentStream(testPrompt)
  696. for try await _ in stream {
  697. XCTFail("No content is there, this shouldn't happen.")
  698. }
  699. } catch let GenerateContentError.internalError(error as RPCError) {
  700. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  701. XCTAssertEqual(error.status, .permissionDenied)
  702. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  703. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  704. return
  705. }
  706. XCTFail("Should have caught an error.")
  707. }
  708. func testGenerateContentStream_failureEmptyContent() async throws {
  709. MockURLProtocol
  710. .requestHandler = try httpRequestHandler(
  711. forResource: "streaming-failure-empty-content",
  712. withExtension: "txt"
  713. )
  714. do {
  715. let stream = try await model.generateContentStream("Hi")
  716. for try await _ in stream {
  717. XCTFail("No content is there, this shouldn't happen.")
  718. }
  719. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  720. // Underlying error is as expected, nothing else to check.
  721. return
  722. }
  723. XCTFail("Should have caught an error.")
  724. }
  725. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  726. MockURLProtocol
  727. .requestHandler = try httpRequestHandler(
  728. forResource: "streaming-failure-finish-reason-safety",
  729. withExtension: "txt"
  730. )
  731. do {
  732. let stream = try await model.generateContentStream("Hi")
  733. for try await _ in stream {
  734. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  735. }
  736. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  737. XCTAssertEqual(reason, .safety)
  738. return
  739. }
  740. XCTFail("Should have caught an error.")
  741. }
  742. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  743. MockURLProtocol
  744. .requestHandler = try httpRequestHandler(
  745. forResource: "streaming-failure-prompt-blocked-safety",
  746. withExtension: "txt"
  747. )
  748. do {
  749. let stream = try await model.generateContentStream("Hi")
  750. for try await _ in stream {
  751. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  752. }
  753. } catch let GenerateContentError.promptBlocked(response) {
  754. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  755. return
  756. }
  757. XCTFail("Should have caught an error.")
  758. }
  759. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  760. MockURLProtocol
  761. .requestHandler = try httpRequestHandler(
  762. forResource: "streaming-failure-unknown-finish-enum",
  763. withExtension: "txt"
  764. )
  765. let stream = try await model.generateContentStream("Hi")
  766. do {
  767. for try await content in stream {
  768. XCTAssertNotNil(content.text)
  769. }
  770. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  771. XCTAssertEqual(reason, .unknown)
  772. return
  773. }
  774. XCTFail("Should have caught an error.")
  775. }
  776. func testGenerateContentStream_successBasicReplyLong() async throws {
  777. MockURLProtocol
  778. .requestHandler = try httpRequestHandler(
  779. forResource: "streaming-success-basic-reply-long",
  780. withExtension: "txt"
  781. )
  782. var responses = 0
  783. let stream = try await model.generateContentStream("Hi")
  784. for try await content in stream {
  785. XCTAssertNotNil(content.text)
  786. responses += 1
  787. }
  788. XCTAssertEqual(responses, 6)
  789. }
  790. func testGenerateContentStream_successBasicReplyShort() async throws {
  791. MockURLProtocol
  792. .requestHandler = try httpRequestHandler(
  793. forResource: "streaming-success-basic-reply-short",
  794. withExtension: "txt"
  795. )
  796. var responses = 0
  797. let stream = try await model.generateContentStream("Hi")
  798. for try await content in stream {
  799. XCTAssertNotNil(content.text)
  800. responses += 1
  801. }
  802. XCTAssertEqual(responses, 1)
  803. }
  804. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  805. MockURLProtocol
  806. .requestHandler = try httpRequestHandler(
  807. forResource: "streaming-success-unknown-safety-enum",
  808. withExtension: "txt"
  809. )
  810. var hadUnknown = false
  811. let stream = try await model.generateContentStream("Hi")
  812. for try await content in stream {
  813. XCTAssertNotNil(content.text)
  814. if let ratings = content.candidates.first?.safetyRatings,
  815. ratings.contains(where: { $0.category == .unknown }) {
  816. hadUnknown = true
  817. }
  818. }
  819. XCTAssertTrue(hadUnknown)
  820. }
  821. func testGenerateContentStream_successWithCitations() async throws {
  822. MockURLProtocol
  823. .requestHandler = try httpRequestHandler(
  824. forResource: "streaming-success-citations",
  825. withExtension: "txt"
  826. )
  827. let stream = try await model.generateContentStream("Hi")
  828. var citations = [Citation]()
  829. var responses = [GenerateContentResponse]()
  830. for try await content in stream {
  831. responses.append(content)
  832. XCTAssertNotNil(content.text)
  833. let candidate = try XCTUnwrap(content.candidates.first)
  834. if let sources = candidate.citationMetadata?.citationSources {
  835. citations.append(contentsOf: sources)
  836. }
  837. }
  838. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  839. XCTAssertEqual(lastCandidate.finishReason, .stop)
  840. XCTAssertEqual(citations.count, 6)
  841. XCTAssertTrue(citations
  842. .contains {
  843. $0.startIndex == 0 && $0.endIndex == 128
  844. && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
  845. && $0.license == nil
  846. })
  847. XCTAssertTrue(citations
  848. .contains {
  849. $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
  850. && $0.title == "some-citation-2" && $0.license == nil
  851. })
  852. XCTAssertTrue(citations
  853. .contains {
  854. $0.startIndex == 272 && $0.endIndex == 431
  855. && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
  856. && $0.license == "mit"
  857. })
  858. XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
  859. XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
  860. XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
  861. }
  862. func testGenerateContentStream_appCheck_validToken() async throws {
  863. let appCheckToken = "test-valid-token"
  864. model = GenerativeModel(
  865. name: testModelResourceName,
  866. projectID: "my-project-id",
  867. apiKey: "API_KEY",
  868. tools: nil,
  869. requestOptions: RequestOptions(),
  870. appCheck: AppCheckInteropFake(token: appCheckToken),
  871. auth: nil,
  872. urlSession: urlSession
  873. )
  874. MockURLProtocol
  875. .requestHandler = try httpRequestHandler(
  876. forResource: "streaming-success-basic-reply-short",
  877. withExtension: "txt",
  878. appCheckToken: appCheckToken
  879. )
  880. let stream = try await model.generateContentStream(testPrompt)
  881. for try await _ in stream {}
  882. }
  883. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  884. model = GenerativeModel(
  885. name: testModelResourceName,
  886. projectID: "my-project-id",
  887. apiKey: "API_KEY",
  888. tools: nil,
  889. requestOptions: RequestOptions(),
  890. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  891. auth: nil,
  892. urlSession: urlSession
  893. )
  894. MockURLProtocol
  895. .requestHandler = try httpRequestHandler(
  896. forResource: "streaming-success-basic-reply-short",
  897. withExtension: "txt",
  898. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  899. )
  900. let stream = try await model.generateContentStream(testPrompt)
  901. for try await _ in stream {}
  902. }
  903. func testGenerateContentStream_usageMetadata() async throws {
  904. MockURLProtocol
  905. .requestHandler = try httpRequestHandler(
  906. forResource: "streaming-success-basic-reply-short",
  907. withExtension: "txt"
  908. )
  909. var responses = [GenerateContentResponse]()
  910. let stream = try await model.generateContentStream(testPrompt)
  911. for try await response in stream {
  912. responses.append(response)
  913. }
  914. for (index, response) in responses.enumerated() {
  915. if index == responses.endIndex - 1 {
  916. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  917. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  918. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  919. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  920. } else {
  921. // Only the last streamed response contains usage metadata
  922. XCTAssertNil(response.usageMetadata)
  923. }
  924. }
  925. }
  926. func testGenerateContentStream_errorMidStream() async throws {
  927. MockURLProtocol.requestHandler = try httpRequestHandler(
  928. forResource: "streaming-failure-error-mid-stream",
  929. withExtension: "txt"
  930. )
  931. var responseCount = 0
  932. do {
  933. let stream = try await model.generateContentStream("Hi")
  934. for try await content in stream {
  935. XCTAssertNotNil(content.text)
  936. responseCount += 1
  937. }
  938. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  939. XCTAssertEqual(rpcError.httpResponseCode, 499)
  940. XCTAssertEqual(rpcError.status, .cancelled)
  941. // Check the content count is correct.
  942. XCTAssertEqual(responseCount, 2)
  943. return
  944. }
  945. XCTFail("Expected an internalError with an RPCError.")
  946. }
  947. func testGenerateContentStream_nonHTTPResponse() async throws {
  948. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  949. let stream = try await model.generateContentStream("Hi")
  950. do {
  951. for try await content in stream {
  952. XCTFail("Unexpected content in stream: \(content)")
  953. }
  954. } catch let GenerateContentError.internalError(underlying) {
  955. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  956. return
  957. }
  958. XCTFail("Expected an internal error.")
  959. }
  960. func testGenerateContentStream_invalidResponse() async throws {
  961. MockURLProtocol
  962. .requestHandler = try httpRequestHandler(
  963. forResource: "streaming-failure-invalid-json",
  964. withExtension: "txt"
  965. )
  966. let stream = try await model.generateContentStream(testPrompt)
  967. do {
  968. for try await content in stream {
  969. XCTFail("Unexpected content in stream: \(content)")
  970. }
  971. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  972. guard case let .dataCorrupted(context) = underlying else {
  973. XCTFail("Not a data corrupted error: \(underlying)")
  974. return
  975. }
  976. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  977. return
  978. }
  979. XCTFail("Expected an internal error.")
  980. }
  981. func testGenerateContentStream_malformedContent() async throws {
  982. MockURLProtocol
  983. .requestHandler = try httpRequestHandler(
  984. forResource: "streaming-failure-malformed-content",
  985. withExtension: "txt"
  986. )
  987. let stream = try await model.generateContentStream(testPrompt)
  988. do {
  989. for try await content in stream {
  990. XCTFail("Unexpected content in stream: \(content)")
  991. }
  992. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  993. guard case let .malformedContent(contentError) = underlyingError else {
  994. XCTFail("Not a malformed content error: \(underlyingError)")
  995. return
  996. }
  997. XCTAssert(contentError is DecodingError)
  998. return
  999. }
  1000. XCTFail("Expected an internal decoding error.")
  1001. }
  1002. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  1003. let expectedTimeout = 150.0
  1004. MockURLProtocol
  1005. .requestHandler = try httpRequestHandler(
  1006. forResource: "streaming-success-basic-reply-short",
  1007. withExtension: "txt",
  1008. timeout: expectedTimeout
  1009. )
  1010. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1011. model = GenerativeModel(
  1012. name: testModelResourceName,
  1013. projectID: "my-project-id",
  1014. apiKey: "API_KEY",
  1015. tools: nil,
  1016. requestOptions: requestOptions,
  1017. appCheck: nil,
  1018. auth: nil,
  1019. urlSession: urlSession
  1020. )
  1021. var responses = 0
  1022. let stream = try await model.generateContentStream(testPrompt)
  1023. for try await content in stream {
  1024. XCTAssertNotNil(content.text)
  1025. responses += 1
  1026. }
  1027. XCTAssertEqual(responses, 1)
  1028. }
  1029. // MARK: - Count Tokens
  1030. func testCountTokens_succeeds() async throws {
  1031. MockURLProtocol.requestHandler = try httpRequestHandler(
  1032. forResource: "unary-success-total-tokens",
  1033. withExtension: "json"
  1034. )
  1035. let response = try await model.countTokens("Why is the sky blue?")
  1036. XCTAssertEqual(response.totalTokens, 6)
  1037. XCTAssertEqual(response.totalBillableCharacters, 16)
  1038. }
  1039. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1040. MockURLProtocol.requestHandler = try httpRequestHandler(
  1041. forResource: "unary-success-no-billable-characters",
  1042. withExtension: "json"
  1043. )
  1044. let response = try await model.countTokens(ModelContent.Part.data(
  1045. mimetype: "image/jpeg",
  1046. Data()
  1047. ))
  1048. XCTAssertEqual(response.totalTokens, 258)
  1049. XCTAssertEqual(response.totalBillableCharacters, 0)
  1050. }
  1051. func testCountTokens_modelNotFound() async throws {
  1052. MockURLProtocol.requestHandler = try httpRequestHandler(
  1053. forResource: "unary-failure-model-not-found", withExtension: "json",
  1054. statusCode: 404
  1055. )
  1056. do {
  1057. _ = try await model.countTokens("Why is the sky blue?")
  1058. XCTFail("Request should not have succeeded.")
  1059. } catch let CountTokensError.internalError(rpcError as RPCError) {
  1060. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1061. XCTAssertEqual(rpcError.status, .notFound)
  1062. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1063. return
  1064. }
  1065. XCTFail("Expected internal RPCError.")
  1066. }
  1067. func testCountTokens_requestOptions_customTimeout() async throws {
  1068. let expectedTimeout = 150.0
  1069. MockURLProtocol
  1070. .requestHandler = try httpRequestHandler(
  1071. forResource: "unary-success-total-tokens",
  1072. withExtension: "json",
  1073. timeout: expectedTimeout
  1074. )
  1075. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1076. model = GenerativeModel(
  1077. name: testModelResourceName,
  1078. projectID: "my-project-id",
  1079. apiKey: "API_KEY",
  1080. tools: nil,
  1081. requestOptions: requestOptions,
  1082. appCheck: nil,
  1083. auth: nil,
  1084. urlSession: urlSession
  1085. )
  1086. let response = try await model.countTokens(testPrompt)
  1087. XCTAssertEqual(response.totalTokens, 6)
  1088. }
  1089. // MARK: - Helpers
  1090. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1091. URLResponse,
  1092. AsyncLineSequence<URL.AsyncBytes>?
  1093. )) {
  1094. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1095. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1096. #if os(watchOS)
  1097. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1098. #endif // os(watchOS)
  1099. return { request in
  1100. // This is *not* an HTTPURLResponse
  1101. let response = URLResponse(
  1102. url: request.url!,
  1103. mimeType: nil,
  1104. expectedContentLength: 0,
  1105. textEncodingName: nil
  1106. )
  1107. return (response, nil)
  1108. }
  1109. }
  1110. private func httpRequestHandler(forResource name: String,
  1111. withExtension ext: String,
  1112. statusCode: Int = 200,
  1113. timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
  1114. appCheckToken: String? = nil,
  1115. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1116. URLResponse,
  1117. AsyncLineSequence<URL.AsyncBytes>?
  1118. )) {
  1119. // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
  1120. // https://developer.apple.com/documentation/foundation/urlprotocol for details.
  1121. #if os(watchOS)
  1122. throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
  1123. #endif // os(watchOS)
  1124. let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
  1125. return { request in
  1126. let requestURL = try XCTUnwrap(request.url)
  1127. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1128. XCTAssertEqual(request.timeoutInterval, timeout)
  1129. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1130. .components(separatedBy: " ")
  1131. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1132. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1133. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1134. if let authToken {
  1135. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1136. } else {
  1137. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1138. }
  1139. let response = try XCTUnwrap(HTTPURLResponse(
  1140. url: requestURL,
  1141. statusCode: statusCode,
  1142. httpVersion: nil,
  1143. headerFields: nil
  1144. ))
  1145. return (response, fileURL.lines)
  1146. }
  1147. }
  1148. }
  1149. private extension String {
  1150. /// Returns the number of occurrences of `substring` in the `String`.
  1151. func occurrenceCount(of substring: String) -> Int {
  1152. return components(separatedBy: substring).count - 1
  1153. }
  1154. }
  1155. private extension URLRequest {
  1156. /// Returns the default `timeoutInterval` for a `URLRequest`.
  1157. static func defaultTimeoutInterval() -> TimeInterval {
  1158. let placeholderURL = URL(string: "https://example.com")!
  1159. return URLRequest(url: placeholderURL).timeoutInterval
  1160. }
  1161. }
  1162. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1163. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1164. /// The placeholder token value returned when an error occurs
  1165. static let placeholderTokenValue = "placeholder-token"
  1166. var token: String
  1167. var error: Error?
  1168. private init(token: String, error: Error?) {
  1169. self.token = token
  1170. self.error = error
  1171. }
  1172. convenience init(token: String) {
  1173. self.init(token: token, error: nil)
  1174. }
  1175. convenience init(error: Error) {
  1176. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1177. }
  1178. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1179. return AppCheckTokenResultInteropFake(token: token, error: error)
  1180. }
  1181. func tokenDidChangeNotificationName() -> String {
  1182. fatalError("\(#function) not implemented.")
  1183. }
  1184. func notificationTokenKey() -> String {
  1185. fatalError("\(#function) not implemented.")
  1186. }
  1187. func notificationAppNameKey() -> String {
  1188. fatalError("\(#function) not implemented.")
  1189. }
  1190. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1191. var token: String
  1192. var error: Error?
  1193. init(token: String, error: Error?) {
  1194. self.token = token
  1195. self.error = error
  1196. }
  1197. }
  1198. }
  1199. struct AppCheckErrorFake: Error {}
  1200. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  1201. extension SafetyRating: Comparable {
  1202. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1203. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1204. return lhs.category.rawValue < rhs.category.rawValue
  1205. }
  1206. }