GenerationConfigTests.swift 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. // Copyright 2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseVertexAI
  15. import Foundation
  16. import XCTest
  17. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  18. final class GenerationConfigTests: XCTestCase {
  19. let encoder = JSONEncoder()
  20. override func setUp() {
  21. encoder.outputFormatting = .init(
  22. arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
  23. )
  24. }
  25. // MARK: GenerationConfig Encoding
  26. func testEncodeGenerationConfig_default() throws {
  27. let generationConfig = GenerationConfig()
  28. let jsonData = try encoder.encode(generationConfig)
  29. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  30. XCTAssertEqual(json, """
  31. {
  32. }
  33. """)
  34. }
  35. func testEncodeGenerationConfig_allOptions() throws {
  36. let temperature: Float = 0.5
  37. let topP: Float = 0.75
  38. let topK = 40
  39. let candidateCount = 2
  40. let maxOutputTokens = 256
  41. let presencePenalty: Float = 0.5
  42. let frequencyPenalty: Float = 0.75
  43. let stopSequences = ["END", "DONE"]
  44. let responseMIMEType = "application/json"
  45. let responseModalities: [ContentModality] = [.text, .image, .audio]
  46. let generationConfig = GenerationConfig(
  47. temperature: temperature,
  48. topP: topP,
  49. topK: topK,
  50. candidateCount: candidateCount,
  51. maxOutputTokens: maxOutputTokens,
  52. presencePenalty: presencePenalty,
  53. frequencyPenalty: frequencyPenalty,
  54. stopSequences: stopSequences,
  55. responseMIMEType: responseMIMEType,
  56. responseSchema: .array(items: .string()),
  57. responseModalities: responseModalities
  58. )
  59. let jsonData = try encoder.encode(generationConfig)
  60. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  61. XCTAssertEqual(json, """
  62. {
  63. "candidateCount" : \(candidateCount),
  64. "frequencyPenalty" : \(frequencyPenalty),
  65. "maxOutputTokens" : \(maxOutputTokens),
  66. "presencePenalty" : \(presencePenalty),
  67. "responseMimeType" : "\(responseMIMEType)",
  68. "responseModalities" : [
  69. "TEXT",
  70. "IMAGE",
  71. "AUDIO"
  72. ],
  73. "responseSchema" : {
  74. "items" : {
  75. "nullable" : false,
  76. "type" : "STRING"
  77. },
  78. "nullable" : false,
  79. "type" : "ARRAY"
  80. },
  81. "stopSequences" : [
  82. "END",
  83. "DONE"
  84. ],
  85. "temperature" : \(temperature),
  86. "topK" : \(topK),
  87. "topP" : \(topP)
  88. }
  89. """)
  90. }
  91. func testEncodeGenerationConfig_jsonResponse() throws {
  92. let mimeType = "application/json"
  93. let generationConfig = GenerationConfig(
  94. responseMIMEType: mimeType,
  95. responseSchema: .object(properties: [
  96. "firstName": .string(),
  97. "lastName": .string(),
  98. "age": .integer(),
  99. ])
  100. )
  101. let jsonData = try encoder.encode(generationConfig)
  102. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  103. XCTAssertEqual(json, """
  104. {
  105. "responseMimeType" : "\(mimeType)",
  106. "responseSchema" : {
  107. "nullable" : false,
  108. "properties" : {
  109. "age" : {
  110. "nullable" : false,
  111. "type" : "INTEGER"
  112. },
  113. "firstName" : {
  114. "nullable" : false,
  115. "type" : "STRING"
  116. },
  117. "lastName" : {
  118. "nullable" : false,
  119. "type" : "STRING"
  120. }
  121. },
  122. "required" : [
  123. "age",
  124. "firstName",
  125. "lastName"
  126. ],
  127. "type" : "OBJECT"
  128. }
  129. }
  130. """)
  131. }
  132. func testEncodeGenerationConfig_with_responseModalities() throws {
  133. let responseModalities: [ContentModality] = [.text, .image, .audio]
  134. let generationConfig = GenerationConfig(responseModalities: responseModalities)
  135. let jsonData = try encoder.encode(generationConfig)
  136. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  137. XCTAssertEqual(json, """
  138. {
  139. "responseModalities" : [
  140. "TEXT",
  141. "IMAGE",
  142. "AUDIO"
  143. ]
  144. }
  145. """)
  146. }
  147. }
  148. func testEncodeGenerationConfig_jsonResponse() throws {
  149. let mimeType = "application/json"
  150. let generationConfig = GenerationConfig(
  151. responseMIMEType: mimeType,
  152. responseSchema: .object(properties: [
  153. "firstName": .string(),
  154. "lastName": .string(),
  155. "age": .integer(),
  156. ])
  157. )
  158. let jsonData = try encoder.encode(generationConfig)
  159. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  160. XCTAssertEqual(json, """
  161. {
  162. "responseMimeType" : "\(mimeType)",
  163. "responseSchema" : {
  164. "nullable" : false,
  165. "properties" : {
  166. "age" : {
  167. "nullable" : false,
  168. "type" : "INTEGER"
  169. },
  170. "firstName" : {
  171. "nullable" : false,
  172. "type" : "STRING"
  173. },
  174. "lastName" : {
  175. "nullable" : false,
  176. "type" : "STRING"
  177. }
  178. },
  179. "required" : [
  180. "age",
  181. "firstName",
  182. "lastName"
  183. ],
  184. "type" : "OBJECT"
  185. }
  186. }
  187. """)
  188. }
  189. }