MultimodalOutputViewModel.swift 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import FirebaseVertexAI
  15. import Foundation
  16. import OSLog
  17. import PhotosUI
  18. import SwiftUI
  19. @MainActor
  20. class MultimodalOutputViewModel: ObservableObject {
  21. private var logger = Logger(subsystem: Bundle.main.bundleIdentifier!, category: "generative-ai")
  22. @Published
  23. var userInput: String = ""
  24. @Published
  25. var outputText: String? = nil
  26. @Published
  27. var errorMessage: String?
  28. @Published
  29. var inProgress = false
  30. @Published
  31. var responseModalities: [OutputContentModality]?
  32. private var model: GenerativeModel?
  33. init() {
  34. let config = GenerationConfig(responseModalities: [.text, .image, .audio])
  35. model = VertexAI.vertexAI().generativeModel(modelName: "gemini-2.0-flash-001", generationConfig: config)
  36. }
  37. func generate() async {
  38. defer {
  39. inProgress = false
  40. }
  41. guard let model else {
  42. return
  43. }
  44. do {
  45. inProgress = true
  46. errorMessage = nil
  47. outputText = ""
  48. let prompt = userInput
  49. let outputContentStream = try model.generateContentStream(prompt)
  50. for try await outputContent in outputContentStream {
  51. guard let line = outputContent.text else {
  52. return
  53. }
  54. outputText = (outputText ?? "") + line
  55. }
  56. } catch {
  57. logger.error("\(error.localizedDescription)")
  58. errorMessage = error.localizedDescription
  59. }
  60. }
  61. }