ModelContent.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. #if canImport(UIKit)
  16. import UIKit
  17. #endif // canImport(UIKit)
  18. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  19. @resultBuilder
  20. public struct PartsBuilder {
  21. public static func buildBlock(_ parts: [any ModelContent.Part]...) -> [any ModelContent.Part] {
  22. return parts.flatMap { $0 }
  23. }
  24. public static func buildExpression(_ parts: PartsRepresentable) -> [any ModelContent.Part] {
  25. parts.partsValue
  26. }
  27. #if canImport(UIKit)
  28. public static func buildExpression(_ images: [UIImage]) -> [any ModelContent.Part] {
  29. return images.flatMap { $0.partsValue }
  30. }
  31. #endif // canImport(UIKit)
  32. }
  33. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  34. public extension ModelContent.Part {
  35. var text: String? {
  36. switch self {
  37. case let textPart as TextPart:
  38. return textPart.textValue
  39. default:
  40. return nil
  41. }
  42. }
  43. var partsValue: [any ModelContent.Part] {
  44. return [self]
  45. }
  46. }
  47. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  48. extension [ModelContent] {
  49. // TODO: Rename and refactor this.
  50. func throwIfError() throws {
  51. for content in self {
  52. for part in content.parts {
  53. switch part {
  54. case let errorPart as ErrorPart:
  55. throw errorPart
  56. default:
  57. break
  58. }
  59. }
  60. }
  61. }
  62. }
  63. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  64. public struct TextPart: ModelContent.Part {
  65. public let textValue: String
  66. public init(textValue: String) {
  67. self.textValue = textValue
  68. }
  69. }
  70. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  71. public struct InlineData: Codable, Equatable, Sendable {
  72. public let mimeType: String
  73. public let data: Data
  74. public init(mimeType: String, data: Data) {
  75. self.mimeType = mimeType
  76. self.data = data
  77. }
  78. }
  79. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  80. public struct InlineDataPart: ModelContent.Part {
  81. public let inlineData: InlineData
  82. public init(inlineData: InlineData) {
  83. self.inlineData = inlineData
  84. }
  85. }
  86. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  87. public struct FileData: Codable, Equatable, Sendable {
  88. enum CodingKeys: String, CodingKey {
  89. case mimeType = "mime_type"
  90. case uri = "file_uri"
  91. }
  92. public let mimeType: String
  93. public let uri: String
  94. public init(mimeType: String, uri: String) {
  95. self.mimeType = mimeType
  96. self.uri = uri
  97. }
  98. }
  99. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  100. public struct FileDataPart: ModelContent.Part {
  101. public let fileData: FileData
  102. public init(fileData: FileData) {
  103. self.fileData = fileData
  104. }
  105. }
  106. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  107. public struct FunctionCallPart: ModelContent.Part {
  108. public let functionCall: FunctionCall
  109. public init(functionCall: FunctionCall) {
  110. self.functionCall = functionCall
  111. }
  112. }
  113. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  114. public struct FunctionResponsePart: ModelContent.Part {
  115. public let functionResponse: FunctionResponse
  116. public init(functionResponse: FunctionResponse) {
  117. self.functionResponse = functionResponse
  118. }
  119. }
  120. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  121. struct ErrorPart: Error, ModelContent.Part, Equatable {}
  122. /// A type describing data in media formats interpretable by an AI model. Each generative AI
  123. /// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value
  124. /// may comprise multiple heterogeneous ``ModelContent/Part``s.
  125. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  126. public struct ModelContent: Equatable, Sendable {
  127. public protocol Part: PartsRepresentable, Codable, Sendable, Equatable {
  128. /// Returns the text contents of this ``Part``, if it contains text.
  129. var text: String? { get }
  130. }
  131. /// A discrete piece of data in a media format interpretable by an AI model. Within a single value
  132. /// of ``Part``, different data types may not mix.
  133. enum InternalPart: Equatable, Sendable {
  134. /// Text value.
  135. case text(String)
  136. /// Data with a specified media type. Not all media types may be supported by the AI model.
  137. case inlineData(mimetype: String, Data)
  138. /// File data stored in Cloud Storage for Firebase, referenced by URI.
  139. ///
  140. /// > Note: Supported media types depends on the model; see [media requirements
  141. /// > ](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements)
  142. /// > for details.
  143. ///
  144. /// - Parameters:
  145. /// - mimetype: The IANA standard MIME type of the uploaded file, for example, `"image/jpeg"`
  146. /// or `"video/mp4"`; see [media requirements
  147. /// ](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements)
  148. /// for supported values.
  149. /// - uri: The `"gs://"`-prefixed URI of the file in Cloud Storage for Firebase, for example,
  150. /// `"gs://bucket-name/path/image.jpg"`.
  151. case fileData(mimetype: String, uri: String)
  152. /// A predicted function call returned from the model.
  153. case functionCall(FunctionCall)
  154. /// A response to a function call.
  155. case functionResponse(FunctionResponse)
  156. // MARK: Convenience Initializers
  157. /// Convenience function for populating a Part with JPEG data.
  158. public static func jpeg(_ data: Data) -> Self {
  159. return .inlineData(mimetype: "image/jpeg", data)
  160. }
  161. /// Convenience function for populating a Part with PNG data.
  162. public static func png(_ data: Data) -> Self {
  163. return .inlineData(mimetype: "image/png", data)
  164. }
  165. }
  166. /// The role of the entity creating the ``ModelContent``. For user-generated client requests,
  167. /// for example, the role is `user`.
  168. public let role: String?
  169. /// The data parts comprising this ``ModelContent`` value.
  170. public var parts: [any Part] {
  171. var convertedParts = [any ModelContent.Part]()
  172. for part in internalParts {
  173. switch part {
  174. case let .text(text):
  175. convertedParts.append(TextPart(textValue: text))
  176. case let .inlineData(mimetype, data):
  177. convertedParts
  178. .append(InlineDataPart(inlineData: InlineData(mimeType: mimetype, data: data)))
  179. case let .fileData(mimetype, uri):
  180. convertedParts.append(FileDataPart(fileData: FileData(mimeType: mimetype, uri: uri)))
  181. case let .functionCall(functionCall):
  182. convertedParts.append(FunctionCallPart(functionCall: functionCall))
  183. case let .functionResponse(functionResponse):
  184. convertedParts.append(FunctionResponsePart(functionResponse: functionResponse))
  185. }
  186. }
  187. return convertedParts
  188. }
  189. // TODO: Refactor this
  190. let internalParts: [InternalPart]
  191. /// Creates a new value from any data or `Array` of data interpretable as a
  192. /// ``Part``. See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
  193. public init(role: String? = "user", parts: some PartsRepresentable) {
  194. self.role = role
  195. var convertedParts = [InternalPart]()
  196. for part in parts.partsValue {
  197. switch part {
  198. case let textPart as TextPart:
  199. convertedParts.append(.text(textPart.textValue))
  200. case let inlineDataPart as InlineDataPart:
  201. let inlineData = inlineDataPart.inlineData
  202. convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
  203. case let fileDataPart as FileDataPart:
  204. let fileData = fileDataPart.fileData
  205. convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri))
  206. case let functionCallPart as FunctionCallPart:
  207. convertedParts.append(.functionCall(functionCallPart.functionCall))
  208. case let functionResponsePart as FunctionResponsePart:
  209. convertedParts.append(.functionResponse(functionResponsePart.functionResponse))
  210. default:
  211. fatalError()
  212. }
  213. }
  214. internalParts = convertedParts
  215. }
  216. /// Creates a new value from a list of ``Part``s.
  217. public init(role: String? = "user", parts: [any Part]) {
  218. self.role = role
  219. var convertedParts = [InternalPart]()
  220. for part in parts {
  221. switch part {
  222. case let textPart as TextPart:
  223. convertedParts.append(.text(textPart.textValue))
  224. case let inlineDataPart as InlineDataPart:
  225. let inlineData = inlineDataPart.inlineData
  226. convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
  227. case let fileDataPart as FileDataPart:
  228. let fileData = fileDataPart.fileData
  229. convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.uri))
  230. case let functionCallPart as FunctionCallPart:
  231. convertedParts.append(.functionCall(functionCallPart.functionCall))
  232. case let functionResponsePart as FunctionResponsePart:
  233. convertedParts.append(.functionResponse(functionResponsePart.functionResponse))
  234. default:
  235. fatalError()
  236. }
  237. }
  238. internalParts = convertedParts
  239. }
  240. public init(role: String? = "user", @PartsBuilder parts: () -> [any Part]) {
  241. self.init(role: role, parts: parts())
  242. }
  243. /// Creates a new value from any data interpretable as a ``Part``.
  244. /// See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
  245. public init(role: String? = "user", _ parts: any PartsRepresentable...) {
  246. let content = parts.flatMap { $0.partsValue }
  247. self.init(role: role, parts: content)
  248. }
  249. /// Creates a new value from any data interpretable as a ``Part``.
  250. /// See ``PartsRepresentable``for types that can be interpreted as `Part`s.
  251. public init(role: String? = "user", _ parts: [PartsRepresentable]) {
  252. let content = parts.flatMap { $0.partsValue }
  253. self.init(role: role, parts: content)
  254. }
  255. }
  256. // MARK: Codable Conformances
  257. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  258. extension ModelContent: Codable {
  259. enum CodingKeys: String, CodingKey {
  260. case role
  261. case internalParts = "parts"
  262. }
  263. }
  264. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
  265. extension ModelContent.InternalPart: Codable {
  266. enum CodingKeys: String, CodingKey {
  267. case text
  268. case inlineData
  269. case fileData
  270. case functionCall
  271. case functionResponse
  272. }
  273. enum InlineDataKeys: String, CodingKey {
  274. case mimeType = "mime_type"
  275. case bytes = "data"
  276. }
  277. public func encode(to encoder: Encoder) throws {
  278. var container = encoder.container(keyedBy: CodingKeys.self)
  279. switch self {
  280. case let .text(a0):
  281. try container.encode(a0, forKey: .text)
  282. case let .inlineData(mimetype, bytes):
  283. var inlineDataContainer = container.nestedContainer(
  284. keyedBy: InlineDataKeys.self,
  285. forKey: .inlineData
  286. )
  287. try inlineDataContainer.encode(mimetype, forKey: .mimeType)
  288. try inlineDataContainer.encode(bytes, forKey: .bytes)
  289. case let .fileData(mimetype: mimetype, url):
  290. // var fileDataContainer = container.nestedContainer(
  291. // keyedBy: FileDataKeys.self,
  292. // forKey: .fileData
  293. // )
  294. try container.encode(FileData(mimeType: mimetype, uri: url), forKey: .fileData)
  295. // try fileDataContainer.encode(mimetype, forKey: .mimeType)
  296. // try fileDataContainer.encode(url, forKey: .uri)
  297. case let .functionCall(functionCall):
  298. try container.encode(functionCall, forKey: .functionCall)
  299. case let .functionResponse(functionResponse):
  300. try container.encode(functionResponse, forKey: .functionResponse)
  301. }
  302. }
  303. public init(from decoder: Decoder) throws {
  304. let values = try decoder.container(keyedBy: CodingKeys.self)
  305. if values.contains(.text) {
  306. self = try .text(values.decode(String.self, forKey: .text))
  307. } else if values.contains(.inlineData) {
  308. let dataContainer = try values.nestedContainer(
  309. keyedBy: InlineDataKeys.self,
  310. forKey: .inlineData
  311. )
  312. let mimetype = try dataContainer.decode(String.self, forKey: .mimeType)
  313. let bytes = try dataContainer.decode(Data.self, forKey: .bytes)
  314. self = .inlineData(mimetype: mimetype, bytes)
  315. } else if values.contains(.functionCall) {
  316. self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
  317. } else if values.contains(.functionResponse) {
  318. self = try .functionResponse(values.decode(FunctionResponse.self, forKey: .functionResponse))
  319. } else {
  320. let unexpectedKeys = values.allKeys.map { $0.stringValue }
  321. throw DecodingError.dataCorrupted(DecodingError.Context(
  322. codingPath: values.codingPath,
  323. debugDescription: "Unexpected ModelContent.Part type(s): \(unexpectedKeys)"
  324. ))
  325. }
  326. }
  327. }