GenerationConfigTests.swift 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. // Copyright 2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseVertexAI
  15. import Foundation
  16. import XCTest
  17. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  18. final class GenerationConfigTests: XCTestCase {
  19. let encoder = JSONEncoder()
  20. override func setUp() {
  21. encoder.outputFormatting = .init(
  22. arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
  23. )
  24. }
  25. // MARK: GenerationConfig Encoding
  26. func testEncodeGenerationConfig_default() throws {
  27. let generationConfig = GenerationConfig()
  28. let jsonData = try encoder.encode(generationConfig)
  29. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  30. XCTAssertEqual(json, """
  31. {
  32. }
  33. """)
  34. }
  35. func testEncodeGenerationConfig_allOptions() throws {
  36. let temperature: Float = 0.5
  37. let topP: Float = 0.75
  38. let topK = 40
  39. let candidateCount = 2
  40. let maxOutputTokens = 256
  41. let presencePenalty: Float = 0.5
  42. let frequencyPenalty: Float = 0.75
  43. let stopSequences = ["END", "DONE"]
  44. let responseMIMEType = "application/json"
  45. let generationConfig = GenerationConfig(
  46. temperature: temperature,
  47. topP: topP,
  48. topK: topK,
  49. candidateCount: candidateCount,
  50. maxOutputTokens: maxOutputTokens,
  51. presencePenalty: presencePenalty,
  52. frequencyPenalty: frequencyPenalty,
  53. stopSequences: stopSequences,
  54. responseMIMEType: responseMIMEType,
  55. responseSchema: .array(items: .string()),
  56. responseModalities: [.text, .image]
  57. )
  58. let jsonData = try encoder.encode(generationConfig)
  59. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  60. XCTAssertEqual(json, """
  61. {
  62. "candidateCount" : \(candidateCount),
  63. "frequencyPenalty" : \(frequencyPenalty),
  64. "maxOutputTokens" : \(maxOutputTokens),
  65. "presencePenalty" : \(presencePenalty),
  66. "responseMimeType" : "\(responseMIMEType)",
  67. "responseModalities" : [
  68. "TEXT",
  69. "IMAGE"
  70. ],
  71. "responseSchema" : {
  72. "items" : {
  73. "nullable" : false,
  74. "type" : "STRING"
  75. },
  76. "nullable" : false,
  77. "type" : "ARRAY"
  78. },
  79. "stopSequences" : [
  80. "END",
  81. "DONE"
  82. ],
  83. "temperature" : \(temperature),
  84. "topK" : \(topK),
  85. "topP" : \(topP)
  86. }
  87. """)
  88. }
  89. func testEncodeGenerationConfig_jsonResponse() throws {
  90. let mimeType = "application/json"
  91. let generationConfig = GenerationConfig(
  92. responseMIMEType: mimeType,
  93. responseSchema: .object(properties: [
  94. "firstName": .string(),
  95. "lastName": .string(),
  96. "age": .integer(),
  97. ])
  98. )
  99. let jsonData = try encoder.encode(generationConfig)
  100. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  101. XCTAssertEqual(json, """
  102. {
  103. "responseMimeType" : "\(mimeType)",
  104. "responseSchema" : {
  105. "nullable" : false,
  106. "properties" : {
  107. "age" : {
  108. "nullable" : false,
  109. "type" : "INTEGER"
  110. },
  111. "firstName" : {
  112. "nullable" : false,
  113. "type" : "STRING"
  114. },
  115. "lastName" : {
  116. "nullable" : false,
  117. "type" : "STRING"
  118. }
  119. },
  120. "required" : [
  121. "age",
  122. "firstName",
  123. "lastName"
  124. ],
  125. "type" : "OBJECT"
  126. }
  127. }
  128. """)
  129. }
  130. }