GenerativeModelTests.swift 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseCore
  16. import XCTest
  17. @testable import FirebaseVertexAI
  18. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
  19. final class GenerativeModelTests: XCTestCase {
  20. let testPrompt = "What sorts of questions can I ask you?"
  21. let safetyRatingsNegligible: [SafetyRating] = [
  22. .init(category: .sexuallyExplicit, probability: .negligible),
  23. .init(category: .hateSpeech, probability: .negligible),
  24. .init(category: .harassment, probability: .negligible),
  25. .init(category: .dangerousContent, probability: .negligible),
  26. ].sorted()
  27. var urlSession: URLSession!
  28. var model: GenerativeModel!
  29. override func setUp() async throws {
  30. let configuration = URLSessionConfiguration.default
  31. configuration.protocolClasses = [MockURLProtocol.self]
  32. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  33. model = GenerativeModel(
  34. name: "my-model",
  35. apiKey: "API_KEY",
  36. tools: nil,
  37. requestOptions: RequestOptions(),
  38. appCheck: nil,
  39. urlSession: urlSession
  40. )
  41. }
  42. override func tearDown() {
  43. MockURLProtocol.requestHandler = nil
  44. }
  45. // MARK: - Generate Content
  46. func testGenerateContent_success_basicReplyLong() async throws {
  47. MockURLProtocol
  48. .requestHandler = try httpRequestHandler(
  49. forResource: "unary-success-basic-reply-long",
  50. withExtension: "json"
  51. )
  52. let response = try await model.generateContent(testPrompt)
  53. XCTAssertEqual(response.candidates.count, 1)
  54. let candidate = try XCTUnwrap(response.candidates.first)
  55. let finishReason = try XCTUnwrap(candidate.finishReason)
  56. XCTAssertEqual(finishReason, .stop)
  57. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  58. XCTAssertEqual(candidate.content.parts.count, 1)
  59. let part = try XCTUnwrap(candidate.content.parts.first)
  60. let partText = try XCTUnwrap(part.text)
  61. XCTAssertTrue(partText.hasPrefix("You can ask me a wide range of questions"))
  62. XCTAssertEqual(response.text, partText)
  63. XCTAssertEqual(response.functionCalls, [])
  64. }
  65. func testGenerateContent_success_basicReplyShort() async throws {
  66. MockURLProtocol
  67. .requestHandler = try httpRequestHandler(
  68. forResource: "unary-success-basic-reply-short",
  69. withExtension: "json"
  70. )
  71. let response = try await model.generateContent(testPrompt)
  72. XCTAssertEqual(response.candidates.count, 1)
  73. let candidate = try XCTUnwrap(response.candidates.first)
  74. let finishReason = try XCTUnwrap(candidate.finishReason)
  75. XCTAssertEqual(finishReason, .stop)
  76. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  77. XCTAssertEqual(candidate.content.parts.count, 1)
  78. let part = try XCTUnwrap(candidate.content.parts.first)
  79. XCTAssertEqual(part.text, "Mountain View, California, United States")
  80. XCTAssertEqual(response.text, part.text)
  81. XCTAssertEqual(response.functionCalls, [])
  82. }
  83. func testGenerateContent_success_citations() async throws {
  84. MockURLProtocol
  85. .requestHandler = try httpRequestHandler(
  86. forResource: "unary-success-citations",
  87. withExtension: "json"
  88. )
  89. let response = try await model.generateContent(testPrompt)
  90. XCTAssertEqual(response.candidates.count, 1)
  91. let candidate = try XCTUnwrap(response.candidates.first)
  92. XCTAssertEqual(candidate.content.parts.count, 1)
  93. XCTAssertEqual(response.text, "Some information cited from an external source")
  94. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  95. XCTAssertEqual(citationMetadata.citationSources.count, 1)
  96. let citationSource = try XCTUnwrap(citationMetadata.citationSources.first)
  97. XCTAssertEqual(citationSource.uri, "https://www.example.com/some-citation")
  98. XCTAssertEqual(citationSource.startIndex, 179)
  99. XCTAssertEqual(citationSource.endIndex, 366)
  100. XCTAssertNil(citationSource.license)
  101. }
  102. func testGenerateContent_success_quoteReply() async throws {
  103. MockURLProtocol
  104. .requestHandler = try httpRequestHandler(
  105. forResource: "unary-success-quote-reply",
  106. withExtension: "json"
  107. )
  108. let response = try await model.generateContent(testPrompt)
  109. XCTAssertEqual(response.candidates.count, 1)
  110. let candidate = try XCTUnwrap(response.candidates.first)
  111. let finishReason = try XCTUnwrap(candidate.finishReason)
  112. XCTAssertEqual(finishReason, .stop)
  113. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  114. XCTAssertEqual(candidate.content.parts.count, 1)
  115. let part = try XCTUnwrap(candidate.content.parts.first)
  116. let partText = try XCTUnwrap(part.text)
  117. XCTAssertTrue(partText.hasPrefix("Google"))
  118. XCTAssertEqual(response.text, part.text)
  119. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  120. XCTAssertNil(promptFeedback.blockReason)
  121. XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
  122. }
  123. func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
  124. let expectedSafetyRatings = [
  125. SafetyRating(category: .harassment, probability: .medium),
  126. SafetyRating(category: .dangerousContent, probability: .unknown),
  127. SafetyRating(category: .unknown, probability: .high),
  128. ]
  129. MockURLProtocol
  130. .requestHandler = try httpRequestHandler(
  131. forResource: "unary-success-unknown-enum-safety-ratings",
  132. withExtension: "json"
  133. )
  134. let response = try await model.generateContent(testPrompt)
  135. XCTAssertEqual(response.text, "Some text")
  136. XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
  137. XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
  138. }
  139. func testGenerateContent_success_prefixedModelName() async throws {
  140. MockURLProtocol
  141. .requestHandler = try httpRequestHandler(
  142. forResource: "unary-success-basic-reply-short",
  143. withExtension: "json"
  144. )
  145. let model = GenerativeModel(
  146. // Model name is prefixed with "models/".
  147. name: "models/test-model",
  148. apiKey: "API_KEY",
  149. tools: nil,
  150. requestOptions: RequestOptions(),
  151. appCheck: nil,
  152. urlSession: urlSession
  153. )
  154. _ = try await model.generateContent(testPrompt)
  155. }
  156. func testGenerateContent_success_functionCall_emptyArguments() async throws {
  157. MockURLProtocol
  158. .requestHandler = try httpRequestHandler(
  159. forResource: "unary-success-function-call-empty-arguments",
  160. withExtension: "json"
  161. )
  162. let response = try await model.generateContent(testPrompt)
  163. XCTAssertEqual(response.candidates.count, 1)
  164. let candidate = try XCTUnwrap(response.candidates.first)
  165. XCTAssertEqual(candidate.content.parts.count, 1)
  166. let part = try XCTUnwrap(candidate.content.parts.first)
  167. guard case let .functionCall(functionCall) = part else {
  168. XCTFail("Part is not a FunctionCall.")
  169. return
  170. }
  171. XCTAssertEqual(functionCall.name, "current_time")
  172. XCTAssertTrue(functionCall.args.isEmpty)
  173. XCTAssertEqual(response.functionCalls, [functionCall])
  174. }
  175. func testGenerateContent_success_functionCall_noArguments() async throws {
  176. MockURLProtocol
  177. .requestHandler = try httpRequestHandler(
  178. forResource: "unary-success-function-call-no-arguments",
  179. withExtension: "json"
  180. )
  181. let response = try await model.generateContent(testPrompt)
  182. XCTAssertEqual(response.candidates.count, 1)
  183. let candidate = try XCTUnwrap(response.candidates.first)
  184. XCTAssertEqual(candidate.content.parts.count, 1)
  185. let part = try XCTUnwrap(candidate.content.parts.first)
  186. guard case let .functionCall(functionCall) = part else {
  187. XCTFail("Part is not a FunctionCall.")
  188. return
  189. }
  190. XCTAssertEqual(functionCall.name, "current_time")
  191. XCTAssertTrue(functionCall.args.isEmpty)
  192. XCTAssertEqual(response.functionCalls, [functionCall])
  193. }
  194. func testGenerateContent_success_functionCall_withArguments() async throws {
  195. MockURLProtocol
  196. .requestHandler = try httpRequestHandler(
  197. forResource: "unary-success-function-call-with-arguments",
  198. withExtension: "json"
  199. )
  200. let response = try await model.generateContent(testPrompt)
  201. XCTAssertEqual(response.candidates.count, 1)
  202. let candidate = try XCTUnwrap(response.candidates.first)
  203. XCTAssertEqual(candidate.content.parts.count, 1)
  204. let part = try XCTUnwrap(candidate.content.parts.first)
  205. guard case let .functionCall(functionCall) = part else {
  206. XCTFail("Part is not a FunctionCall.")
  207. return
  208. }
  209. XCTAssertEqual(functionCall.name, "sum")
  210. XCTAssertEqual(functionCall.args.count, 2)
  211. let argX = try XCTUnwrap(functionCall.args["x"])
  212. XCTAssertEqual(argX, .number(4))
  213. let argY = try XCTUnwrap(functionCall.args["y"])
  214. XCTAssertEqual(argY, .number(5))
  215. XCTAssertEqual(response.functionCalls, [functionCall])
  216. }
  217. func testGenerateContent_appCheck_validToken() async throws {
  218. let appCheckToken = "test-valid-token"
  219. model = GenerativeModel(
  220. name: "my-model",
  221. apiKey: "API_KEY",
  222. tools: nil,
  223. requestOptions: RequestOptions(),
  224. appCheck: AppCheckInteropFake(token: appCheckToken),
  225. urlSession: urlSession
  226. )
  227. MockURLProtocol
  228. .requestHandler = try httpRequestHandler(
  229. forResource: "unary-success-basic-reply-short",
  230. withExtension: "json",
  231. appCheckToken: appCheckToken
  232. )
  233. _ = try await model.generateContent(testPrompt)
  234. }
  235. func testGenerateContent_appCheck_tokenRefreshError() async throws {
  236. model = GenerativeModel(
  237. name: "my-model",
  238. apiKey: "API_KEY",
  239. tools: nil,
  240. requestOptions: RequestOptions(),
  241. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  242. urlSession: urlSession
  243. )
  244. MockURLProtocol
  245. .requestHandler = try httpRequestHandler(
  246. forResource: "unary-success-basic-reply-short",
  247. withExtension: "json",
  248. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  249. )
  250. _ = try await model.generateContent(testPrompt)
  251. }
  252. func testGenerateContent_failure_invalidAPIKey() async throws {
  253. let expectedStatusCode = 400
  254. MockURLProtocol
  255. .requestHandler = try httpRequestHandler(
  256. forResource: "unary-failure-api-key",
  257. withExtension: "json",
  258. statusCode: expectedStatusCode
  259. )
  260. do {
  261. _ = try await model.generateContent(testPrompt)
  262. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  263. } catch let GenerateContentError.internalError(error as RPCError) {
  264. XCTAssertEqual(error.httpResponseCode, 400)
  265. XCTAssertEqual(error.status, .invalidArgument)
  266. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  267. return
  268. } catch {
  269. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  270. }
  271. }
  272. func testGenerateContent_failure_emptyContent() async throws {
  273. MockURLProtocol
  274. .requestHandler = try httpRequestHandler(
  275. forResource: "unary-failure-empty-content",
  276. withExtension: "json"
  277. )
  278. do {
  279. _ = try await model.generateContent(testPrompt)
  280. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  281. } catch let GenerateContentError
  282. .internalError(underlying: invalidCandidateError as InvalidCandidateError) {
  283. guard case let .emptyContent(decodingError) = invalidCandidateError else {
  284. XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)")
  285. return
  286. }
  287. _ = try XCTUnwrap(decodingError as? DecodingError,
  288. "Not a DecodingError: \(decodingError)")
  289. } catch {
  290. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  291. }
  292. }
  293. func testGenerateContent_failure_finishReasonSafety() async throws {
  294. MockURLProtocol
  295. .requestHandler = try httpRequestHandler(
  296. forResource: "unary-failure-finish-reason-safety",
  297. withExtension: "json"
  298. )
  299. do {
  300. _ = try await model.generateContent(testPrompt)
  301. XCTFail("Should throw")
  302. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  303. XCTAssertEqual(reason, .safety)
  304. XCTAssertEqual(response.text, "No")
  305. } catch {
  306. XCTFail("Should throw a responseStoppedEarly")
  307. }
  308. }
  309. func testGenerateContent_failure_finishReasonSafety_noContent() async throws {
  310. MockURLProtocol
  311. .requestHandler = try httpRequestHandler(
  312. forResource: "unary-failure-finish-reason-safety-no-content",
  313. withExtension: "json"
  314. )
  315. do {
  316. _ = try await model.generateContent(testPrompt)
  317. XCTFail("Should throw")
  318. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  319. XCTAssertEqual(reason, .safety)
  320. XCTAssertNil(response.text)
  321. } catch {
  322. XCTFail("Should throw a responseStoppedEarly")
  323. }
  324. }
  325. func testGenerateContent_failure_imageRejected() async throws {
  326. let expectedStatusCode = 400
  327. MockURLProtocol
  328. .requestHandler = try httpRequestHandler(
  329. forResource: "unary-failure-image-rejected",
  330. withExtension: "json",
  331. statusCode: 400
  332. )
  333. do {
  334. _ = try await model.generateContent(testPrompt)
  335. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  336. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  337. XCTAssertEqual(rpcError.status, .invalidArgument)
  338. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  339. XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
  340. } catch {
  341. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  342. }
  343. }
  344. func testGenerateContent_failure_promptBlockedSafety() async throws {
  345. MockURLProtocol
  346. .requestHandler = try httpRequestHandler(
  347. forResource: "unary-failure-prompt-blocked-safety",
  348. withExtension: "json"
  349. )
  350. do {
  351. _ = try await model.generateContent(testPrompt)
  352. XCTFail("Should throw")
  353. } catch let GenerateContentError.promptBlocked(response) {
  354. XCTAssertNil(response.text)
  355. } catch {
  356. XCTFail("Should throw a promptBlocked")
  357. }
  358. }
  359. func testGenerateContent_failure_unknownEnum_finishReason() async throws {
  360. MockURLProtocol
  361. .requestHandler = try httpRequestHandler(
  362. forResource: "unary-failure-unknown-enum-finish-reason",
  363. withExtension: "json"
  364. )
  365. do {
  366. _ = try await model.generateContent(testPrompt)
  367. XCTFail("Should throw")
  368. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  369. XCTAssertEqual(reason, .unknown)
  370. XCTAssertEqual(response.text, "Some text")
  371. } catch {
  372. XCTFail("Should throw a responseStoppedEarly")
  373. }
  374. }
  375. func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
  376. MockURLProtocol
  377. .requestHandler = try httpRequestHandler(
  378. forResource: "unary-failure-unknown-enum-prompt-blocked",
  379. withExtension: "json"
  380. )
  381. do {
  382. _ = try await model.generateContent(testPrompt)
  383. XCTFail("Should throw")
  384. } catch let GenerateContentError.promptBlocked(response) {
  385. let promptFeedback = try XCTUnwrap(response.promptFeedback)
  386. XCTAssertEqual(promptFeedback.blockReason, .unknown)
  387. } catch {
  388. XCTFail("Should throw a promptBlocked")
  389. }
  390. }
  391. func testGenerateContent_failure_unknownModel() async throws {
  392. let expectedStatusCode = 404
  393. MockURLProtocol
  394. .requestHandler = try httpRequestHandler(
  395. forResource: "unary-failure-unknown-model",
  396. withExtension: "json",
  397. statusCode: 404
  398. )
  399. do {
  400. _ = try await model.generateContent(testPrompt)
  401. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  402. } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
  403. XCTAssertEqual(rpcError.status, .notFound)
  404. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  405. XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
  406. } catch {
  407. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  408. }
  409. }
  410. func testGenerateContent_failure_nonHTTPResponse() async throws {
  411. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  412. var responseError: Error?
  413. var content: GenerateContentResponse?
  414. do {
  415. content = try await model.generateContent(testPrompt)
  416. } catch {
  417. responseError = error
  418. }
  419. XCTAssertNil(content)
  420. XCTAssertNotNil(responseError)
  421. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  422. guard case let .internalError(underlyingError) = generateContentError else {
  423. XCTFail("Not an internal error: \(generateContentError)")
  424. return
  425. }
  426. XCTAssertEqual(underlyingError.localizedDescription, "Response was not an HTTP response.")
  427. }
  428. func testGenerateContent_failure_invalidResponse() async throws {
  429. MockURLProtocol.requestHandler = try httpRequestHandler(
  430. forResource: "unary-failure-invalid-response",
  431. withExtension: "json"
  432. )
  433. var responseError: Error?
  434. var content: GenerateContentResponse?
  435. do {
  436. content = try await model.generateContent(testPrompt)
  437. } catch {
  438. responseError = error
  439. }
  440. XCTAssertNil(content)
  441. XCTAssertNotNil(responseError)
  442. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  443. guard case let .internalError(underlyingError) = generateContentError else {
  444. XCTFail("Not an internal error: \(generateContentError)")
  445. return
  446. }
  447. let decodingError = try XCTUnwrap(underlyingError as? DecodingError)
  448. guard case let .dataCorrupted(context) = decodingError else {
  449. XCTFail("Not a data corrupted error: \(decodingError)")
  450. return
  451. }
  452. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  453. }
  454. func testGenerateContent_failure_malformedContent() async throws {
  455. MockURLProtocol
  456. .requestHandler = try httpRequestHandler(
  457. forResource: "unary-failure-malformed-content",
  458. withExtension: "json"
  459. )
  460. var responseError: Error?
  461. var content: GenerateContentResponse?
  462. do {
  463. content = try await model.generateContent(testPrompt)
  464. } catch {
  465. responseError = error
  466. }
  467. XCTAssertNil(content)
  468. XCTAssertNotNil(responseError)
  469. let generateContentError = try XCTUnwrap(responseError as? GenerateContentError)
  470. guard case let .internalError(underlyingError) = generateContentError else {
  471. XCTFail("Not an internal error: \(generateContentError)")
  472. return
  473. }
  474. let invalidCandidateError = try XCTUnwrap(underlyingError as? InvalidCandidateError)
  475. guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else {
  476. XCTFail("Not a malformed content error: \(invalidCandidateError)")
  477. return
  478. }
  479. _ = try XCTUnwrap(
  480. malformedContentUnderlyingError as? DecodingError,
  481. "Not a decoding error: \(malformedContentUnderlyingError)"
  482. )
  483. }
  484. func testGenerateContentMissingSafetyRatings() async throws {
  485. MockURLProtocol.requestHandler = try httpRequestHandler(
  486. forResource: "unary-success-missing-safety-ratings",
  487. withExtension: "json"
  488. )
  489. let content = try await model.generateContent(testPrompt)
  490. let promptFeedback = try XCTUnwrap(content.promptFeedback)
  491. XCTAssertEqual(promptFeedback.safetyRatings.count, 0)
  492. XCTAssertEqual(content.text, "This is the generated content.")
  493. }
  494. func testGenerateContent_requestOptions_customTimeout() async throws {
  495. let expectedTimeout = 150.0
  496. MockURLProtocol
  497. .requestHandler = try httpRequestHandler(
  498. forResource: "unary-success-basic-reply-short",
  499. withExtension: "json",
  500. timeout: expectedTimeout
  501. )
  502. let requestOptions = RequestOptions(timeout: expectedTimeout)
  503. model = GenerativeModel(
  504. name: "my-model",
  505. apiKey: "API_KEY",
  506. tools: nil,
  507. requestOptions: requestOptions,
  508. appCheck: nil,
  509. urlSession: urlSession
  510. )
  511. let response = try await model.generateContent(testPrompt)
  512. XCTAssertEqual(response.candidates.count, 1)
  513. }
  514. // MARK: - Generate Content (Streaming)
  515. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  516. MockURLProtocol
  517. .requestHandler = try httpRequestHandler(
  518. forResource: "unary-failure-api-key",
  519. withExtension: "json"
  520. )
  521. do {
  522. let stream = model.generateContentStream("Hi")
  523. for try await _ in stream {
  524. XCTFail("No content is there, this shouldn't happen.")
  525. }
  526. } catch let GenerateContentError.internalError(error as RPCError) {
  527. XCTAssertEqual(error.httpResponseCode, 400)
  528. XCTAssertEqual(error.status, .invalidArgument)
  529. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  530. return
  531. }
  532. XCTFail("Should have caught an error.")
  533. }
  534. func testGenerateContentStream_failureEmptyContent() async throws {
  535. MockURLProtocol
  536. .requestHandler = try httpRequestHandler(
  537. forResource: "streaming-failure-empty-content",
  538. withExtension: "txt"
  539. )
  540. do {
  541. let stream = model.generateContentStream("Hi")
  542. for try await _ in stream {
  543. XCTFail("No content is there, this shouldn't happen.")
  544. }
  545. } catch GenerateContentError.internalError(_ as InvalidCandidateError) {
  546. // Underlying error is as expected, nothing else to check.
  547. return
  548. }
  549. XCTFail("Should have caught an error.")
  550. }
  551. func testGenerateContentStream_failureFinishReasonSafety() async throws {
  552. MockURLProtocol
  553. .requestHandler = try httpRequestHandler(
  554. forResource: "streaming-failure-finish-reason-safety",
  555. withExtension: "txt"
  556. )
  557. do {
  558. let stream = model.generateContentStream("Hi")
  559. for try await _ in stream {
  560. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  561. }
  562. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  563. XCTAssertEqual(reason, .safety)
  564. return
  565. }
  566. XCTFail("Should have caught an error.")
  567. }
  568. func testGenerateContentStream_failurePromptBlockedSafety() async throws {
  569. MockURLProtocol
  570. .requestHandler = try httpRequestHandler(
  571. forResource: "streaming-failure-prompt-blocked-safety",
  572. withExtension: "txt"
  573. )
  574. do {
  575. let stream = model.generateContentStream("Hi")
  576. for try await _ in stream {
  577. XCTFail("Content shouldn't be shown, this shouldn't happen.")
  578. }
  579. } catch let GenerateContentError.promptBlocked(response) {
  580. XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
  581. return
  582. }
  583. XCTFail("Should have caught an error.")
  584. }
  585. func testGenerateContentStream_failureUnknownFinishEnum() async throws {
  586. MockURLProtocol
  587. .requestHandler = try httpRequestHandler(
  588. forResource: "streaming-failure-unknown-finish-enum",
  589. withExtension: "txt"
  590. )
  591. let stream = model.generateContentStream("Hi")
  592. do {
  593. for try await content in stream {
  594. XCTAssertNotNil(content.text)
  595. }
  596. } catch let GenerateContentError.responseStoppedEarly(reason, _) {
  597. XCTAssertEqual(reason, .unknown)
  598. return
  599. }
  600. XCTFail("Should have caught an error.")
  601. }
  602. func testGenerateContentStream_successBasicReplyLong() async throws {
  603. MockURLProtocol
  604. .requestHandler = try httpRequestHandler(
  605. forResource: "streaming-success-basic-reply-long",
  606. withExtension: "txt"
  607. )
  608. var responses = 0
  609. let stream = model.generateContentStream("Hi")
  610. for try await content in stream {
  611. XCTAssertNotNil(content.text)
  612. responses += 1
  613. }
  614. XCTAssertEqual(responses, 8)
  615. }
  616. func testGenerateContentStream_successBasicReplyShort() async throws {
  617. MockURLProtocol
  618. .requestHandler = try httpRequestHandler(
  619. forResource: "streaming-success-basic-reply-short",
  620. withExtension: "txt"
  621. )
  622. var responses = 0
  623. let stream = model.generateContentStream("Hi")
  624. for try await content in stream {
  625. XCTAssertNotNil(content.text)
  626. responses += 1
  627. }
  628. XCTAssertEqual(responses, 1)
  629. }
  630. func testGenerateContentStream_successUnknownSafetyEnum() async throws {
  631. MockURLProtocol
  632. .requestHandler = try httpRequestHandler(
  633. forResource: "streaming-success-unknown-safety-enum",
  634. withExtension: "txt"
  635. )
  636. var hadUnknown = false
  637. let stream = model.generateContentStream("Hi")
  638. for try await content in stream {
  639. XCTAssertNotNil(content.text)
  640. if let ratings = content.candidates.first?.safetyRatings,
  641. ratings.contains(where: { $0.category == .unknown }) {
  642. hadUnknown = true
  643. }
  644. }
  645. XCTAssertTrue(hadUnknown)
  646. }
  647. func testGenerateContentStream_successWithCitations() async throws {
  648. MockURLProtocol
  649. .requestHandler = try httpRequestHandler(
  650. forResource: "streaming-success-citations",
  651. withExtension: "txt"
  652. )
  653. let stream = model.generateContentStream("Hi")
  654. var citations = [Citation]()
  655. var responses = [GenerateContentResponse]()
  656. for try await content in stream {
  657. responses.append(content)
  658. XCTAssertNotNil(content.text)
  659. let candidate = try XCTUnwrap(content.candidates.first)
  660. if let sources = candidate.citationMetadata?.citationSources {
  661. citations.append(contentsOf: sources)
  662. }
  663. }
  664. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  665. XCTAssertEqual(lastCandidate.finishReason, .stop)
  666. XCTAssertEqual(citations.count, 3)
  667. XCTAssertTrue(citations
  668. .contains(where: {
  669. $0.startIndex == 31 && $0.endIndex == 187 && $0
  670. .uri == "https://www.example.com/citation-1" && $0.license == nil
  671. }))
  672. XCTAssertTrue(citations
  673. .contains(where: {
  674. $0.startIndex == 133 && $0.endIndex == 272 && $0
  675. .uri == "https://www.example.com/citation-3" && $0.license == "mit"
  676. }))
  677. }
  678. func testGenerateContentStream_appCheck_validToken() async throws {
  679. let appCheckToken = "test-valid-token"
  680. model = GenerativeModel(
  681. name: "my-model",
  682. apiKey: "API_KEY",
  683. tools: nil,
  684. requestOptions: RequestOptions(),
  685. appCheck: AppCheckInteropFake(token: appCheckToken),
  686. urlSession: urlSession
  687. )
  688. MockURLProtocol
  689. .requestHandler = try httpRequestHandler(
  690. forResource: "streaming-success-basic-reply-short",
  691. withExtension: "txt",
  692. appCheckToken: appCheckToken
  693. )
  694. let stream = model.generateContentStream(testPrompt)
  695. for try await _ in stream {}
  696. }
  697. func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
  698. model = GenerativeModel(
  699. name: "my-model",
  700. apiKey: "API_KEY",
  701. tools: nil,
  702. requestOptions: RequestOptions(),
  703. appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
  704. urlSession: urlSession
  705. )
  706. MockURLProtocol
  707. .requestHandler = try httpRequestHandler(
  708. forResource: "streaming-success-basic-reply-short",
  709. withExtension: "txt",
  710. appCheckToken: AppCheckInteropFake.placeholderTokenValue
  711. )
  712. let stream = model.generateContentStream(testPrompt)
  713. for try await _ in stream {}
  714. }
  715. func testGenerateContentStream_errorMidStream() async throws {
  716. MockURLProtocol.requestHandler = try httpRequestHandler(
  717. forResource: "streaming-failure-error-mid-stream",
  718. withExtension: "txt"
  719. )
  720. var responseCount = 0
  721. do {
  722. let stream = model.generateContentStream("Hi")
  723. for try await content in stream {
  724. XCTAssertNotNil(content.text)
  725. responseCount += 1
  726. }
  727. } catch let GenerateContentError.internalError(rpcError as RPCError) {
  728. XCTAssertEqual(rpcError.httpResponseCode, 499)
  729. XCTAssertEqual(rpcError.status, .cancelled)
  730. // Check the content count is correct.
  731. XCTAssertEqual(responseCount, 2)
  732. return
  733. }
  734. XCTFail("Expected an internalError with an RPCError.")
  735. }
  736. func testGenerateContentStream_nonHTTPResponse() async throws {
  737. MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
  738. let stream = model.generateContentStream("Hi")
  739. do {
  740. for try await content in stream {
  741. XCTFail("Unexpected content in stream: \(content)")
  742. }
  743. } catch let GenerateContentError.internalError(underlying) {
  744. XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
  745. return
  746. }
  747. XCTFail("Expected an internal error.")
  748. }
  749. func testGenerateContentStream_invalidResponse() async throws {
  750. MockURLProtocol
  751. .requestHandler = try httpRequestHandler(
  752. forResource: "streaming-failure-invalid-json",
  753. withExtension: "txt"
  754. )
  755. let stream = model.generateContentStream(testPrompt)
  756. do {
  757. for try await content in stream {
  758. XCTFail("Unexpected content in stream: \(content)")
  759. }
  760. } catch let GenerateContentError.internalError(underlying as DecodingError) {
  761. guard case let .dataCorrupted(context) = underlying else {
  762. XCTFail("Not a data corrupted error: \(underlying)")
  763. return
  764. }
  765. XCTAssert(context.debugDescription.hasPrefix("Failed to decode GenerateContentResponse"))
  766. return
  767. }
  768. XCTFail("Expected an internal error.")
  769. }
  770. func testGenerateContentStream_malformedContent() async throws {
  771. MockURLProtocol
  772. .requestHandler = try httpRequestHandler(
  773. forResource: "streaming-failure-malformed-content",
  774. withExtension: "txt"
  775. )
  776. let stream = model.generateContentStream(testPrompt)
  777. do {
  778. for try await content in stream {
  779. XCTFail("Unexpected content in stream: \(content)")
  780. }
  781. } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
  782. guard case let .malformedContent(contentError) = underlyingError else {
  783. XCTFail("Not a malformed content error: \(underlyingError)")
  784. return
  785. }
  786. XCTAssert(contentError is DecodingError)
  787. return
  788. }
  789. XCTFail("Expected an internal decoding error.")
  790. }
  791. func testGenerateContentStream_requestOptions_customTimeout() async throws {
  792. let expectedTimeout = 150.0
  793. MockURLProtocol
  794. .requestHandler = try httpRequestHandler(
  795. forResource: "streaming-success-basic-reply-short",
  796. withExtension: "txt",
  797. timeout: expectedTimeout
  798. )
  799. let requestOptions = RequestOptions(timeout: expectedTimeout)
  800. model = GenerativeModel(
  801. name: "my-model",
  802. apiKey: "API_KEY",
  803. tools: nil,
  804. requestOptions: requestOptions,
  805. appCheck: nil,
  806. urlSession: urlSession
  807. )
  808. var responses = 0
  809. let stream = model.generateContentStream(testPrompt)
  810. for try await content in stream {
  811. XCTAssertNotNil(content.text)
  812. responses += 1
  813. }
  814. XCTAssertEqual(responses, 1)
  815. }
  816. // MARK: - Count Tokens
  817. func testCountTokens_succeeds() async throws {
  818. MockURLProtocol.requestHandler = try httpRequestHandler(
  819. forResource: "success-total-tokens",
  820. withExtension: "json"
  821. )
  822. let response = try await model.countTokens("Why is the sky blue?")
  823. XCTAssertEqual(response.totalTokens, 6)
  824. XCTAssertEqual(response.totalBillableCharacters, 16)
  825. }
  826. func testCountTokens_succeeds_noBillableCharacters() async throws {
  827. MockURLProtocol.requestHandler = try httpRequestHandler(
  828. forResource: "success-no-billable-characters",
  829. withExtension: "json"
  830. )
  831. let response = try await model.countTokens(ModelContent.Part.data(
  832. mimetype: "image/jpeg",
  833. Data()
  834. ))
  835. XCTAssertEqual(response.totalTokens, 258)
  836. XCTAssertEqual(response.totalBillableCharacters, 0)
  837. }
  838. func testCountTokens_modelNotFound() async throws {
  839. MockURLProtocol.requestHandler = try httpRequestHandler(
  840. forResource: "failure-model-not-found", withExtension: "json",
  841. statusCode: 404
  842. )
  843. do {
  844. _ = try await model.countTokens("Why is the sky blue?")
  845. XCTFail("Request should not have succeeded.")
  846. } catch let CountTokensError.internalError(rpcError as RPCError) {
  847. XCTAssertEqual(rpcError.httpResponseCode, 404)
  848. XCTAssertEqual(rpcError.status, .notFound)
  849. XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
  850. return
  851. }
  852. XCTFail("Expected internal RPCError.")
  853. }
  854. func testCountTokens_requestOptions_customTimeout() async throws {
  855. let expectedTimeout = 150.0
  856. MockURLProtocol
  857. .requestHandler = try httpRequestHandler(
  858. forResource: "success-total-tokens",
  859. withExtension: "json",
  860. timeout: expectedTimeout
  861. )
  862. let requestOptions = RequestOptions(timeout: expectedTimeout)
  863. model = GenerativeModel(
  864. name: "my-model",
  865. apiKey: "API_KEY",
  866. tools: nil,
  867. requestOptions: requestOptions,
  868. appCheck: nil,
  869. urlSession: urlSession
  870. )
  871. let response = try await model.countTokens(testPrompt)
  872. XCTAssertEqual(response.totalTokens, 6)
  873. }
  874. // MARK: - Model Resource Name
  875. func testModelResourceName_noPrefix() async throws {
  876. let modelName = "my-model"
  877. let modelResourceName = "models/\(modelName)"
  878. model = GenerativeModel(
  879. name: modelName,
  880. apiKey: "API_KEY",
  881. tools: nil,
  882. requestOptions: RequestOptions(),
  883. appCheck: nil
  884. )
  885. XCTAssertEqual(model.modelResourceName, modelResourceName)
  886. }
  887. func testModelResourceName_modelsPrefix() async throws {
  888. let modelResourceName = "models/my-model"
  889. model = GenerativeModel(
  890. name: modelResourceName,
  891. apiKey: "API_KEY",
  892. tools: nil,
  893. requestOptions: RequestOptions(),
  894. appCheck: nil
  895. )
  896. XCTAssertEqual(model.modelResourceName, modelResourceName)
  897. }
  898. func testModelResourceName_tunedModelsPrefix() async throws {
  899. let tunedModelResourceName = "tunedModels/my-model"
  900. model = GenerativeModel(
  901. name: tunedModelResourceName,
  902. apiKey: "API_KEY",
  903. tools: nil,
  904. requestOptions: RequestOptions(),
  905. appCheck: nil
  906. )
  907. XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
  908. }
  909. // MARK: - Helpers
  910. private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
  911. URLResponse,
  912. AsyncLineSequence<URL.AsyncBytes>?
  913. )) {
  914. return { request in
  915. // This is *not* an HTTPURLResponse
  916. let response = URLResponse(
  917. url: request.url!,
  918. mimeType: nil,
  919. expectedContentLength: 0,
  920. textEncodingName: nil
  921. )
  922. return (response, nil)
  923. }
  924. }
  925. private func httpRequestHandler(forResource name: String,
  926. withExtension ext: String,
  927. statusCode: Int = 200,
  928. timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
  929. appCheckToken: String? = nil) throws -> ((URLRequest) throws -> (
  930. URLResponse,
  931. AsyncLineSequence<URL.AsyncBytes>?
  932. )) {
  933. let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
  934. return { request in
  935. let requestURL = try XCTUnwrap(request.url)
  936. XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
  937. XCTAssertEqual(request.timeoutInterval, timeout)
  938. let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client"))
  939. .components(separatedBy: " ")
  940. XCTAssert(apiClientTags.contains(GenerativeAIService.languageTag))
  941. XCTAssert(apiClientTags.contains(GenerativeAIService.firebaseVersionTag))
  942. XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
  943. let response = try XCTUnwrap(HTTPURLResponse(
  944. url: requestURL,
  945. statusCode: statusCode,
  946. httpVersion: nil,
  947. headerFields: nil
  948. ))
  949. return (response, fileURL.lines)
  950. }
  951. }
  952. }
  953. private extension String {
  954. /// Returns the number of occurrences of `substring` in the `String`.
  955. func occurrenceCount(of substring: String) -> Int {
  956. return components(separatedBy: substring).count - 1
  957. }
  958. }
  959. private extension URLRequest {
  960. /// Returns the default `timeoutInterval` for a `URLRequest`.
  961. static func defaultTimeoutInterval() -> TimeInterval {
  962. let placeholderURL = URL(string: "https://example.com")!
  963. return URLRequest(url: placeholderURL).timeoutInterval
  964. }
  965. }
  966. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
  967. class AppCheckInteropFake: NSObject, AppCheckInterop {
  968. /// The placeholder token value returned when an error occurs
  969. static let placeholderTokenValue = "placeholder-token"
  970. var token: String
  971. var error: Error?
  972. private init(token: String, error: Error?) {
  973. self.token = token
  974. self.error = error
  975. }
  976. convenience init(token: String) {
  977. self.init(token: token, error: nil)
  978. }
  979. convenience init(error: Error) {
  980. self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
  981. }
  982. func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
  983. return AppCheckTokenResultInteropFake(token: token, error: error)
  984. }
  985. func tokenDidChangeNotificationName() -> String {
  986. fatalError("\(#function) not implemented.")
  987. }
  988. func notificationTokenKey() -> String {
  989. fatalError("\(#function) not implemented.")
  990. }
  991. func notificationAppNameKey() -> String {
  992. fatalError("\(#function) not implemented.")
  993. }
  994. private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
  995. var token: String
  996. var error: Error?
  997. init(token: String, error: Error?) {
  998. self.token = token
  999. self.error = error
  1000. }
  1001. }
  1002. }
  1003. struct AppCheckErrorFake: Error {}
  1004. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
  1005. extension SafetyRating: Comparable {
  1006. public static func < (lhs: FirebaseVertexAI.SafetyRating,
  1007. rhs: FirebaseVertexAI.SafetyRating) -> Bool {
  1008. return lhs.category.rawValue < rhs.category.rawValue
  1009. }
  1010. }