ModelContent.swift 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  16. extension [ModelContent] {
  17. // TODO: Rename and refactor this.
  18. func throwIfError() throws {
  19. for content in self {
  20. for part in content.parts {
  21. switch part {
  22. case let errorPart as ErrorPart:
  23. throw errorPart.error
  24. default:
  25. break
  26. }
  27. }
  28. }
  29. }
  30. }
  31. /// A type describing data in media formats interpretable by an AI model. Each generative AI
  32. /// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value
  33. /// may comprise multiple heterogeneous ``Part``s.
  34. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  35. public struct ModelContent: Equatable, Sendable {
  36. enum InternalPart: Equatable, Sendable {
  37. case text(String)
  38. case inlineData(mimetype: String, Data)
  39. case fileData(mimetype: String, uri: String)
  40. case functionCall(FunctionCall)
  41. case functionResponse(FunctionResponse)
  42. }
  43. /// The role of the entity creating the ``ModelContent``. For user-generated client requests,
  44. /// for example, the role is `user`.
  45. public let role: String?
  46. /// The data parts comprising this ``ModelContent`` value.
  47. public var parts: [any Part] {
  48. var convertedParts = [any Part]()
  49. for part in internalParts {
  50. switch part {
  51. case let .text(text):
  52. convertedParts.append(TextPart(text))
  53. case let .inlineData(mimetype, data):
  54. convertedParts.append(InlineDataPart(data: data, mimeType: mimetype))
  55. case let .fileData(mimetype, uri):
  56. convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype))
  57. case let .functionCall(functionCall):
  58. convertedParts.append(FunctionCallPart(functionCall))
  59. case let .functionResponse(functionResponse):
  60. convertedParts.append(FunctionResponsePart(functionResponse))
  61. }
  62. }
  63. return convertedParts
  64. }
  65. // TODO: Refactor this
  66. let internalParts: [InternalPart]
  67. /// Creates a new value from a list of ``Part``s.
  68. public init(role: String? = "user", parts: [any Part]) {
  69. self.role = role
  70. var convertedParts = [InternalPart]()
  71. for part in parts {
  72. switch part {
  73. case let textPart as TextPart:
  74. convertedParts.append(.text(textPart.text))
  75. case let inlineDataPart as InlineDataPart:
  76. let inlineData = inlineDataPart.inlineData
  77. convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
  78. case let fileDataPart as FileDataPart:
  79. let fileData = fileDataPart.fileData
  80. convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
  81. case let functionCallPart as FunctionCallPart:
  82. convertedParts.append(.functionCall(functionCallPart.functionCall))
  83. case let functionResponsePart as FunctionResponsePart:
  84. convertedParts.append(.functionResponse(functionResponsePart.functionResponse))
  85. default:
  86. fatalError()
  87. }
  88. }
  89. internalParts = convertedParts
  90. }
  91. /// Creates a new value from any data interpretable as a ``Part``.
  92. /// See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
  93. public init(role: String? = "user", parts: any PartsRepresentable...) {
  94. let content = parts.flatMap { $0.partsValue }
  95. self.init(role: role, parts: content)
  96. }
  97. }
  98. // MARK: Codable Conformances
  99. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  100. extension ModelContent: Codable {
  101. enum CodingKeys: String, CodingKey {
  102. case role
  103. case internalParts = "parts"
  104. }
  105. public init(from decoder: any Decoder) throws {
  106. let container = try decoder.container(keyedBy: CodingKeys.self)
  107. role = try container.decodeIfPresent(String.self, forKey: .role)
  108. internalParts = try container.decodeIfPresent([InternalPart].self, forKey: .internalParts) ?? []
  109. }
  110. }
  111. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  112. extension ModelContent.InternalPart: Codable {
  113. enum CodingKeys: String, CodingKey {
  114. case text
  115. case inlineData
  116. case fileData
  117. case functionCall
  118. case functionResponse
  119. }
  120. public func encode(to encoder: Encoder) throws {
  121. var container = encoder.container(keyedBy: CodingKeys.self)
  122. switch self {
  123. case let .text(text):
  124. try container.encode(text, forKey: .text)
  125. case let .inlineData(mimetype, bytes):
  126. try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData)
  127. case let .fileData(mimetype: mimetype, url):
  128. try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData)
  129. case let .functionCall(functionCall):
  130. try container.encode(functionCall, forKey: .functionCall)
  131. case let .functionResponse(functionResponse):
  132. try container.encode(functionResponse, forKey: .functionResponse)
  133. }
  134. }
  135. public init(from decoder: Decoder) throws {
  136. let values = try decoder.container(keyedBy: CodingKeys.self)
  137. if values.contains(.text) {
  138. self = try .text(values.decode(String.self, forKey: .text))
  139. } else if values.contains(.inlineData) {
  140. let inlineData = try values.decode(InlineData.self, forKey: .inlineData)
  141. self = .inlineData(mimetype: inlineData.mimeType, inlineData.data)
  142. } else if values.contains(.fileData) {
  143. let fileData = try values.decode(FileData.self, forKey: .fileData)
  144. self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)
  145. } else if values.contains(.functionCall) {
  146. self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
  147. } else if values.contains(.functionResponse) {
  148. self = try .functionResponse(values.decode(FunctionResponse.self, forKey: .functionResponse))
  149. } else {
  150. let unexpectedKeys = values.allKeys.map { $0.stringValue }
  151. throw DecodingError.dataCorrupted(DecodingError.Context(
  152. codingPath: values.codingPath,
  153. debugDescription: "Unexpected Part type(s): \(unexpectedKeys)"
  154. ))
  155. }
  156. }
  157. }
  158. // MARK: - ModelContentBuilder
  159. @resultBuilder
  160. public struct ModelContentBuilder {
  161. typealias Expression = PartsRepresentable
  162. typealias Component = [PartsRepresentable]
  163. typealias Result = ModelContent
  164. public static func buildExpression(_ expression: PartsRepresentable) -> [any PartsRepresentable] {
  165. return [expression]
  166. }
  167. public static func buildBlock(_ components: [any PartsRepresentable]...)
  168. -> [any PartsRepresentable] {
  169. return components
  170. }
  171. public static func buildEither(first component: [any PartsRepresentable])
  172. -> [any PartsRepresentable] {
  173. return component
  174. }
  175. public static func buildEither(second component: [any PartsRepresentable])
  176. -> [any PartsRepresentable] {
  177. return component
  178. }
  179. public static func buildArray(_ components: [any PartsRepresentable])
  180. -> [any PartsRepresentable] {
  181. return components
  182. }
  183. public static func buildArray(_ components: [[any PartsRepresentable]])
  184. -> [any PartsRepresentable] {
  185. return components.flatMap { $0 }
  186. }
  187. public static func buildOptional(_ component: [any PartsRepresentable]?)
  188. -> [any PartsRepresentable] {
  189. return component ?? []
  190. }
  191. public static func buildLimitedAvailability(_ component: [any PartsRepresentable])
  192. -> [any PartsRepresentable] {
  193. return component
  194. }
  195. public static func buildFinalResult(_ component: [any PartsRepresentable]) -> ModelContent {
  196. return ModelContent(parts: component)
  197. }
  198. }