| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- // Copyright 2020 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- import Foundation
- import TensorFlowLite
- import FirebaseMLModelDownloader
- class Downloader: ObservableObject {
- @Published var downloadProgress: Float = 0.0
- @Published var modelName = ""
- @Published var filePath = ""
- @Published var error = ""
- @Published var isDownloaded = false
- @Published var isDeleted = false
- @Published var isError = false
- @Published var modelNames = [String]()
- private func resetState() {
- isDownloaded = false
- isDeleted = false
- downloadProgress = 0.0
- filePath = ""
- error = ""
- isError = false
- modelNames = []
- }
- func downloadModelHelper(downloadType: ModelDownloadType) -> () -> Void {
- return {
- self.resetState()
- self.downloadModel(downloadType: downloadType)
- }
- }
- func downloadModel(downloadType: ModelDownloadType) {
- let modelDownloader = ModelDownloader.modelDownloader()
- let conditions = ModelDownloadConditions()
- modelDownloader.getModel(
- name: modelName,
- downloadType: downloadType,
- conditions: conditions,
- progressHandler: { progress in
- self.downloadProgress = progress
- }
- ) { result in
- switch result {
- case let .success(model):
- self.isDownloaded = true
- self.filePath = model.path
- let fileURL = URL(fileURLWithPath: self.filePath)
- do {
- _ = try fileURL.checkResourceIsReachable()
- let attr = try FileManager.default.attributesOfItem(atPath: self.filePath)
- if let size = attr[FileAttributeKey.size] {
- print("File size: \(size)")
- } else {
- print("Error - could not get file size.")
- }
- } catch {
- print("File access error - \(error)")
- }
- do {
- _ = try Interpreter(modelPath: self.filePath)
- } catch {
- print("Tensorflow error - \(error)")
- }
- case let .failure(error):
- self.isDownloaded = false
- self.isError = true
- self.error = "Model download failed with error: \(error)"
- }
- }
- }
- func deleteModelHelper() -> () -> Void {
- return {
- self.resetState()
- self.deleteModel()
- }
- }
- func deleteModel() {
- let modelDownloader = ModelDownloader.modelDownloader()
- modelDownloader.deleteDownloadedModel(name: modelName) { result in
- switch result {
- case .success:
- self.isDeleted = true
- self.isDownloaded = false
- self.filePath = ""
- case let .failure(error):
- self.isDeleted = false
- self.isError = true
- self.error = "Model deletion failed with error: \(error)"
- }
- }
- }
- func listModelHelper() -> () -> Void {
- return {
- self.resetState()
- self.listModel()
- }
- }
- func listModel() {
- let modelDownloader = ModelDownloader.modelDownloader()
- modelDownloader.listDownloadedModels { result in
- switch result {
- case let .success(models):
- if models.count == 0 {
- self.isError = true
- self.error = "No models found on device."
- } else {
- for model in models {
- self.modelNames.append(model.name)
- }
- }
- case let .failure(error):
- self.isError = true
- self.error = "Listing models failed with error: \(error)"
- }
- }
- }
- }
|