Downloader.swift 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import FirebaseMLModelDownloader
  16. class Downloader: ObservableObject {
  17. @Published var downloadProgress: Float = 0.0
  18. @Published var modelName = ""
  19. @Published var filePath = ""
  20. @Published var error = ""
  21. @Published var isDownloaded = false
  22. @Published var isDeleted = false
  23. @Published var isError = false
  24. @Published var modelNames = [String]()
  25. private func resetState() {
  26. isDownloaded = false
  27. isDeleted = false
  28. downloadProgress = 0.0
  29. filePath = ""
  30. error = ""
  31. isError = false
  32. modelNames = []
  33. }
  34. func downloadModelHelper(downloadType: ModelDownloadType) -> () -> Void {
  35. return {
  36. self.resetState()
  37. self.downloadModel(downloadType: downloadType)
  38. }
  39. }
  40. func downloadModel(downloadType: ModelDownloadType) {
  41. let modelDownloader = ModelDownloader.modelDownloader()
  42. let conditions = ModelDownloadConditions()
  43. modelDownloader.getModel(
  44. name: modelName,
  45. downloadType: downloadType,
  46. conditions: conditions,
  47. progressHandler: { progress in
  48. self.downloadProgress = progress
  49. }
  50. ) { result in
  51. switch result {
  52. case let .success(model):
  53. self.isDownloaded = true
  54. self.filePath = model.path
  55. case let .failure(error):
  56. self.isDownloaded = false
  57. self.isError = true
  58. self.error = "Model download failed with error: \(error)"
  59. }
  60. }
  61. }
  62. func deleteModelHelper() -> () -> Void {
  63. return {
  64. self.resetState()
  65. self.deleteModel()
  66. }
  67. }
  68. func deleteModel() {
  69. let modelDownloader = ModelDownloader.modelDownloader()
  70. modelDownloader.deleteDownloadedModel(name: modelName) { result in
  71. switch result {
  72. case .success:
  73. self.isDeleted = true
  74. self.isDownloaded = false
  75. self.filePath = ""
  76. case let .failure(error):
  77. self.isDeleted = false
  78. self.isError = true
  79. self.error = "Model deletion failed with error: \(error)"
  80. }
  81. }
  82. }
  83. func listModelHelper() -> () -> Void {
  84. return {
  85. self.resetState()
  86. self.listModel()
  87. }
  88. }
  89. func listModel() {
  90. let modelDownloader = ModelDownloader.modelDownloader()
  91. modelDownloader.listDownloadedModels { result in
  92. switch result {
  93. case let .success(models):
  94. if models.count == 0 {
  95. self.isError = true
  96. self.error = "No models found on device."
  97. } else {
  98. for model in models {
  99. self.modelNames.append(model.name)
  100. }
  101. }
  102. case let .failure(error):
  103. self.isError = true
  104. self.error = "Listing models failed with error: \(error)"
  105. }
  106. }
  107. }
  108. }