ModelContent.swift 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  16. extension [ModelContent] {
  17. // TODO: Rename and refactor this.
  18. func throwIfError() throws {
  19. for content in self {
  20. for part in content.parts {
  21. switch part {
  22. case let errorPart as ErrorPart:
  23. throw errorPart.error
  24. default:
  25. break
  26. }
  27. }
  28. }
  29. }
  30. }
  31. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  32. struct InternalPart: Equatable, Sendable {
  33. enum OneOfData: Equatable, Sendable {
  34. case text(String)
  35. case inlineData(InlineData)
  36. case fileData(FileData)
  37. case functionCall(FunctionCall)
  38. case functionResponse(FunctionResponse)
  39. struct UnsupportedDataError: Error {
  40. let decodingError: DecodingError
  41. var localizedDescription: String {
  42. decodingError.localizedDescription
  43. }
  44. }
  45. }
  46. let data: OneOfData?
  47. let isThought: Bool?
  48. let thoughtSignature: String?
  49. init(_ data: OneOfData, isThought: Bool?, thoughtSignature: String?) {
  50. self.data = data
  51. self.isThought = isThought
  52. self.thoughtSignature = thoughtSignature
  53. }
  54. }
  55. /// A type describing data in media formats interpretable by an AI model. Each generative AI
  56. /// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value
  57. /// may comprise multiple heterogeneous ``Part``s.
  58. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  59. public struct ModelContent: Equatable, Sendable {
  60. /// The role of the entity creating the ``ModelContent``. For user-generated client requests,
  61. /// for example, the role is `user`.
  62. public let role: String?
  63. /// The data parts comprising this ``ModelContent`` value.
  64. public var parts: [any Part] {
  65. return internalParts.compactMap { part -> (any Part)? in
  66. switch part.data {
  67. case let .text(text):
  68. return TextPart(text, isThought: part.isThought, thoughtSignature: part.thoughtSignature)
  69. case let .inlineData(inlineData):
  70. return InlineDataPart(
  71. inlineData, isThought: part.isThought, thoughtSignature: part.thoughtSignature
  72. )
  73. case let .fileData(fileData):
  74. return FileDataPart(
  75. fileData, isThought: part.isThought, thoughtSignature: part.thoughtSignature
  76. )
  77. case let .functionCall(functionCall):
  78. return FunctionCallPart(
  79. functionCall, isThought: part.isThought, thoughtSignature: part.thoughtSignature
  80. )
  81. case let .functionResponse(functionResponse):
  82. return FunctionResponsePart(
  83. functionResponse, isThought: part.isThought, thoughtSignature: part.thoughtSignature
  84. )
  85. case .none:
  86. // Filter out parts that contain missing or unrecognized data
  87. return nil
  88. }
  89. }
  90. }
  91. // TODO: Refactor this
  92. let internalParts: [InternalPart]
  93. /// Creates a new value from a list of ``Part``s.
  94. public init(role: String? = "user", parts: [any Part]) {
  95. self.role = role
  96. var convertedParts = [InternalPart]()
  97. for part in parts {
  98. switch part {
  99. case let textPart as TextPart:
  100. convertedParts.append(InternalPart(
  101. .text(textPart.text),
  102. isThought: textPart._isThought,
  103. thoughtSignature: textPart.thoughtSignature
  104. ))
  105. case let inlineDataPart as InlineDataPart:
  106. convertedParts.append(InternalPart(
  107. .inlineData(inlineDataPart.inlineData),
  108. isThought: inlineDataPart._isThought,
  109. thoughtSignature: inlineDataPart.thoughtSignature
  110. ))
  111. case let fileDataPart as FileDataPart:
  112. convertedParts.append(InternalPart(
  113. .fileData(fileDataPart.fileData),
  114. isThought: fileDataPart._isThought,
  115. thoughtSignature: fileDataPart.thoughtSignature
  116. ))
  117. case let functionCallPart as FunctionCallPart:
  118. convertedParts.append(InternalPart(
  119. .functionCall(functionCallPart.functionCall),
  120. isThought: functionCallPart._isThought,
  121. thoughtSignature: functionCallPart.thoughtSignature
  122. ))
  123. case let functionResponsePart as FunctionResponsePart:
  124. convertedParts.append(InternalPart(
  125. .functionResponse(functionResponsePart.functionResponse),
  126. isThought: functionResponsePart._isThought,
  127. thoughtSignature: functionResponsePart.thoughtSignature
  128. ))
  129. default:
  130. fatalError()
  131. }
  132. }
  133. internalParts = convertedParts
  134. }
  135. /// Creates a new value from any data interpretable as a ``Part``.
  136. /// See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
  137. public init(role: String? = "user", parts: any PartsRepresentable...) {
  138. let content = parts.flatMap { $0.partsValue }
  139. self.init(role: role, parts: content)
  140. }
  141. init(role: String?, parts: [InternalPart]) {
  142. self.role = role
  143. internalParts = parts
  144. }
  145. }
  146. // MARK: Codable Conformances
  147. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  148. extension ModelContent: Codable {
  149. enum CodingKeys: String, CodingKey {
  150. case role
  151. case internalParts = "parts"
  152. }
  153. public init(from decoder: any Decoder) throws {
  154. let container = try decoder.container(keyedBy: CodingKeys.self)
  155. role = try container.decodeIfPresent(String.self, forKey: .role)
  156. internalParts = try container.decodeIfPresent([InternalPart].self, forKey: .internalParts) ?? []
  157. }
  158. }
  159. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  160. extension InternalPart: Codable {
  161. enum CodingKeys: String, CodingKey {
  162. case isThought = "thought"
  163. case thoughtSignature
  164. }
  165. public func encode(to encoder: Encoder) throws {
  166. try data.encode(to: encoder)
  167. var container = encoder.container(keyedBy: CodingKeys.self)
  168. try container.encodeIfPresent(isThought, forKey: .isThought)
  169. try container.encodeIfPresent(thoughtSignature, forKey: .thoughtSignature)
  170. }
  171. public init(from decoder: Decoder) throws {
  172. do {
  173. data = try OneOfData(from: decoder)
  174. } catch let error as OneOfData.UnsupportedDataError {
  175. AILog.error(code: .decodedUnsupportedPartData, error.localizedDescription)
  176. data = nil
  177. } catch { // Re-throw any other error types
  178. throw error
  179. }
  180. let container = try decoder.container(keyedBy: CodingKeys.self)
  181. isThought = try container.decodeIfPresent(Bool.self, forKey: .isThought)
  182. thoughtSignature = try container.decodeIfPresent(String.self, forKey: .thoughtSignature)
  183. }
  184. }
  185. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  186. extension InternalPart.OneOfData: Codable {
  187. enum CodingKeys: String, CodingKey {
  188. case text
  189. case inlineData
  190. case fileData
  191. case functionCall
  192. case functionResponse
  193. }
  194. public func encode(to encoder: Encoder) throws {
  195. var container = encoder.container(keyedBy: CodingKeys.self)
  196. switch self {
  197. case let .text(text):
  198. try container.encode(text, forKey: .text)
  199. case let .inlineData(inlineData):
  200. try container.encode(inlineData, forKey: .inlineData)
  201. case let .fileData(fileData):
  202. try container.encode(fileData, forKey: .fileData)
  203. case let .functionCall(functionCall):
  204. try container.encode(functionCall, forKey: .functionCall)
  205. case let .functionResponse(functionResponse):
  206. try container.encode(functionResponse, forKey: .functionResponse)
  207. }
  208. }
  209. public init(from decoder: Decoder) throws {
  210. let values = try decoder.container(keyedBy: CodingKeys.self)
  211. if values.contains(.text) {
  212. self = try .text(values.decode(String.self, forKey: .text))
  213. } else if values.contains(.inlineData) {
  214. self = try .inlineData(values.decode(InlineData.self, forKey: .inlineData))
  215. } else if values.contains(.fileData) {
  216. self = try .fileData(values.decode(FileData.self, forKey: .fileData))
  217. } else if values.contains(.functionCall) {
  218. self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
  219. } else if values.contains(.functionResponse) {
  220. self = try .functionResponse(values.decode(FunctionResponse.self, forKey: .functionResponse))
  221. } else {
  222. let unexpectedKeys = values.allKeys.map { $0.stringValue }
  223. throw UnsupportedDataError(decodingError: DecodingError.dataCorrupted(
  224. DecodingError.Context(
  225. codingPath: values.codingPath,
  226. debugDescription: "Unexpected Part type(s): \(unexpectedKeys)"
  227. )
  228. ))
  229. }
  230. }
  231. }