GenerationConfigTests.swift 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. // Copyright 2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseVertexAI
  15. import Foundation
  16. import XCTest
  17. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  18. final class GenerationConfigTests: XCTestCase {
  19. let encoder = JSONEncoder()
  20. override func setUp() {
  21. encoder.outputFormatting = .init(
  22. arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
  23. )
  24. }
  25. // MARK: GenerationConfig Encoding
  26. func testEncodeGenerationConfig_default() throws {
  27. let generationConfig = GenerationConfig()
  28. let jsonData = try encoder.encode(generationConfig)
  29. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  30. XCTAssertEqual(json, """
  31. {
  32. }
  33. """)
  34. }
  35. func testEncodeGenerationConfig_allOptions() throws {
  36. let temperature: Float = 0.5
  37. let topP: Float = 0.75
  38. let topK = 40
  39. let candidateCount = 2
  40. let maxOutputTokens = 256
  41. let stopSequences = ["END", "DONE"]
  42. let responseMIMEType = "application/json"
  43. let generationConfig = GenerationConfig(
  44. temperature: temperature,
  45. topP: topP,
  46. topK: topK,
  47. candidateCount: candidateCount,
  48. maxOutputTokens: maxOutputTokens,
  49. stopSequences: stopSequences,
  50. responseMIMEType: responseMIMEType,
  51. responseSchema: .array(items: .string())
  52. )
  53. let jsonData = try encoder.encode(generationConfig)
  54. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  55. XCTAssertEqual(json, """
  56. {
  57. "candidateCount" : \(candidateCount),
  58. "maxOutputTokens" : \(maxOutputTokens),
  59. "responseMIMEType" : "\(responseMIMEType)",
  60. "responseSchema" : {
  61. "items" : {
  62. "nullable" : false,
  63. "type" : "STRING"
  64. },
  65. "nullable" : false,
  66. "type" : "ARRAY"
  67. },
  68. "stopSequences" : [
  69. "END",
  70. "DONE"
  71. ],
  72. "temperature" : \(temperature),
  73. "topK" : \(topK),
  74. "topP" : \(topP)
  75. }
  76. """)
  77. }
  78. func testEncodeGenerationConfig_jsonResponse() throws {
  79. let mimeType = "application/json"
  80. let generationConfig = GenerationConfig(
  81. responseMIMEType: mimeType,
  82. responseSchema: .object(properties: [
  83. "firstName": .string(),
  84. "lastName": .string(),
  85. "age": .integer(),
  86. ])
  87. )
  88. let jsonData = try encoder.encode(generationConfig)
  89. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  90. XCTAssertEqual(json, """
  91. {
  92. "responseMIMEType" : "\(mimeType)",
  93. "responseSchema" : {
  94. "nullable" : false,
  95. "properties" : {
  96. "age" : {
  97. "nullable" : false,
  98. "type" : "INTEGER"
  99. },
  100. "firstName" : {
  101. "nullable" : false,
  102. "type" : "STRING"
  103. },
  104. "lastName" : {
  105. "nullable" : false,
  106. "type" : "STRING"
  107. }
  108. },
  109. "required" : [
  110. "age",
  111. "firstName",
  112. "lastName"
  113. ],
  114. "type" : "OBJECT"
  115. }
  116. }
  117. """)
  118. }
  119. }