IntegrationTests.swift 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. // Copyright 2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseCore
  15. import FirebaseVertexAI
  16. import XCTest
  17. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  18. final class IntegrationTests: XCTestCase {
  19. // Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
  20. let generationConfig = GenerationConfig(
  21. temperature: 0.0,
  22. topP: 0.0,
  23. topK: 1,
  24. responseMIMEType: "text/plain"
  25. )
  26. let systemInstruction = ModelContent(
  27. role: "system",
  28. parts: "You are a friendly and helpful assistant."
  29. )
  30. let safetySettings = [
  31. SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
  32. SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
  33. SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
  34. SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
  35. SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
  36. ]
  37. var vertex: VertexAI!
  38. var model: GenerativeModel!
  39. override func setUp() async throws {
  40. try XCTSkipIf(ProcessInfo.processInfo.environment["VertexAIRunIntegrationTests"] == nil, """
  41. Vertex AI integration tests skipped; to enable them, set the VertexAIRunIntegrationTests \
  42. environment variable in Xcode or CI jobs.
  43. """)
  44. let plistPath = try XCTUnwrap(Bundle.module.path(
  45. forResource: "GoogleService-Info",
  46. ofType: "plist"
  47. ))
  48. let options = try XCTUnwrap(FirebaseOptions(contentsOfFile: plistPath))
  49. FirebaseApp.configure(options: options)
  50. vertex = VertexAI.vertexAI()
  51. model = vertex.generativeModel(
  52. modelName: "gemini-1.5-flash",
  53. generationConfig: generationConfig,
  54. safetySettings: safetySettings,
  55. tools: [],
  56. toolConfig: .init(functionCallingConfig: .init(mode: FunctionCallingConfig.Mode.none)),
  57. systemInstruction: systemInstruction
  58. )
  59. }
  60. override func tearDown() async throws {
  61. if let app = FirebaseApp.app() {
  62. await app.delete()
  63. }
  64. }
  65. // MARK: - Generate Content
  66. func testGenerateContent() async throws {
  67. let prompt = "Where is Google headquarters located? Answer with the city name only."
  68. let response = try await model.generateContent(prompt)
  69. let text = try XCTUnwrap(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
  70. XCTAssertEqual(text, "Mountain View")
  71. }
  72. // MARK: - Count Tokens
  73. func testCountTokens_text() async throws {
  74. let prompt = "Why is the sky blue?"
  75. model = vertex.generativeModel(
  76. modelName: "gemini-1.5-pro",
  77. generationConfig: generationConfig,
  78. safetySettings: [
  79. SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
  80. SafetySetting(harmCategory: .hateSpeech, threshold: .blockMediumAndAbove),
  81. SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockOnlyHigh),
  82. SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone),
  83. SafetySetting(harmCategory: .civicIntegrity, threshold: .off),
  84. ],
  85. toolConfig: .init(functionCallingConfig: .init(mode: .auto)),
  86. systemInstruction: systemInstruction
  87. )
  88. let response = try await model.countTokens(prompt)
  89. XCTAssertEqual(response.totalTokens, 14)
  90. XCTAssertEqual(response.totalBillableCharacters, 51)
  91. }
  92. func testCountTokens_image_inlineData() async throws {
  93. guard let image = UIImage(systemName: "cloud") else {
  94. XCTFail("Image not found.")
  95. return
  96. }
  97. let response = try await model.countTokens(image)
  98. XCTAssertEqual(response.totalTokens, 266)
  99. XCTAssertEqual(response.totalBillableCharacters, 35)
  100. }
  101. func testCountTokens_image_fileData() async throws {
  102. let fileData = FileDataPart(
  103. uri: "gs://ios-opensource-samples.appspot.com/ios/public/blank.jpg",
  104. mimeType: "image/jpeg"
  105. )
  106. let response = try await model.countTokens(fileData)
  107. XCTAssertEqual(response.totalTokens, 266)
  108. XCTAssertEqual(response.totalBillableCharacters, 35)
  109. }
  110. func testCountTokens_functionCalling() async throws {
  111. let sumDeclaration = FunctionDeclaration(
  112. name: "sum",
  113. description: "Adds two integers.",
  114. parameters: ["x": .integer(), "y": .integer()]
  115. )
  116. model = vertex.generativeModel(
  117. modelName: "gemini-1.5-flash",
  118. tools: [Tool(functionDeclarations: [sumDeclaration])],
  119. toolConfig: .init(functionCallingConfig: .init(mode: .any, allowedFunctionNames: ["sum"]))
  120. )
  121. let prompt = "What is 10 + 32?"
  122. let sumCall = FunctionCallPart(name: "sum", args: ["x": .number(10), "y": .number(32)])
  123. let sumResponse = FunctionResponsePart(name: "sum", response: ["result": .number(42)])
  124. let response = try await model.countTokens([
  125. ModelContent(role: "user", parts: prompt),
  126. ModelContent(role: "model", parts: sumCall),
  127. ModelContent(role: "function", parts: sumResponse),
  128. ])
  129. XCTAssertEqual(response.totalTokens, 24)
  130. XCTAssertEqual(response.totalBillableCharacters, 71)
  131. }
  132. func testCountTokens_jsonSchema() async throws {
  133. model = vertex.generativeModel(
  134. modelName: "gemini-1.5-flash",
  135. generationConfig: GenerationConfig(
  136. responseMIMEType: "application/json",
  137. responseSchema: Schema.object(properties: [
  138. "startDate": .string(format: .custom("date")),
  139. "yearsSince": .integer(format: .custom("int16")),
  140. "hoursSince": .integer(format: .int32),
  141. "minutesSince": .integer(format: .int64),
  142. ])
  143. )
  144. )
  145. let prompt = "It is 2050-01-01, how many years, hours and minutes since 2000-01-01?"
  146. let response = try await model.countTokens(prompt)
  147. XCTAssertEqual(response.totalTokens, 34)
  148. XCTAssertEqual(response.totalBillableCharacters, 59)
  149. }
  150. }