GenerativeModelGoogleAITests.swift 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. // Copyright 2025 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAppCheckInterop
  15. import FirebaseAuthInterop
  16. import FirebaseCore
  17. import XCTest
  18. @testable import FirebaseAI
  19. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  20. final class GenerativeModelGoogleAITests: XCTestCase {
  21. let testPrompt = "What sorts of questions can I ask you?"
  22. let safetyRatingsNegligible: [SafetyRating] = [
  23. .init(
  24. category: .sexuallyExplicit,
  25. probability: .negligible,
  26. probabilityScore: 0.0,
  27. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  28. severityScore: 0.0,
  29. blocked: false
  30. ),
  31. .init(
  32. category: .hateSpeech,
  33. probability: .negligible,
  34. probabilityScore: 0.0,
  35. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  36. severityScore: 0.0,
  37. blocked: false
  38. ),
  39. .init(
  40. category: .harassment,
  41. probability: .negligible,
  42. probabilityScore: 0.0,
  43. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  44. severityScore: 0.0,
  45. blocked: false
  46. ),
  47. .init(
  48. category: .dangerousContent,
  49. probability: .negligible,
  50. probabilityScore: 0.0,
  51. severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
  52. severityScore: 0.0,
  53. blocked: false
  54. ),
  55. ].sorted()
  56. let testModelName = "test-model"
  57. let testModelResourceName = "projects/test-project-id/models/test-model"
  58. let apiConfig = FirebaseAI.defaultVertexAIAPIConfig
  59. let googleAISubdirectory = "mock-responses/googleai"
  60. var urlSession: URLSession!
  61. var model: GenerativeModel!
  62. override func setUp() async throws {
  63. let configuration = URLSessionConfiguration.default
  64. configuration.protocolClasses = [MockURLProtocol.self]
  65. urlSession = try XCTUnwrap(URLSession(configuration: configuration))
  66. model = GenerativeModel(
  67. modelName: testModelName,
  68. modelResourceName: testModelResourceName,
  69. firebaseInfo: GenerativeModelTestUtil.testFirebaseInfo(),
  70. apiConfig: apiConfig,
  71. tools: nil,
  72. requestOptions: RequestOptions(),
  73. urlSession: urlSession
  74. )
  75. }
  76. override func tearDown() {
  77. MockURLProtocol.requestHandler = nil
  78. }
  79. // MARK: - Generate Content
  80. func testGenerateContent_success_basicReplyLong() async throws {
  81. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  82. forResource: "unary-success-basic-reply-long",
  83. withExtension: "json",
  84. subdirectory: googleAISubdirectory
  85. )
  86. let response = try await model.generateContent(testPrompt)
  87. XCTAssertEqual(response.candidates.count, 1)
  88. let candidate = try XCTUnwrap(response.candidates.first)
  89. let finishReason = try XCTUnwrap(candidate.finishReason)
  90. XCTAssertEqual(finishReason, .stop)
  91. XCTAssertEqual(candidate.safetyRatings.count, 4)
  92. XCTAssertEqual(candidate.content.parts.count, 1)
  93. let part = try XCTUnwrap(candidate.content.parts.first)
  94. let partText = try XCTUnwrap(part as? TextPart).text
  95. XCTAssertTrue(partText.hasPrefix("Making professional-quality"))
  96. XCTAssertEqual(response.text, partText)
  97. XCTAssertEqual(response.functionCalls, [])
  98. XCTAssertEqual(response.inlineDataParts, [])
  99. }
  100. func testGenerateContent_success_basicReplyShort() async throws {
  101. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  102. forResource: "unary-success-basic-reply-short",
  103. withExtension: "json",
  104. subdirectory: googleAISubdirectory
  105. )
  106. let response = try await model.generateContent(testPrompt)
  107. XCTAssertEqual(response.candidates.count, 1)
  108. let candidate = try XCTUnwrap(response.candidates.first)
  109. let finishReason = try XCTUnwrap(candidate.finishReason)
  110. XCTAssertEqual(finishReason, .stop)
  111. XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
  112. XCTAssertEqual(candidate.content.parts.count, 1)
  113. let part = try XCTUnwrap(candidate.content.parts.first)
  114. let textPart = try XCTUnwrap(part as? TextPart)
  115. XCTAssertTrue(textPart.text.hasPrefix("Google's headquarters"))
  116. XCTAssertEqual(response.text, textPart.text)
  117. XCTAssertEqual(response.functionCalls, [])
  118. XCTAssertEqual(response.inlineDataParts, [])
  119. }
  120. func testGenerateContent_success_citations() async throws {
  121. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  122. forResource: "unary-success-citations",
  123. withExtension: "json",
  124. subdirectory: googleAISubdirectory
  125. )
  126. let response = try await model.generateContent(testPrompt)
  127. XCTAssertEqual(response.candidates.count, 1)
  128. let candidate = try XCTUnwrap(response.candidates.first)
  129. XCTAssertEqual(candidate.content.parts.count, 1)
  130. let text = try XCTUnwrap(response.text)
  131. XCTAssertTrue(text.hasPrefix("Okay, let's break down quantum mechanics."))
  132. let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
  133. XCTAssertEqual(citationMetadata.citations.count, 4)
  134. let citationSource1 = try XCTUnwrap(citationMetadata.citations[0])
  135. XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
  136. XCTAssertEqual(citationSource1.startIndex, 548)
  137. XCTAssertEqual(citationSource1.endIndex, 690)
  138. XCTAssertNil(citationSource1.title)
  139. XCTAssertEqual(citationSource1.license, "mit")
  140. XCTAssertNil(citationSource1.publicationDate)
  141. let citationSource2 = try XCTUnwrap(citationMetadata.citations[1])
  142. XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-1")
  143. XCTAssertEqual(citationSource2.startIndex, 1240)
  144. XCTAssertEqual(citationSource2.endIndex, 1407)
  145. XCTAssertNil(citationSource2.title, "some-citation-2")
  146. XCTAssertNil(citationSource2.license)
  147. XCTAssertNil(citationSource2.publicationDate)
  148. let citationSource3 = try XCTUnwrap(citationMetadata.citations[2])
  149. XCTAssertEqual(citationSource3.startIndex, 1942)
  150. XCTAssertEqual(citationSource3.endIndex, 2149)
  151. XCTAssertNil(citationSource3.uri)
  152. XCTAssertNil(citationSource3.license)
  153. XCTAssertNil(citationSource3.title)
  154. XCTAssertNil(citationSource3.publicationDate)
  155. let citationSource4 = try XCTUnwrap(citationMetadata.citations[3])
  156. XCTAssertEqual(citationSource4.startIndex, 2036)
  157. XCTAssertEqual(citationSource4.endIndex, 2175)
  158. XCTAssertNil(citationSource4.uri)
  159. XCTAssertNil(citationSource4.license)
  160. XCTAssertNil(citationSource4.title)
  161. XCTAssertNil(citationSource4.publicationDate)
  162. }
  163. func testGenerateContent_usageMetadata() async throws {
  164. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  165. forResource: "unary-success-basic-reply-short",
  166. withExtension: "json",
  167. subdirectory: googleAISubdirectory
  168. )
  169. let response = try await model.generateContent(testPrompt)
  170. let usageMetadata = try XCTUnwrap(response.usageMetadata)
  171. XCTAssertEqual(usageMetadata.promptTokenCount, 7)
  172. XCTAssertEqual(usageMetadata.promptTokensDetails.count, 1)
  173. XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .text)
  174. XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 7)
  175. XCTAssertEqual(usageMetadata.candidatesTokenCount, 22)
  176. XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
  177. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text)
  178. XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 22)
  179. }
  180. func testGenerateContent_failure_invalidAPIKey() async throws {
  181. let expectedStatusCode = 400
  182. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  183. forResource: "unary-failure-api-key",
  184. withExtension: "json",
  185. subdirectory: googleAISubdirectory,
  186. statusCode: expectedStatusCode
  187. )
  188. do {
  189. _ = try await model.generateContent(testPrompt)
  190. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  191. } catch let GenerateContentError.internalError(error as BackendError) {
  192. XCTAssertEqual(error.httpResponseCode, 400)
  193. XCTAssertEqual(error.status, .invalidArgument)
  194. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  195. XCTAssertTrue(error.localizedDescription.contains(error.message))
  196. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  197. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  198. let nsError = error as NSError
  199. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  200. XCTAssertEqual(nsError.code, error.httpResponseCode)
  201. return
  202. } catch {
  203. XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
  204. }
  205. }
  206. func testGenerateContent_failure_finishReasonSafety() async throws {
  207. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  208. forResource: "unary-failure-finish-reason-safety",
  209. withExtension: "json",
  210. subdirectory: googleAISubdirectory
  211. )
  212. do {
  213. _ = try await model.generateContent(testPrompt)
  214. XCTFail("Should throw")
  215. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  216. XCTAssertEqual(reason, .safety)
  217. XCTAssertEqual(response.text, "Safety error incoming in 5, 4, 3, 2...")
  218. } catch {
  219. XCTFail("Should throw a responseStoppedEarly")
  220. }
  221. }
  222. func testGenerateContent_failure_unknownModel() async throws {
  223. let expectedStatusCode = 404
  224. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  225. forResource: "unary-failure-unknown-model",
  226. withExtension: "json",
  227. subdirectory: googleAISubdirectory,
  228. statusCode: 404
  229. )
  230. do {
  231. _ = try await model.generateContent(testPrompt)
  232. XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
  233. } catch let GenerateContentError.internalError(underlying: rpcError as BackendError) {
  234. XCTAssertEqual(rpcError.status, .notFound)
  235. XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
  236. XCTAssertTrue(rpcError.message.hasPrefix("models/gemini-5.0-flash is not found"))
  237. } catch {
  238. XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
  239. }
  240. }
  241. // MARK: - Generate Content (Streaming)
  242. func testGenerateContentStream_successBasicReplyLong() async throws {
  243. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  244. forResource: "streaming-success-basic-reply-long",
  245. withExtension: "txt",
  246. subdirectory: googleAISubdirectory
  247. )
  248. var responses = 0
  249. let stream = try model.generateContentStream("Hi")
  250. for try await content in stream {
  251. XCTAssertNotNil(content.text)
  252. responses += 1
  253. }
  254. XCTAssertEqual(responses, 36)
  255. }
  256. func testGenerateContentStream_successBasicReplyShort() async throws {
  257. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  258. forResource: "streaming-success-basic-reply-short",
  259. withExtension: "txt",
  260. subdirectory: googleAISubdirectory
  261. )
  262. var responses = 0
  263. let stream = try model.generateContentStream("Hi")
  264. for try await content in stream {
  265. XCTAssertNotNil(content.text)
  266. responses += 1
  267. }
  268. XCTAssertEqual(responses, 3)
  269. }
  270. func testGenerateContentStream_successWithCitations() async throws {
  271. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  272. forResource: "streaming-success-citations",
  273. withExtension: "txt",
  274. subdirectory: googleAISubdirectory
  275. )
  276. let stream = try model.generateContentStream("Hi")
  277. var citations = [Citation]()
  278. var responses = [GenerateContentResponse]()
  279. for try await content in stream {
  280. responses.append(content)
  281. XCTAssertNotNil(content.text)
  282. let candidate = try XCTUnwrap(content.candidates.first)
  283. if let sources = candidate.citationMetadata?.citations {
  284. citations.append(contentsOf: sources)
  285. }
  286. }
  287. let lastCandidate = try XCTUnwrap(responses.last?.candidates.first)
  288. XCTAssertEqual(lastCandidate.finishReason, .stop)
  289. XCTAssertEqual(citations.count, 1)
  290. let citation = try XCTUnwrap(citations.first)
  291. XCTAssertEqual(citation.startIndex, 111)
  292. XCTAssertEqual(citation.endIndex, 236)
  293. let citationURI = try XCTUnwrap(citation.uri)
  294. XCTAssertTrue(citationURI.starts(with: "https://www."))
  295. XCTAssertNil(citation.license)
  296. XCTAssertNil(citation.title)
  297. XCTAssertNil(citation.publicationDate)
  298. }
  299. func testGenerateContentStream_failureInvalidAPIKey() async throws {
  300. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  301. forResource: "unary-failure-api-key",
  302. withExtension: "json",
  303. subdirectory: googleAISubdirectory
  304. )
  305. do {
  306. let stream = try model.generateContentStream("Hi")
  307. for try await _ in stream {
  308. XCTFail("No content is there, this shouldn't happen.")
  309. }
  310. } catch let GenerateContentError.internalError(error as BackendError) {
  311. XCTAssertEqual(error.httpResponseCode, 400)
  312. XCTAssertEqual(error.status, .invalidArgument)
  313. XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
  314. XCTAssertTrue(error.localizedDescription.contains(error.message))
  315. XCTAssertTrue(error.localizedDescription.contains(error.status.rawValue))
  316. XCTAssertTrue(error.localizedDescription.contains("\(error.httpResponseCode)"))
  317. let nsError = error as NSError
  318. XCTAssertEqual(nsError.domain, "\(Constants.baseErrorDomain).\(BackendError.self)")
  319. XCTAssertEqual(nsError.code, error.httpResponseCode)
  320. return
  321. }
  322. XCTFail("Should have caught an error.")
  323. }
  324. func testGenerateContentStream_failureFinishRecitation() async throws {
  325. MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler(
  326. forResource: "streaming-failure-recitation-no-content",
  327. withExtension: "txt",
  328. subdirectory: googleAISubdirectory
  329. )
  330. var responses = [GenerateContentResponse]()
  331. do {
  332. let stream = try model.generateContentStream("Hi")
  333. for try await response in stream {
  334. responses.append(response)
  335. }
  336. XCTFail("Expected a GenerateContentError.responseStoppedEarly error, but got no error.")
  337. } catch let GenerateContentError.responseStoppedEarly(reason, response) {
  338. XCTAssertEqual(reason, .recitation)
  339. let candidate = try XCTUnwrap(response.candidates.first)
  340. XCTAssertEqual(candidate.finishReason, reason)
  341. } catch {
  342. XCTFail("Expected a GenerateContentError.responseStoppedEarly error, but got error: \(error)")
  343. }
  344. XCTAssertEqual(responses.count, 8)
  345. let firstResponse = try XCTUnwrap(responses.first)
  346. XCTAssertEqual(firstResponse.text, "text1")
  347. let lastResponse = try XCTUnwrap(responses.last)
  348. XCTAssertEqual(lastResponse.text, "text8")
  349. }
  350. }