GenerationConfigTests.swift 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // Copyright 2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseAILogic
  15. import Foundation
  16. import XCTest
  17. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  18. final class GenerationConfigTests: XCTestCase {
  19. let encoder = JSONEncoder()
  20. override func setUp() {
  21. encoder.outputFormatting = .init(
  22. arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
  23. )
  24. }
  25. // MARK: GenerationConfig Encoding
  26. func testEncodeGenerationConfig_default() throws {
  27. let generationConfig = GenerationConfig()
  28. let jsonData = try encoder.encode(generationConfig)
  29. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  30. XCTAssertEqual(json, """
  31. {
  32. }
  33. """)
  34. }
  35. func testEncodeGenerationConfig_allOptions() throws {
  36. let temperature: Float = 0.5
  37. let topP: Float = 0.75
  38. let topK = 40
  39. let candidateCount = 2
  40. let maxOutputTokens = 256
  41. let presencePenalty: Float = 0.5
  42. let frequencyPenalty: Float = 0.75
  43. let stopSequences = ["END", "DONE"]
  44. let responseMIMEType = "application/json"
  45. let generationConfig = GenerationConfig(
  46. temperature: temperature,
  47. topP: topP,
  48. topK: topK,
  49. candidateCount: candidateCount,
  50. maxOutputTokens: maxOutputTokens,
  51. presencePenalty: presencePenalty,
  52. frequencyPenalty: frequencyPenalty,
  53. stopSequences: stopSequences,
  54. responseMIMEType: responseMIMEType,
  55. responseSchema: .array(items: .string()),
  56. responseModalities: [.text, .image]
  57. )
  58. let jsonData = try encoder.encode(generationConfig)
  59. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  60. XCTAssertEqual(json, """
  61. {
  62. "candidateCount" : \(candidateCount),
  63. "frequencyPenalty" : \(frequencyPenalty),
  64. "maxOutputTokens" : \(maxOutputTokens),
  65. "presencePenalty" : \(presencePenalty),
  66. "responseMimeType" : "\(responseMIMEType)",
  67. "responseModalities" : [
  68. "TEXT",
  69. "IMAGE"
  70. ],
  71. "responseSchema" : {
  72. "items" : {
  73. "nullable" : false,
  74. "type" : "STRING"
  75. },
  76. "nullable" : false,
  77. "type" : "ARRAY"
  78. },
  79. "stopSequences" : [
  80. "END",
  81. "DONE"
  82. ],
  83. "temperature" : \(temperature),
  84. "topK" : \(topK),
  85. "topP" : \(topP)
  86. }
  87. """)
  88. }
  89. func testEncodeGenerationConfig_jsonResponse() throws {
  90. let mimeType = "application/json"
  91. let generationConfig = GenerationConfig(
  92. responseMIMEType: mimeType,
  93. responseSchema: .object(properties: [
  94. "firstName": .string(),
  95. "middleNames": .array(items: .string(), minItems: 0, maxItems: 3),
  96. "lastName": .string(),
  97. "age": .integer(),
  98. ])
  99. )
  100. let jsonData = try encoder.encode(generationConfig)
  101. let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
  102. XCTAssertEqual(json, """
  103. {
  104. "responseMimeType" : "\(mimeType)",
  105. "responseSchema" : {
  106. "nullable" : false,
  107. "properties" : {
  108. "age" : {
  109. "nullable" : false,
  110. "type" : "INTEGER"
  111. },
  112. "firstName" : {
  113. "nullable" : false,
  114. "type" : "STRING"
  115. },
  116. "lastName" : {
  117. "nullable" : false,
  118. "type" : "STRING"
  119. },
  120. "middleNames" : {
  121. "items" : {
  122. "nullable" : false,
  123. "type" : "STRING"
  124. },
  125. "maxItems" : 3,
  126. "minItems" : 0,
  127. "nullable" : false,
  128. "type" : "ARRAY"
  129. }
  130. },
  131. "required" : [
  132. "age",
  133. "firstName",
  134. "lastName",
  135. "middleNames"
  136. ],
  137. "type" : "OBJECT"
  138. }
  139. }
  140. """)
  141. }
  142. }