Downloader.swift 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. import Foundation
  15. import TensorFlowLite
  16. import FirebaseMLModelDownloader
  17. class Downloader: ObservableObject {
  18. @Published var downloadProgress: Float = 0.0
  19. @Published var modelName = ""
  20. @Published var filePath = ""
  21. @Published var error = ""
  22. @Published var isDownloaded = false
  23. @Published var isDeleted = false
  24. @Published var isError = false
  25. @Published var modelNames = [String]()
  26. private func resetState() {
  27. isDownloaded = false
  28. isDeleted = false
  29. downloadProgress = 0.0
  30. filePath = ""
  31. error = ""
  32. isError = false
  33. modelNames = []
  34. }
  35. func downloadModelHelper(downloadType: ModelDownloadType) -> () -> Void {
  36. return {
  37. self.resetState()
  38. self.downloadModel(downloadType: downloadType)
  39. }
  40. }
  41. func downloadModel(downloadType: ModelDownloadType) {
  42. let modelDownloader = ModelDownloader.modelDownloader()
  43. let conditions = ModelDownloadConditions()
  44. modelDownloader.getModel(
  45. name: modelName,
  46. downloadType: downloadType,
  47. conditions: conditions,
  48. progressHandler: { progress in
  49. self.downloadProgress = progress
  50. }
  51. ) { result in
  52. switch result {
  53. case let .success(model):
  54. self.isDownloaded = true
  55. self.filePath = model.path
  56. let fileURL = URL(fileURLWithPath: self.filePath)
  57. do {
  58. _ = try fileURL.checkResourceIsReachable()
  59. let attr = try FileManager.default.attributesOfItem(atPath: self.filePath)
  60. if let size = attr[FileAttributeKey.size] {
  61. print("File size: \(size)")
  62. } else {
  63. print("Error - could not get file size.")
  64. }
  65. } catch {
  66. print("File access error - \(error)")
  67. }
  68. do {
  69. _ = try Interpreter(modelPath: self.filePath)
  70. } catch {
  71. print("Tensorflow error - \(error)")
  72. }
  73. case let .failure(error):
  74. self.isDownloaded = false
  75. self.isError = true
  76. self.error = "Model download failed with error: \(error)"
  77. }
  78. }
  79. }
  80. func deleteModelHelper() -> () -> Void {
  81. return {
  82. self.resetState()
  83. self.deleteModel()
  84. }
  85. }
  86. func deleteModel() {
  87. let modelDownloader = ModelDownloader.modelDownloader()
  88. modelDownloader.deleteDownloadedModel(name: modelName) { result in
  89. switch result {
  90. case .success:
  91. self.isDeleted = true
  92. self.isDownloaded = false
  93. self.filePath = ""
  94. case let .failure(error):
  95. self.isDeleted = false
  96. self.isError = true
  97. self.error = "Model deletion failed with error: \(error)"
  98. }
  99. }
  100. }
  101. func listModelHelper() -> () -> Void {
  102. return {
  103. self.resetState()
  104. self.listModel()
  105. }
  106. }
  107. func listModel() {
  108. let modelDownloader = ModelDownloader.modelDownloader()
  109. modelDownloader.listDownloadedModels { result in
  110. switch result {
  111. case let .success(models):
  112. if models.count == 0 {
  113. self.isError = true
  114. self.error = "No models found on device."
  115. } else {
  116. for model in models {
  117. self.modelNames.append(model.name)
  118. }
  119. }
  120. case let .failure(error):
  121. self.isError = true
  122. self.error = "Listing models failed with error: \(error)"
  123. }
  124. }
  125. }
  126. }