APITests.swift 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAI
  15. import FirebaseCore
  16. import XCTest
  17. #if canImport(AppKit)
  18. import AppKit // For NSImage extensions.
  19. #elseif canImport(UIKit)
  20. import UIKit // For UIImage extensions.
  21. #endif
  22. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  23. final class APITests: XCTestCase {
  24. func codeSamples() async throws {
  25. let app = FirebaseApp.app()
  26. let config = GenerationConfig(temperature: 0.2,
  27. topP: 0.1,
  28. topK: 16,
  29. candidateCount: 4,
  30. maxOutputTokens: 256,
  31. stopSequences: ["..."],
  32. responseMIMEType: "text/plain")
  33. let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
  34. let systemInstruction = ModelContent(
  35. role: "system",
  36. parts: TextPart("Talk like a pirate.")
  37. )
  38. let requestOptions = RequestOptions()
  39. let _ = RequestOptions(timeout: 30.0)
  40. // Instantiate Firebase AI SDK - Default App
  41. let firebaseAI = FirebaseAI.firebaseAI()
  42. let _ = FirebaseAI.firebaseAI(backend: .googleAI())
  43. let _ = FirebaseAI.firebaseAI(backend: .vertexAI())
  44. let _ = FirebaseAI.firebaseAI(backend: .vertexAI(location: "my-location"))
  45. // Instantiate Firebase AI SDK - Custom App
  46. let _ = FirebaseAI.firebaseAI(app: app!)
  47. let _ = FirebaseAI.firebaseAI(app: app!, backend: .googleAI())
  48. let _ = FirebaseAI.firebaseAI(app: app!, backend: .vertexAI())
  49. let _ = FirebaseAI.firebaseAI(app: app!, backend: .vertexAI(location: "my-location"))
  50. // Permutations without optional arguments.
  51. let _ = firebaseAI.generativeModel(modelName: "gemini-2.0-flash")
  52. let _ = firebaseAI.generativeModel(
  53. modelName: "gemini-2.0-flash",
  54. safetySettings: filters
  55. )
  56. let _ = firebaseAI.generativeModel(
  57. modelName: "gemini-2.0-flash",
  58. generationConfig: config
  59. )
  60. let _ = firebaseAI.generativeModel(
  61. modelName: "gemini-2.0-flash",
  62. systemInstruction: systemInstruction
  63. )
  64. // All arguments passed.
  65. let model = firebaseAI.generativeModel(
  66. modelName: "gemini-2.0-flash",
  67. generationConfig: config, // Optional
  68. safetySettings: filters, // Optional
  69. systemInstruction: systemInstruction, // Optional
  70. requestOptions: requestOptions // Optional
  71. )
  72. // Full Typed Usage
  73. let pngData = Data() // ....
  74. let contents = [ModelContent(
  75. role: "user",
  76. parts: [
  77. TextPart("Is it a cat?"),
  78. InlineDataPart(data: pngData, mimeType: "image/png"),
  79. ]
  80. )]
  81. do {
  82. let response = try await model.generateContent(contents)
  83. print(response.text ?? "Couldn't get text... check status")
  84. } catch {
  85. print("Error generating content: \(error)")
  86. }
  87. // Content input combinations.
  88. let _ = try await model.generateContent("Constant String")
  89. let str = "String Variable"
  90. let _ = try await model.generateContent(str)
  91. let _ = try await model.generateContent([str])
  92. let _ = try await model.generateContent(str, "abc", "def")
  93. let _ = try await model.generateContent(
  94. str,
  95. FileDataPart(uri: "gs://test-bucket/image.jpg", mimeType: "image/jpeg")
  96. )
  97. #if canImport(UIKit)
  98. _ = try await model.generateContent(UIImage())
  99. _ = try await model.generateContent([UIImage()])
  100. _ = try await model.generateContent([str, UIImage(), TextPart(str)])
  101. _ = try await model.generateContent(str, UIImage(), "def", UIImage())
  102. _ = try await model.generateContent([str, UIImage(), "def", UIImage()])
  103. _ = try await model.generateContent([ModelContent(parts: "def", UIImage()),
  104. ModelContent(parts: "def", UIImage())])
  105. #elseif canImport(AppKit)
  106. _ = try await model.generateContent(NSImage())
  107. _ = try await model.generateContent([NSImage()])
  108. _ = try await model.generateContent(str, NSImage(), "def", NSImage())
  109. _ = try await model.generateContent([str, NSImage(), "def", NSImage()])
  110. #endif
  111. // PartsRepresentable combinations.
  112. let _ = ModelContent(parts: [TextPart(str)])
  113. let _ = ModelContent(role: "model", parts: [TextPart(str)])
  114. let _ = ModelContent(parts: "Constant String")
  115. let _ = ModelContent(parts: str)
  116. let _ = ModelContent(parts: [str])
  117. let _ = ModelContent(parts: [str, InlineDataPart(data: Data(), mimeType: "foo")])
  118. #if canImport(UIKit)
  119. _ = ModelContent(role: "user", parts: UIImage())
  120. _ = ModelContent(role: "user", parts: [UIImage()])
  121. _ = ModelContent(parts: [str, UIImage()])
  122. // Note: without explicitly specifying`: [any PartsRepresentable]` this will fail to compile
  123. // below with "Cannot convert value of type `[Any]` to expected type `[any Part]`.
  124. let representable2: [any PartsRepresentable] = [str, UIImage()]
  125. _ = ModelContent(parts: representable2)
  126. _ = ModelContent(parts: [str, UIImage(), TextPart(str)])
  127. #elseif canImport(AppKit)
  128. _ = ModelContent(role: "user", parts: NSImage())
  129. _ = ModelContent(role: "user", parts: [NSImage()])
  130. _ = ModelContent(parts: [str, NSImage()])
  131. // Note: without explicitly specifying`: [any PartsRepresentable]` this will fail to compile
  132. // below with "Cannot convert value of type `[Any]` to expected type `[any Part]`.
  133. let representable2: [any PartsRepresentable] = [str, NSImage()]
  134. _ = ModelContent(parts: representable2)
  135. _ = ModelContent(parts: [str, NSImage(), TextPart(str)])
  136. #endif
  137. // countTokens API
  138. let _: CountTokensResponse = try await model.countTokens("What color is the Sky?")
  139. #if canImport(UIKit)
  140. let _: CountTokensResponse = try await model.countTokens("What color is the Sky?",
  141. UIImage())
  142. let _: CountTokensResponse = try await model.countTokens([
  143. ModelContent(parts: "What color is the Sky?", UIImage()),
  144. ModelContent(parts: UIImage(), "What color is the Sky?", UIImage()),
  145. ])
  146. #endif
  147. // Chat
  148. _ = model.startChat()
  149. _ = model.startChat(history: [ModelContent(parts: "abc")])
  150. }
  151. // Public API tests for GenerateContentResponse.
  152. func generateContentResponseAPI() {
  153. let response = GenerateContentResponse(candidates: [])
  154. let _: [Candidate] = response.candidates
  155. let _: PromptFeedback? = response.promptFeedback
  156. // Usage Metadata
  157. guard let usageMetadata = response.usageMetadata else { fatalError() }
  158. let _: Int = usageMetadata.promptTokenCount
  159. let _: Int = usageMetadata.candidatesTokenCount
  160. let _: Int = usageMetadata.totalTokenCount
  161. // Computed Properties
  162. let _: String? = response.text
  163. let _: [FunctionCallPart] = response.functionCalls
  164. }
  165. }