VertexAIAPITests.swift 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseCore
  15. import FirebaseVertexAI
  16. import XCTest
  17. #if canImport(AppKit)
  18. import AppKit // For NSImage extensions.
  19. #elseif canImport(UIKit)
  20. import UIKit // For UIImage extensions.
  21. #endif
  22. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  23. final class VertexAIAPITests: XCTestCase {
  24. func codeSamples() async throws {
  25. let app = FirebaseApp.app()
  26. let config = GenerationConfig(temperature: 0.2,
  27. topP: 0.1,
  28. topK: 16,
  29. candidateCount: 4,
  30. maxOutputTokens: 256,
  31. stopSequences: ["..."],
  32. responseMIMEType: "text/plain")
  33. let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
  34. let systemInstruction = ModelContent(
  35. role: "system",
  36. parts: TextPart("Talk like a pirate.")
  37. )
  38. // Instantiate Vertex AI SDK - Default App
  39. let vertexAI = VertexAI.vertexAI()
  40. let _ = VertexAI.vertexAI(location: "my-location")
  41. // Instantiate Vertex AI SDK - Custom App
  42. let _ = VertexAI.vertexAI(app: app!)
  43. let _ = VertexAI.vertexAI(app: app!, location: "my-location")
  44. // Permutations without optional arguments.
  45. let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro")
  46. let _ = vertexAI.generativeModel(
  47. modelName: "gemini-1.0-pro",
  48. safetySettings: filters
  49. )
  50. let _ = vertexAI.generativeModel(
  51. modelName: "gemini-1.0-pro",
  52. generationConfig: config
  53. )
  54. let _ = vertexAI.generativeModel(
  55. modelName: "gemini-1.0-pro",
  56. systemInstruction: systemInstruction
  57. )
  58. // All arguments passed.
  59. let genAI = vertexAI.generativeModel(
  60. modelName: "gemini-1.0-pro",
  61. generationConfig: config, // Optional
  62. safetySettings: filters, // Optional
  63. systemInstruction: systemInstruction // Optional
  64. )
  65. // Full Typed Usage
  66. let pngData = Data() // ....
  67. let contents = [ModelContent(
  68. role: "user",
  69. parts: [
  70. TextPart("Is it a cat?"),
  71. InlineDataPart(data: pngData, mimeType: "image/png"),
  72. ]
  73. )]
  74. do {
  75. let response = try await genAI.generateContent(contents)
  76. print(response.text ?? "Couldn't get text... check status")
  77. } catch {
  78. print("Error generating content: \(error)")
  79. }
  80. // Content input combinations.
  81. let _ = try await genAI.generateContent("Constant String")
  82. let str = "String Variable"
  83. let _ = try await genAI.generateContent(str)
  84. let _ = try await genAI.generateContent([str])
  85. let _ = try await genAI.generateContent(str, "abc", "def")
  86. let _ = try await genAI.generateContent(
  87. str,
  88. FileDataPart(uri: "gs://test-bucket/image.jpg", mimeType: "image/jpeg")
  89. )
  90. #if canImport(UIKit)
  91. _ = try await genAI.generateContent(UIImage())
  92. _ = try await genAI.generateContent([UIImage()])
  93. _ = try await genAI.generateContent([str, UIImage(), TextPart(str)])
  94. _ = try await genAI.generateContent(str, UIImage(), "def", UIImage())
  95. _ = try await genAI.generateContent([str, UIImage(), "def", UIImage()])
  96. _ = try await genAI.generateContent([ModelContent(parts: "def", UIImage()),
  97. ModelContent(parts: "def", UIImage())])
  98. #elseif canImport(AppKit)
  99. _ = try await genAI.generateContent(NSImage())
  100. _ = try await genAI.generateContent([NSImage()])
  101. _ = try await genAI.generateContent(str, NSImage(), "def", NSImage())
  102. _ = try await genAI.generateContent([str, NSImage(), "def", NSImage()])
  103. #endif
  104. // PartsRepresentable combinations.
  105. let _ = ModelContent(parts: [TextPart(str)])
  106. let _ = ModelContent(role: "model", parts: [TextPart(str)])
  107. let _ = ModelContent(parts: "Constant String")
  108. let _ = ModelContent(parts: str)
  109. let _ = ModelContent(parts: [str])
  110. let _ = ModelContent(parts: [str, InlineDataPart(data: Data(), mimeType: "foo")])
  111. #if canImport(UIKit)
  112. _ = ModelContent(role: "user", parts: UIImage())
  113. _ = ModelContent(role: "user", parts: [UIImage()])
  114. _ = ModelContent(parts: [str, UIImage()])
  115. // Note: without explicitly specifying`: [any PartsRepresentable]` this will fail to compile
  116. // below with "Cannot convert value of type `[Any]` to expected type `[any Part]`.
  117. let representable2: [any PartsRepresentable] = [str, UIImage()]
  118. _ = ModelContent(parts: representable2)
  119. _ = ModelContent(parts: [str, UIImage(), TextPart(str)])
  120. #elseif canImport(AppKit)
  121. _ = ModelContent(role: "user", parts: NSImage())
  122. _ = ModelContent(role: "user", parts: [NSImage()])
  123. _ = ModelContent(parts: [str, NSImage()])
  124. // Note: without explicitly specifying`: [any PartsRepresentable]` this will fail to compile
  125. // below with "Cannot convert value of type `[Any]` to expected type `[any Part]`.
  126. let representable2: [any PartsRepresentable] = [str, NSImage()]
  127. _ = ModelContent(parts: representable2)
  128. _ = ModelContent(parts: [str, NSImage(), TextPart(str)])
  129. #endif
  130. // countTokens API
  131. let _: CountTokensResponse = try await genAI.countTokens("What color is the Sky?")
  132. #if canImport(UIKit)
  133. let _: CountTokensResponse = try await genAI.countTokens("What color is the Sky?",
  134. UIImage())
  135. let _: CountTokensResponse = try await genAI.countTokens([
  136. ModelContent(parts: "What color is the Sky?", UIImage()),
  137. ModelContent(parts: UIImage(), "What color is the Sky?", UIImage()),
  138. ])
  139. #endif
  140. // Chat
  141. _ = genAI.startChat()
  142. _ = genAI.startChat(history: [ModelContent(parts: "abc")])
  143. }
  144. // Public API tests for GenerateContentResponse.
  145. func generateContentResponseAPI() {
  146. let response = GenerateContentResponse(candidates: [])
  147. let _: [CandidateResponse] = response.candidates
  148. let _: PromptFeedback? = response.promptFeedback
  149. // Usage Metadata
  150. guard let usageMetadata = response.usageMetadata else { fatalError() }
  151. let _: Int = usageMetadata.promptTokenCount
  152. let _: Int = usageMetadata.candidatesTokenCount
  153. let _: Int = usageMetadata.totalTokenCount
  154. // Computed Properties
  155. let _: String? = response.text
  156. let _: [FunctionCallPart] = response.functionCalls
  157. }
  158. // Result builder alternative
  159. /*
  160. let pngData = Data() // ....
  161. let contents = [GenAIContent(role: "user",
  162. parts: [
  163. .text("Is it a cat?"),
  164. .png(pngData)
  165. ])]
  166. // Turns into...
  167. let contents = GenAIContent {
  168. Role("user") {
  169. Text("Is this a cat?")
  170. Image(png: pngData)
  171. }
  172. }
  173. GenAIContent {
  174. ForEach(myInput) { input in
  175. Role(input.role) {
  176. input.contents
  177. }
  178. }
  179. }
  180. // Thoughts: this looks great from a code demo, but since I assume most content will be
  181. // user generated, the result builder may not be the best API.
  182. */
  183. }