GenerativeModelTests.swift 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseVertexAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
  20. final class GenerativeModelTests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(category: .sexuallyExplicit, probability: .negligible),
  24. .init(category: .hateSpeech, probability: .negligible),
  25. .init(category: .harassment, probability: .negligible),
  26. .init(category: .dangerousContent, probability: .negligible),
  27. ].sorted()
  28. var urlSession: URLSession!
  29. var model: GenerativeModel!
  30. override func setUp() async throws {
  31. let configuration = URLSessionConfiguration.default
  32. configuration.protocolClasses = [MockURLProtocol.self]
  33. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  34. model = GenerativeModel(
  35. name: "my-model",
  36. projectID: "my-project-id",
  37. apiKey: "API_KEY",
  38. tools: nil,
  39. requestOptions: RequestOptions(),
  40. appCheck: nil,
  41. auth: nil,
  42. urlSession: urlSession
  43. )
  44. }
  45. override func tearDown() {
  46. MockURLProtocol.requestHandler = nil
  47. }
  48. // MARK: - Generate Content
  49. func testGenerateContent_success_basicReplyLong() async throws {
  50. MockURLProtocol
  51. .requestHandler = try httpRequestHandler(
  52. forResource: "unary-success-basic-reply-long",
  53. withExtension: "json"
  54. )
  55. let response = try await model.generateContent(testPrompt)
  56. XCTAssertEqual(response.candidates.count, 1)
  57. let candidate = try XCTUnwrap(response.candidates.first)
  58. let finishReason = try XCTUnwrap(candidate.finishReason)
  59. XCTAssertEqual(finishReason, .stop)
  60. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  61. XCTAssertEqual(candidate.content.parts.count, 1)
  62. let part = try XCTUnwrap(candidate.content.parts.first)
  63. let partText = try XCTUnwrap(part.text)
  64. XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:"))
  65. XCTAssertEqual(response.text, partText)
  66. XCTAssertEqual(response.functionCalls, [])
  67. }
  68. func testGenerateContent_success_basicReplyShort() async throws {
  69. MockURLProtocol
  70. .requestHandler = try httpRequestHandler(
  71. forResource: "unary-success-basic-reply-short",
  72. withExtension: "json"
  73. )
  74. let response = try await model.generateContent(testPrompt)
  75. XCTAssertEqual(response.candidates.count, 1)
  76. let candidate = try XCTUnwrap(response.candidates.first)
  77. let finishReason = try XCTUnwrap(candidate.finishReason)
  78. XCTAssertEqual(finishReason, .stop)
  79. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  80. XCTAssertEqual(candidate.content.parts.count, 1)
  81. let part = try XCTUnwrap(candidate.content.parts.first)
  82. XCTAssertEqual(part.text, "Mountain View, California")
  83. XCTAssertEqual(response.text, part.text)
  84. XCTAssertEqual(response.functionCalls, [])
  85. }
  86. func testGenerateContent_success_citations() async throws {
  87. MockURLProtocol
  88. .requestHandler = try httpRequestHandler(
  89. forResource: "unary-success-citations",
  90. withExtension: "json"
  91. )
  92. let response = try await model.generateContent(testPrompt)
  93. XCTAssertEqual(response.candidates.count, 1)
  94. let candidate = try XCTUnwrap(response.candidates.first)
  95. XCTAssertEqual(candidate.content.parts.count, 1)
  96. XCTAssertEqual(response.text, "Some information cited from an external source")
  97. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  98. XCTAssertEqual(citationMetadata.citationSources.count, 3)
  99. let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0])
  100. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  101. XCTAssertEqual(citationSource1.startIndex, 0)
  102. XCTAssertEqual(citationSource1.endIndex, 128)
  103. XCTAssertNil(citationSource1.license)
  104. let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
  105. XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2")
  106. XCTAssertEqual(citationSource2.startIndex, 130)
  107. XCTAssertEqual(citationSource2.endIndex, 265)
  108. XCTAssertNil(citationSource2.license)
  109. let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
  110. XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
  111. XCTAssertEqual(citationSource3.startIndex, 272)
  112. XCTAssertEqual(citationSource3.endIndex, 431)
  113. XCTAssertEqual(citationSource3.license, "mit")
  114. }
  115. func testGenerateContent_success_quoteReply() async throws {
  116. MockURLProtocol
  117. .requestHandler = try httpRequestHandler(
  118. forResource: "unary-success-quote-reply",
  119. withExtension: "json"
  120. )
  121. let response = try await model.generateContent(testPrompt)
  122. XCTAssertEqual(response.candidates.count, 1)
  123. let candidate = try XCTUnwrap(response.candidates.first)
  124. let finishReason = try XCTUnwrap(candidate.finishReason)
  125. XCTAssertEqual(finishReason, .stop)
  126. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  127. XCTAssertEqual(candidate.content.parts.count, 1)
  128. let part = try XCTUnwrap(candidate.content.parts.first)
  129. let partText = try XCTUnwrap(part.text)
  130. XCTAssertTrue(partText.hasPrefix("Google"))
  131. XCTAssertEqual(response.text, part.text)
  132. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  133. XCTAssertNil(promptFeedback.blockReason)
  134. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  135. }
  136. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  137. let expectedSafetyRatings = [
  138. SafetyRating(category: .harassment, probability: .medium),
  139. SafetyRating(category: .dangerousContent, probability: .unknown),
  140. SafetyRating(category: .unknown, probability: .high),
  141. ]
  142. MockURLProtocol
  143. .requestHandler = try httpRequestHandler(
  144. forResource: "unary-success-unknown-enum-safety-ratings",
  145. withExtension: "json"
  146. )
  147. let response = try await model.generateContent(testPrompt)
  148. XCTAssertEqual(response.text, "Some text")
  149. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  150. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  151. }
  152. func testGenerateContent_success_prefixedModelName() async throws {
  153. MockURLProtocol
  154. .requestHandler = try httpRequestHandler(
  155. forResource: "unary-success-basic-reply-short",
  156. withExtension: "json"
  157. )
  158. let model = GenerativeModel(
  159. // Model name is prefixed with "models/".
  160. name: "models/test-model",
  161. projectID: "my-project-id",
  162. apiKey: "API_KEY",
  163. tools: nil,
  164. requestOptions: RequestOptions(),
  165. appCheck: nil,
  166. auth: nil,
  167. urlSession: urlSession
  168. )
  169. _ = try await model.generateContent(testPrompt)
  170. }
  171. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  172. MockURLProtocol
  173. .requestHandler = try httpRequestHandler(
  174. forResource: "unary-success-function-call-empty-arguments",
  175. withExtension: "json"
  176. )
  177. let response = try await model.generateContent(testPrompt)
  178. XCTAssertEqual(response.candidates.count, 1)
  179. let candidate = try XCTUnwrap(response.candidates.first)
  180. XCTAssertEqual(candidate.content.parts.count, 1)
  181. let part = try XCTUnwrap(candidate.content.parts.first)
  182. guard case let .functionCall(functionCall) = part else {
  183. XCTFail("Part is not a FunctionCall.")
  184. return
  185. }
  186. XCTAssertEqual(functionCall.name, "current_time")
  187. XCTAssertTrue(functionCall.args.isEmpty)
  188. XCTAssertEqual(response.functionCalls, [functionCall])
  189. }
  190. func testGenerateContent_success_functionCall_noArguments() async throws {
  191. MockURLProtocol
  192. .requestHandler = try httpRequestHandler(
  193. forResource: "unary-success-function-call-no-arguments",
  194. withExtension: "json"
  195. )
  196. let response = try await model.generateContent(testPrompt)
  197. XCTAssertEqual(response.candidates.count, 1)
  198. let candidate = try XCTUnwrap(response.candidates.first)
  199. XCTAssertEqual(candidate.content.parts.count, 1)
  200. let part = try XCTUnwrap(candidate.content.parts.first)
  201. guard case let .functionCall(functionCall) = part else {
  202. XCTFail("Part is not a FunctionCall.")
  203. return
  204. }
  205. XCTAssertEqual(functionCall.name, "current_time")
  206. XCTAssertTrue(functionCall.args.isEmpty)
  207. XCTAssertEqual(response.functionCalls, [functionCall])
  208. }
  209. func testGenerateContent_success_functionCall_withArguments() async throws {
  210. MockURLProtocol
  211. .requestHandler = try httpRequestHandler(
  212. forResource: "unary-success-function-call-with-arguments",
  213. withExtension: "json"
  214. )
  215. let response = try await model.generateContent(testPrompt)
  216. XCTAssertEqual(response.candidates.count, 1)
  217. let candidate = try XCTUnwrap(response.candidates.first)
  218. XCTAssertEqual(candidate.content.parts.count, 1)
  219. let part = try XCTUnwrap(candidate.content.parts.first)
  220. guard case let .functionCall(functionCall) = part else {
  221. XCTFail("Part is not a FunctionCall.")
  222. return
  223. }
  224. XCTAssertEqual(functionCall.name, "sum")
  225. XCTAssertEqual(functionCall.args.count, 2)
  226. let argX = try XCTUnwrap(functionCall.args["x"])
  227. XCTAssertEqual(argX, .number(4))
  228. let argY = try XCTUnwrap(functionCall.args["y"])
  229. XCTAssertEqual(argY, .number(5))
  230. XCTAssertEqual(response.functionCalls, [functionCall])
  231. }
  232. func testGenerateContent_success_functionCall_parallelCalls() async throws {
  233. MockURLProtocol
  234. .requestHandler = try httpRequestHandler(
  235. forResource: "unary-success-function-call-parallel-calls",
  236. withExtension: "json"
  237. )
  238. let response = try await model.generateContent(testPrompt)
  239. XCTAssertEqual(response.candidates.count, 1)
  240. let candidate = try XCTUnwrap(response.candidates.first)
  241. XCTAssertEqual(candidate.content.parts.count, 3)
  242. let functionCalls = response.functionCalls
  243. XCTAssertEqual(functionCalls.count, 3)
  244. }
  245. func testGenerateContent_success_functionCall_mixedContent() async throws {
  246. MockURLProtocol
  247. .requestHandler = try httpRequestHandler(
  248. forResource: "unary-success-function-call-mixed-content",
  249. withExtension: "json"
  250. )
  251. let response = try await model.generateContent(testPrompt)
  252. XCTAssertEqual(response.candidates.count, 1)
  253. let candidate = try XCTUnwrap(response.candidates.first)
  254. XCTAssertEqual(candidate.content.parts.count, 4)
  255. let functionCalls = response.functionCalls
  256. XCTAssertEqual(functionCalls.count, 2)
  257. let text = try XCTUnwrap(response.text)
  258. XCTAssertEqual(text, "The sum of [1, 2, 3] is")
  259. }
  260. func testGenerateContent_appCheck_validToken() async throws {
  261. let appCheckToken = "test-valid-token"
  262. model = GenerativeModel(
  263. name: "my-model",
  264. projectID: "my-project-id",
  265. apiKey: "API_KEY",
  266. tools: nil,
  267. requestOptions: RequestOptions(),
  268. appCheck: AppCheckInteropFake(token: appCheckToken),
  269. auth: nil,
  270. urlSession: urlSession
  271. )
  272. MockURLProtocol
  273. .requestHandler = try httpRequestHandler(
  274. forResource: "unary-success-basic-reply-short",
  275. withExtension: "json",
  276. appCheckToken: appCheckToken
  277. )
  278. _ = try await model.generateContent(testPrompt)
  279. }
  280. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  281. model = GenerativeModel(
  282. name: "my-model",
  283. projectID: "my-project-id",
  284. apiKey: "API_KEY",
  285. tools: nil,
  286. requestOptions: RequestOptions(),
  287. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  288. auth: nil,
  289. urlSession: urlSession
  290. )
  291. MockURLProtocol
  292. .requestHandler = try httpRequestHandler(
  293. forResource: "unary-success-basic-reply-short",
  294. withExtension: "json",
  295. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  296. )
  297. _ = try await model.generateContent(testPrompt)
  298. }
  299. func testGenerateContent_auth_validAuthToken() async throws {
  300. let authToken = "test-valid-token"
  301. model = GenerativeModel(
  302. name: "my-model",
  303. projectID: "my-project-id",
  304. apiKey: "API_KEY",
  305. tools: nil,
  306. requestOptions: RequestOptions(),
  307. appCheck: nil,
  308. auth: AuthInteropFake(token: authToken),
  309. urlSession: urlSession
  310. )
  311. MockURLProtocol
  312. .requestHandler = try httpRequestHandler(
  313. forResource: "unary-success-basic-reply-short",
  314. withExtension: "json",
  315. authToken: authToken
  316. )
  317. _ = try await model.generateContent(testPrompt)
  318. }
  319. func testGenerateContent_auth_nilAuthToken() async throws {
  320. model = GenerativeModel(
  321. name: "my-model",
  322. projectID: "my-project-id",
  323. apiKey: "API_KEY",
  324. tools: nil,
  325. requestOptions: RequestOptions(),
  326. appCheck: nil,
  327. auth: AuthInteropFake(token: nil),
  328. urlSession: urlSession
  329. )
  330. MockURLProtocol
  331. .requestHandler = try httpRequestHandler(
  332. forResource: "unary-success-basic-reply-short",
  333. withExtension: "json",
  334. authToken: nil
  335. )
  336. _ = try await model.generateContent(testPrompt)
  337. }
  338. func testGenerateContent_auth_authTokenRefreshError() async throws {
  339. model = GenerativeModel(
  340. name: "my-model",
  341. projectID: "my-project-id",
  342. apiKey: "API_KEY",
  343. tools: nil,
  344. requestOptions: RequestOptions(),
  345. appCheck: nil,
  346. auth: AuthInteropFake(error: AuthErrorFake()),
  347. urlSession: urlSession
  348. )
  349. MockURLProtocol
  350. .requestHandler = try httpRequestHandler(
  351. forResource: "unary-success-basic-reply-short",
  352. withExtension: "json",
  353. authToken: nil
  354. )
  355. do {
  356. _ = try await model.generateContent(testPrompt)
  357. XCTFail("Should throw internalError(AuthErrorFake); no error.")
  358. } catch GenerateContentError.internalError(_ as AuthErrorFake) {
  359. //
  360. } catch {
  361. XCTFail("Should throw internalError(AuthErrorFake); error thrown: \(error)")
  362. }
  363. }
  364. func testGenerateContent_usageMetadata() async throws {
  365. MockURLProtocol
  366. .requestHandler = try httpRequestHandler(
  367. forResource: "unary-success-basic-reply-short",
  368. withExtension: "json"
  369. )
  370. let response = try await model.generateContent(testPrompt)
  371. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  372. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  373. XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
  374. XCTAssertEqual(usageMetadata.totalTokenCount, 13)
  375. }
  376. func testGenerateContent_failure_invalidAPIKey() async throws {
  377. let expectedStatusCode = 400
  378. MockURLProtocol
  379. .requestHandler = try httpRequestHandler(
  380. forResource: "unary-failure-api-key",
  381. withExtension: "json",
  382. statusCode: expectedStatusCode
  383. )
  384. do {
  385. _ = try await model.generateContent(testPrompt)
  386. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  387. } catch let GenerateContentError.internalError(error as RPCError) {
  388. XCTAssertEqual(error.httpResponseCode, 400)
  389. XCTAssertEqual(error.status, .invalidArgument)
  390. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  391. return
  392. } catch {
  393. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  394. }
  395. }
  396. func testGenerateContent_failure_firebaseMLAPINotEnabled() async throws {
  397. let expectedStatusCode = 403
  398. MockURLProtocol
  399. .requestHandler = try httpRequestHandler(
  400. forResource: "unary-failure-firebaseml-api-not-enabled",
  401. withExtension: "json",
  402. statusCode: expectedStatusCode
  403. )
  404. do {
  405. _ = try await model.generateContent(testPrompt)
  406. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  407. } catch let GenerateContentError.internalError(error as RPCError) {
  408. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  409. XCTAssertEqual(error.status, .permissionDenied)
  410. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  411. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  412. return
  413. } catch {
  414. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  415. }
  416. }
  417. func testGenerateContent_failure_emptyContent() async throws {
  418. MockURLProtocol
  419. .requestHandler = try httpRequestHandler(
  420. forResource: "unary-failure-empty-content",
  421. withExtension: "json"
  422. )
  423. do {
  424. _ = try await model.generateContent(testPrompt)
  425. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  426. } catch let GenerateContentError
  427. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  428. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  429. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  430. return
  431. }
  432. _ = try XCTUnwrap(decodingError as? DecodingError,
  433. "Not a DecodingError: \(decodingError)")
  434. } catch {
  435. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  436. }
  437. }
  438. func testGenerateContent_failure_finishReasonSafety() async throws {
  439. MockURLProtocol
  440. .requestHandler = try httpRequestHandler(
  441. forResource: "unary-failure-finish-reason-safety",
  442. withExtension: "json"
  443. )
  444. do {
  445. _ = try await model.generateContent(testPrompt)
  446. XCTFail("Should throw")
  447. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  448. XCTAssertEqual(reason, .safety)
  449. XCTAssertEqual(response.text, "<redacted>")
  450. } catch {
  451. XCTFail("Should throw a responseStoppedEarly")
  452. }
  453. }
  454. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  455. MockURLProtocol
  456. .requestHandler = try httpRequestHandler(
  457. forResource: "unary-failure-finish-reason-safety-no-content",
  458. withExtension: "json"
  459. )
  460. do {
  461. _ = try await model.generateContent(testPrompt)
  462. XCTFail("Should throw")
  463. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  464. XCTAssertEqual(reason, .safety)
  465. XCTAssertNil(response.text)
  466. } catch {
  467. XCTFail("Should throw a responseStoppedEarly")
  468. }
  469. }
  470. func testGenerateContent_failure_imageRejected() async throws {
  471. let expectedStatusCode = 400
  472. MockURLProtocol
  473. .requestHandler = try httpRequestHandler(
  474. forResource: "unary-failure-image-rejected",
  475. withExtension: "json",
  476. statusCode: 400
  477. )
  478. do {
  479. _ = try await model.generateContent(testPrompt)
  480. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  481. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  482. XCTAssertEqual(rpcError.status, .invalidArgument)
  483. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  484. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  485. } catch {
  486. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  487. }
  488. }
  489. func testGenerateContent_failure_promptBlockedSafety() async throws {
  490. MockURLProtocol
  491. .requestHandler = try httpRequestHandler(
  492. forResource: "unary-failure-prompt-blocked-safety",
  493. withExtension: "json"
  494. )
  495. do {
  496. _ = try await model.generateContent(testPrompt)
  497. XCTFail("Should throw")
  498. } catch let GenerateContentError.promptBlocked(response) {
  499. XCTAssertNil(response.text)
  500. } catch {
  501. XCTFail("Should throw a promptBlocked")
  502. }
  503. }
  504. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  505. MockURLProtocol
  506. .requestHandler = try httpRequestHandler(
  507. forResource: "unary-failure-unknown-enum-finish-reason",
  508. withExtension: "json"
  509. )
  510. do {
  511. _ = try await model.generateContent(testPrompt)
  512. XCTFail("Should throw")
  513. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  514. XCTAssertEqual(reason, .unknown)
  515. XCTAssertEqual(response.text, "Some text")
  516. } catch {
  517. XCTFail("Should throw a responseStoppedEarly")
  518. }
  519. }
  520. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  521. MockURLProtocol
  522. .requestHandler = try httpRequestHandler(
  523. forResource: "unary-failure-unknown-enum-prompt-blocked",
  524. withExtension: "json"
  525. )
  526. do {
  527. _ = try await model.generateContent(testPrompt)
  528. XCTFail("Should throw")
  529. } catch let GenerateContentError.promptBlocked(response) {
  530. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  531. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  532. } catch {
  533. XCTFail("Should throw a promptBlocked")
  534. }
  535. }
  536. func testGenerateContent_failure_unknownModel() async throws {
  537. let expectedStatusCode = 404
  538. MockURLProtocol
  539. .requestHandler = try httpRequestHandler(
  540. forResource: "unary-failure-unknown-model",
  541. withExtension: "json",
  542. statusCode: 404
  543. )
  544. do {
  545. _ = try await model.generateContent(testPrompt)
  546. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  547. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  548. XCTAssertEqual(rpcError.status, .notFound)
  549. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  550. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  551. } catch {
  552. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  553. }
  554. }
  555. func testGenerateContent_failure_nonHTTPResponse() async throws {
  556. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  557. var responseError: Error?
  558. var content: GenerateContentResponse?
  559. do {
  560. content = try await model.generateContent(testPrompt)
  561. } catch {
  562. responseError = error
  563. }
  564. XCTAssertNil(content)
  565. XCTAssertNotNil(responseError)
  566. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  567. guard case let .internalError(underlyingError) = generateContentError else {
  568. XCTFail("Not an internal error: \(generateContentError)")
  569. return
  570. }
  571. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  572. }
  573. func testGenerateContent_failure_invalidResponse() async throws {
  574. MockURLProtocol.requestHandler = try httpRequestHandler(
  575. forResource: "unary-failure-invalid-response",
  576. withExtension: "json"
  577. )
  578. var responseError: Error?
  579. var content: GenerateContentResponse?
  580. do {
  581. content = try await model.generateContent(testPrompt)
  582. } catch {
  583. responseError = error
  584. }
  585. XCTAssertNil(content)
  586. XCTAssertNotNil(responseError)
  587. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  588. guard case let .internalError(underlyingError) = generateContentError else {
  589. XCTFail("Not an internal error: \(generateContentError)")
  590. return
  591. }
  592. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  593. guard case let .dataCorrupted(context) = decodingError else {
  594. XCTFail("Not a data corrupted error: \(decodingError)")
  595. return
  596. }
  597. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  598. }
  599. func testGenerateContent_failure_malformedContent() async throws {
  600. MockURLProtocol
  601. .requestHandler = try httpRequestHandler(
  602. forResource: "unary-failure-malformed-content",
  603. withExtension: "json"
  604. )
  605. var responseError: Error?
  606. var content: GenerateContentResponse?
  607. do {
  608. content = try await model.generateContent(testPrompt)
  609. } catch {
  610. responseError = error
  611. }
  612. XCTAssertNil(content)
  613. XCTAssertNotNil(responseError)
  614. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  615. guard case let .internalError(underlyingError) = generateContentError else {
  616. XCTFail("Not an internal error: \(generateContentError)")
  617. return
  618. }
  619. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  620. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  621. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  622. return
  623. }
  624. _ = try XCTUnwrap(
  625. malformedContentUnderlyingError as? DecodingError,
  626. "Not a decoding error: \(malformedContentUnderlyingError)"
  627. )
  628. }
  629. func testGenerateContentMissingSafetyRatings() async throws {
  630. MockURLProtocol.requestHandler = try httpRequestHandler(
  631. forResource: "unary-success-missing-safety-ratings",
  632. withExtension: "json"
  633. )
  634. let content = try await model.generateContent(testPrompt)
  635. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  636. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  637. XCTAssertEqual(content.text, "This is the generated content.")
  638. }
  639. func testGenerateContent_requestOptions_customTimeout() async throws {
  640. let expectedTimeout = 150.0
  641. MockURLProtocol
  642. .requestHandler = try httpRequestHandler(
  643. forResource: "unary-success-basic-reply-short",
  644. withExtension: "json",
  645. timeout: expectedTimeout
  646. )
  647. let requestOptions = RequestOptions(timeout: expectedTimeout)
  648. model = GenerativeModel(
  649. name: "my-model",
  650. projectID: "my-project-id",
  651. apiKey: "API_KEY",
  652. tools: nil,
  653. requestOptions: requestOptions,
  654. appCheck: nil,
  655. auth: nil,
  656. urlSession: urlSession
  657. )
  658. let response = try await model.generateContent(testPrompt)
  659. XCTAssertEqual(response.candidates.count, 1)
  660. }
  661. // MARK: - Generate Content (Streaming)
  662. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  663. MockURLProtocol
  664. .requestHandler = try httpRequestHandler(
  665. forResource: "unary-failure-api-key",
  666. withExtension: "json"
  667. )
  668. do {
  669. let stream = model.generateContentStream("Hi")
  670. for try await _ in stream {
  671. XCTFail("No content is there, this shouldn't happen.")
  672. }
  673. } catch let GenerateContentError.internalError(error as RPCError) {
  674. XCTAssertEqual(error.httpResponseCode, 400)
  675. XCTAssertEqual(error.status, .invalidArgument)
  676. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  677. return
  678. }
  679. XCTFail("Should have caught an error.")
  680. }
  681. func testGenerateContentStream_failure_firebaseMLAPINotEnabled() async throws {
  682. let expectedStatusCode = 403
  683. MockURLProtocol
  684. .requestHandler = try httpRequestHandler(
  685. forResource: "unary-failure-firebaseml-api-not-enabled",
  686. withExtension: "json",
  687. statusCode: expectedStatusCode
  688. )
  689. do {
  690. let stream = model.generateContentStream(testPrompt)
  691. for try await _ in stream {
  692. XCTFail("No content is there, this shouldn't happen.")
  693. }
  694. } catch let GenerateContentError.internalError(error as RPCError) {
  695. XCTAssertEqual(error.httpResponseCode, expectedStatusCode)
  696. XCTAssertEqual(error.status, .permissionDenied)
  697. XCTAssertTrue(error.message.starts(with: "Firebase ML API has not been used in project"))
  698. XCTAssertTrue(error.isFirebaseMLServiceDisabledError())
  699. return
  700. }
  701. XCTFail("Should have caught an error.")
  702. }
  703. func testGenerateContentStream_failureEmptyContent() async throws {
  704. MockURLProtocol
  705. .requestHandler = try httpRequestHandler(
  706. forResource: "streaming-failure-empty-content",
  707. withExtension: "txt"
  708. )
  709. do {
  710. let stream = model.generateContentStream("Hi")
  711. for try await _ in stream {
  712. XCTFail("No content is there, this shouldn't happen.")
  713. }
  714. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  715. // Underlying error is as expected, nothing else to check.
  716. return
  717. }
  718. XCTFail("Should have caught an error.")
  719. }
  720. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  721. MockURLProtocol
  722. .requestHandler = try httpRequestHandler(
  723. forResource: "streaming-failure-finish-reason-safety",
  724. withExtension: "txt"
  725. )
  726. do {
  727. let stream = model.generateContentStream("Hi")
  728. for try await _ in stream {
  729. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  730. }
  731. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  732. XCTAssertEqual(reason, .safety)
  733. return
  734. }
  735. XCTFail("Should have caught an error.")
  736. }
  737. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  738. MockURLProtocol
  739. .requestHandler = try httpRequestHandler(
  740. forResource: "streaming-failure-prompt-blocked-safety",
  741. withExtension: "txt"
  742. )
  743. do {
  744. let stream = model.generateContentStream("Hi")
  745. for try await _ in stream {
  746. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  747. }
  748. } catch let GenerateContentError.promptBlocked(response) {
  749. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  750. return
  751. }
  752. XCTFail("Should have caught an error.")
  753. }
  754. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  755. MockURLProtocol
  756. .requestHandler = try httpRequestHandler(
  757. forResource: "streaming-failure-unknown-finish-enum",
  758. withExtension: "txt"
  759. )
  760. let stream = model.generateContentStream("Hi")
  761. do {
  762. for try await content in stream {
  763. XCTAssertNotNil(content.text)
  764. }
  765. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  766. XCTAssertEqual(reason, .unknown)
  767. return
  768. }
  769. XCTFail("Should have caught an error.")
  770. }
  771. func testGenerateContentStream_successBasicReplyLong() async throws {
  772. MockURLProtocol
  773. .requestHandler = try httpRequestHandler(
  774. forResource: "streaming-success-basic-reply-long",
  775. withExtension: "txt"
  776. )
  777. var responses = 0
  778. let stream = model.generateContentStream("Hi")
  779. for try await content in stream {
  780. XCTAssertNotNil(content.text)
  781. responses += 1
  782. }
  783. XCTAssertEqual(responses, 6)
  784. }
  785. func testGenerateContentStream_successBasicReplyShort() async throws {
  786. MockURLProtocol
  787. .requestHandler = try httpRequestHandler(
  788. forResource: "streaming-success-basic-reply-short",
  789. withExtension: "txt"
  790. )
  791. var responses = 0
  792. let stream = model.generateContentStream("Hi")
  793. for try await content in stream {
  794. XCTAssertNotNil(content.text)
  795. responses += 1
  796. }
  797. XCTAssertEqual(responses, 1)
  798. }
  799. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  800. MockURLProtocol
  801. .requestHandler = try httpRequestHandler(
  802. forResource: "streaming-success-unknown-safety-enum",
  803. withExtension: "txt"
  804. )
  805. var hadUnknown = false
  806. let stream = model.generateContentStream("Hi")
  807. for try await content in stream {
  808. XCTAssertNotNil(content.text)
  809. if let ratings = content.candidates.first?.safetyRatings,
  810. ratings.contains(where: { $0.category == .unknown }) {
  811. hadUnknown = true
  812. }
  813. }
  814. XCTAssertTrue(hadUnknown)
  815. }
  816. func testGenerateContentStream_successWithCitations() async throws {
  817. MockURLProtocol
  818. .requestHandler = try httpRequestHandler(
  819. forResource: "streaming-success-citations",
  820. withExtension: "txt"
  821. )
  822. let stream = model.generateContentStream("Hi")
  823. var citations = [Citation]()
  824. var responses = [GenerateContentResponse]()
  825. for try await content in stream {
  826. responses.append(content)
  827. XCTAssertNotNil(content.text)
  828. let candidate = try XCTUnwrap(content.candidates.first)
  829. if let sources = candidate.citationMetadata?.citationSources {
  830. citations.append(contentsOf: sources)
  831. }
  832. }
  833. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  834. XCTAssertEqual(lastCandidate.finishReason, .stop)
  835. XCTAssertEqual(citations.count, 6)
  836. XCTAssertTrue(citations
  837. .contains(where: {
  838. $0.startIndex == 574 && $0.endIndex == 705 && !$0.uri.isEmpty && $0.license == ""
  839. }))
  840. XCTAssertTrue(citations
  841. .contains(where: {
  842. $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty && $0.license == ""
  843. }))
  844. }
  845. func testGenerateContentStream_appCheck_validToken() async throws {
  846. let appCheckToken = "test-valid-token"
  847. model = GenerativeModel(
  848. name: "my-model",
  849. projectID: "my-project-id",
  850. apiKey: "API_KEY",
  851. tools: nil,
  852. requestOptions: RequestOptions(),
  853. appCheck: AppCheckInteropFake(token: appCheckToken),
  854. auth: nil,
  855. urlSession: urlSession
  856. )
  857. MockURLProtocol
  858. .requestHandler = try httpRequestHandler(
  859. forResource: "streaming-success-basic-reply-short",
  860. withExtension: "txt",
  861. appCheckToken: appCheckToken
  862. )
  863. let stream = model.generateContentStream(testPrompt)
  864. for try await _ in stream {}
  865. }
  866. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  867. model = GenerativeModel(
  868. name: "my-model",
  869. projectID: "my-project-id",
  870. apiKey: "API_KEY",
  871. tools: nil,
  872. requestOptions: RequestOptions(),
  873. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  874. auth: nil,
  875. urlSession: urlSession
  876. )
  877. MockURLProtocol
  878. .requestHandler = try httpRequestHandler(
  879. forResource: "streaming-success-basic-reply-short",
  880. withExtension: "txt",
  881. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  882. )
  883. let stream = model.generateContentStream(testPrompt)
  884. for try await _ in stream {}
  885. }
  886. func testGenerateContentStream_usageMetadata() async throws {
  887. MockURLProtocol
  888. .requestHandler = try httpRequestHandler(
  889. forResource: "streaming-success-basic-reply-short",
  890. withExtension: "txt"
  891. )
  892. var responses = [GenerateContentResponse]()
  893. let stream = model.generateContentStream(testPrompt)
  894. for try await response in stream {
  895. responses.append(response)
  896. }
  897. for (index, response) in responses.enumerated() {
  898. if index == responses.endIndex - 1 {
  899. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  900. XCTAssertEqual(usageMetadata.promptTokenCount, 6)
  901. XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
  902. XCTAssertEqual(usageMetadata.totalTokenCount, 10)
  903. } else {
  904. // Only the last streamed response contains usage metadata
  905. XCTAssertNil(response.usageMetadata)
  906. }
  907. }
  908. }
  909. func testGenerateContentStream_errorMidStream() async throws {
  910. MockURLProtocol.requestHandler = try httpRequestHandler(
  911. forResource: "streaming-failure-error-mid-stream",
  912. withExtension: "txt"
  913. )
  914. var responseCount = 0
  915. do {
  916. let stream = model.generateContentStream("Hi")
  917. for try await content in stream {
  918. XCTAssertNotNil(content.text)
  919. responseCount += 1
  920. }
  921. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  922. XCTAssertEqual(rpcError.httpResponseCode, 499)
  923. XCTAssertEqual(rpcError.status, .cancelled)
  924. // Check the content count is correct.
  925. XCTAssertEqual(responseCount, 2)
  926. return
  927. }
  928. XCTFail("Expected an internalError with an RPCError.")
  929. }
  930. func testGenerateContentStream_nonHTTPResponse() async throws {
  931. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  932. let stream = model.generateContentStream("Hi")
  933. do {
  934. for try await content in stream {
  935. XCTFail("Unexpected content in stream: \(content)")
  936. }
  937. } catch let GenerateContentError.internalError(underlying) {
  938. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  939. return
  940. }
  941. XCTFail("Expected an internal error.")
  942. }
  943. func testGenerateContentStream_invalidResponse() async throws {
  944. MockURLProtocol
  945. .requestHandler = try httpRequestHandler(
  946. forResource: "streaming-failure-invalid-json",
  947. withExtension: "txt"
  948. )
  949. let stream = model.generateContentStream(testPrompt)
  950. do {
  951. for try await content in stream {
  952. XCTFail("Unexpected content in stream: \(content)")
  953. }
  954. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  955. guard case let .dataCorrupted(context) = underlying else {
  956. XCTFail("Not a data corrupted error: \(underlying)")
  957. return
  958. }
  959. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  960. return
  961. }
  962. XCTFail("Expected an internal error.")
  963. }
  964. func testGenerateContentStream_malformedContent() async throws {
  965. MockURLProtocol
  966. .requestHandler = try httpRequestHandler(
  967. forResource: "streaming-failure-malformed-content",
  968. withExtension: "txt"
  969. )
  970. let stream = model.generateContentStream(testPrompt)
  971. do {
  972. for try await content in stream {
  973. XCTFail("Unexpected content in stream: \(content)")
  974. }
  975. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  976. guard case let .malformedContent(contentError) = underlyingError else {
  977. XCTFail("Not a malformed content error: \(underlyingError)")
  978. return
  979. }
  980. XCTAssert(contentError is DecodingError)
  981. return
  982. }
  983. XCTFail("Expected an internal decoding error.")
  984. }
  985. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  986. let expectedTimeout = 150.0
  987. MockURLProtocol
  988. .requestHandler = try httpRequestHandler(
  989. forResource: "streaming-success-basic-reply-short",
  990. withExtension: "txt",
  991. timeout: expectedTimeout
  992. )
  993. let requestOptions = RequestOptions(timeout: expectedTimeout)
  994. model = GenerativeModel(
  995. name: "my-model",
  996. projectID: "my-project-id",
  997. apiKey: "API_KEY",
  998. tools: nil,
  999. requestOptions: requestOptions,
  1000. appCheck: nil,
  1001. auth: nil,
  1002. urlSession: urlSession
  1003. )
  1004. var responses = 0
  1005. let stream = model.generateContentStream(testPrompt)
  1006. for try await content in stream {
  1007. XCTAssertNotNil(content.text)
  1008. responses += 1
  1009. }
  1010. XCTAssertEqual(responses, 1)
  1011. }
  1012. // MARK: - Count Tokens
  1013. func testCountTokens_succeeds() async throws {
  1014. MockURLProtocol.requestHandler = try httpRequestHandler(
  1015. forResource: "unary-success-total-tokens",
  1016. withExtension: "json"
  1017. )
  1018. let response = try await model.countTokens("Why is the sky blue?")
  1019. XCTAssertEqual(response.totalTokens, 6)
  1020. XCTAssertEqual(response.totalBillableCharacters, 16)
  1021. }
  1022. func testCountTokens_succeeds_noBillableCharacters() async throws {
  1023. MockURLProtocol.requestHandler = try httpRequestHandler(
  1024. forResource: "unary-success-no-billable-characters",
  1025. withExtension: "json"
  1026. )
  1027. let response = try await model.countTokens(ModelContent.Part.data(
  1028. mimetype: "image/jpeg",
  1029. Data()
  1030. ))
  1031. XCTAssertEqual(response.totalTokens, 258)
  1032. XCTAssertEqual(response.totalBillableCharacters, 0)
  1033. }
  1034. func testCountTokens_modelNotFound() async throws {
  1035. MockURLProtocol.requestHandler = try httpRequestHandler(
  1036. forResource: "unary-failure-model-not-found", withExtension: "json",
  1037. statusCode: 404
  1038. )
  1039. do {
  1040. _ = try await model.countTokens("Why is the sky blue?")
  1041. XCTFail("Request should not have succeeded.")
  1042. } catch let CountTokensError.internalError(rpcError as RPCError) {
  1043. XCTAssertEqual(rpcError.httpResponseCode, 404)
  1044. XCTAssertEqual(rpcError.status, .notFound)
  1045. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  1046. return
  1047. }
  1048. XCTFail("Expected internal RPCError.")
  1049. }
  1050. func testCountTokens_requestOptions_customTimeout() async throws {
  1051. let expectedTimeout = 150.0
  1052. MockURLProtocol
  1053. .requestHandler = try httpRequestHandler(
  1054. forResource: "unary-success-total-tokens",
  1055. withExtension: "json",
  1056. timeout: expectedTimeout
  1057. )
  1058. let requestOptions = RequestOptions(timeout: expectedTimeout)
  1059. model = GenerativeModel(
  1060. name: "my-model",
  1061. projectID: "my-project-id",
  1062. apiKey: "API_KEY",
  1063. tools: nil,
  1064. requestOptions: requestOptions,
  1065. appCheck: nil,
  1066. auth: nil,
  1067. urlSession: urlSession
  1068. )
  1069. let response = try await model.countTokens(testPrompt)
  1070. XCTAssertEqual(response.totalTokens, 6)
  1071. }
  1072. // MARK: - Model Resource Name
  1073. func testModelResourceName_noPrefix() async throws {
  1074. let modelName = "my-model"
  1075. let modelResourceName = "models/\(modelName)"
  1076. model = GenerativeModel(
  1077. name: modelName,
  1078. projectID: "my-project-id",
  1079. apiKey: "API_KEY",
  1080. tools: nil,
  1081. requestOptions: RequestOptions(),
  1082. appCheck: nil,
  1083. auth: nil
  1084. )
  1085. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1086. }
  1087. func testModelResourceName_modelsPrefix() async throws {
  1088. let modelResourceName = "models/my-model"
  1089. model = GenerativeModel(
  1090. name: modelResourceName,
  1091. projectID: "my-project-id",
  1092. apiKey: "API_KEY",
  1093. tools: nil,
  1094. requestOptions: RequestOptions(),
  1095. appCheck: nil,
  1096. auth: nil
  1097. )
  1098. XCTAssertEqual(model.modelResourceName, modelResourceName)
  1099. }
  1100. func testModelResourceName_tunedModelsPrefix() async throws {
  1101. let tunedModelResourceName = "tunedModels/my-model"
  1102. model = GenerativeModel(
  1103. name: tunedModelResourceName,
  1104. projectID: "my-project-id",
  1105. apiKey: "API_KEY",
  1106. tools: nil,
  1107. requestOptions: RequestOptions(),
  1108. appCheck: nil,
  1109. auth: nil
  1110. )
  1111. XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
  1112. }
  1113. // MARK: - Helpers
  1114. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  1115. URLResponse,
  1116. AsyncLineSequence<URL.AsyncBytes>?
  1117. )) {
  1118. return { request in
  1119. // This is *not* an HTTPURLResponse
  1120. let response = URLResponse(
  1121. url: request.url!,
  1122. mimeType: nil,
  1123. expectedContentLength: 0,
  1124. textEncodingName: nil
  1125. )
  1126. return (response, nil)
  1127. }
  1128. }
  1129. private func httpRequestHandler(forResource name: String,
  1130. withExtension ext: String,
  1131. statusCode: Int = 200,
  1132. timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
  1133. appCheckToken: String? = nil,
  1134. authToken: String? = nil) throws -> ((URLRequest) throws -> (
  1135. URLResponse,
  1136. AsyncLineSequence<URL.AsyncBytes>?
  1137. )) {
  1138. let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
  1139. return { request in
  1140. let requestURL = try XCTUnwrap(request.url)
  1141. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  1142. XCTAssertEqual(request.timeoutInterval, timeout)
  1143. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  1144. .components(separatedBy: " ")
  1145. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  1146. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  1147. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  1148. if let authToken {
  1149. XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Firebase \(authToken)")
  1150. } else {
  1151. XCTAssertNil(request.value(forHTTPHeaderField: "Authorization"))
  1152. }
  1153. let response = try XCTUnwrap(HTTPURLResponse(
  1154. url: requestURL,
  1155. statusCode: statusCode,
  1156. httpVersion: nil,
  1157. headerFields: nil
  1158. ))
  1159. return (response, fileURL.lines)
  1160. }
  1161. }
  1162. }
  1163. private extension String {
  1164. /// Returns the number of occurrences of `substring` in the `String`.
  1165. func occurrenceCount(of substring: String) -> Int {
  1166. return components(separatedBy: substring).count - 1
  1167. }
  1168. }
  1169. private extension URLRequest {
  1170. /// Returns the default `timeoutInterval` for a `URLRequest`.
  1171. static func defaultTimeoutInterval() -> TimeInterval {
  1172. let placeholderURL = URL(string: "https://example.com")!
  1173. return URLRequest(url: placeholderURL).timeoutInterval
  1174. }
  1175. }
  1176. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
  1177. class AppCheckInteropFake: NSObject, AppCheckInterop {
  1178. /// The placeholder token value returned when an error occurs
  1179. static let placeholderTokenValue = "placeholder-token"
  1180. var token: String
  1181. var error: Error?
  1182. private init(token: String, error: Error?) {
  1183. self.token = token
  1184. self.error = error
  1185. }
  1186. convenience init(token: String) {
  1187. self.init(token: token, error: nil)
  1188. }
  1189. convenience init(error: Error) {
  1190. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  1191. }
  1192. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  1193. return AppCheckTokenResultInteropFake(token: token, error: error)
  1194. }
  1195. func tokenDidChangeNotificationName() -> String {
  1196. fatalError("\(#function) not implemented.")
  1197. }
  1198. func notificationTokenKey() -> String {
  1199. fatalError("\(#function) not implemented.")
  1200. }
  1201. func notificationAppNameKey() -> String {
  1202. fatalError("\(#function) not implemented.")
  1203. }
  1204. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  1205. var token: String
  1206. var error: Error?
  1207. init(token: String, error: Error?) {
  1208. self.token = token
  1209. self.error = error
  1210. }
  1211. }
  1212. }
  1213. struct AppCheckErrorFake: Error {}
  1214. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
  1215. extension SafetyRating: Comparable {
  1216. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1217. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1218. return lhs.category.rawValue < rhs.category.rawValue
  1219. }
  1220. }