|
|
@@ -56,10 +56,10 @@ public class ModelDownloader {
|
|
|
/// DispatchQueue to manage download task dictionary.
|
|
|
let taskSerialQueue = DispatchQueue(label: "downloadtask.serial.queue")
|
|
|
|
|
|
- /// Handler that always runs on the main thread
|
|
|
- let mainQueueHandler = { handler in
|
|
|
+ /// Re-dispatch a function on the main queue.
|
|
|
+ func asyncOnMainQueue(_ work: @autoclosure @escaping () -> Void) {
|
|
|
DispatchQueue.main.async {
|
|
|
- handler
|
|
|
+ work()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -132,13 +132,18 @@ public class ModelDownloader {
|
|
|
conditions: ModelDownloadConditions,
|
|
|
progressHandler: ((Float) -> Void)? = nil,
|
|
|
completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
|
|
|
+ guard !modelName.isEmpty else {
|
|
|
+ asyncOnMainQueue(completion(.failure(.emptyModelName)))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
switch downloadType {
|
|
|
case .localModel:
|
|
|
if let localModel = getLocalModel(modelName: modelName) {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.DebugDescription.localModelFound,
|
|
|
messageCode: .localModelFound)
|
|
|
- mainQueueHandler(completion(.success(localModel)))
|
|
|
+ asyncOnMainQueue(completion(.success(localModel)))
|
|
|
} else {
|
|
|
getRemoteModel(
|
|
|
modelName: modelName,
|
|
|
@@ -153,7 +158,7 @@ public class ModelDownloader {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.DebugDescription.localModelFound,
|
|
|
messageCode: .localModelFound)
|
|
|
- mainQueueHandler(completion(.success(localModel)))
|
|
|
+ asyncOnMainQueue(completion(.success(localModel)))
|
|
|
telemetryLogger?.logModelDownloadEvent(
|
|
|
eventName: .modelDownload,
|
|
|
status: .scheduled,
|
|
|
@@ -225,7 +230,7 @@ public class ModelDownloader {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: description,
|
|
|
messageCode: .modelNameParseError)
|
|
|
- mainQueueHandler(completion(.failure(.internalError(description: description))))
|
|
|
+ asyncOnMainQueue(completion(.failure(.internalError(description: description))))
|
|
|
return
|
|
|
}
|
|
|
// Check if model information corresponding to model is stored in UserDefaults.
|
|
|
@@ -234,7 +239,7 @@ public class ModelDownloader {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: description,
|
|
|
messageCode: .noLocalModelInfo)
|
|
|
- mainQueueHandler(completion(.failure(.internalError(description: description))))
|
|
|
+ asyncOnMainQueue(completion(.failure(.internalError(description: description))))
|
|
|
return
|
|
|
}
|
|
|
// Ensure that local model path is as expected, and reachable.
|
|
|
@@ -242,11 +247,11 @@ public class ModelDownloader {
|
|
|
appName: appName,
|
|
|
modelName: modelName
|
|
|
),
|
|
|
- url == modelURL, ModelFileManager.isFileReachable(at: modelURL) else {
|
|
|
+ ModelFileManager.isFileReachable(at: modelURL) else {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.ErrorDescription.outdatedModelPath,
|
|
|
messageCode: .outdatedModelPathError)
|
|
|
- mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
|
|
|
+ asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
|
|
|
.ErrorDescription.outdatedModelPath))))
|
|
|
return
|
|
|
}
|
|
|
@@ -262,12 +267,12 @@ public class ModelDownloader {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.ErrorDescription.listModelsFailed(error),
|
|
|
messageCode: .listModelsError)
|
|
|
- mainQueueHandler(completion(.failure(error)))
|
|
|
+ asyncOnMainQueue(completion(.failure(error)))
|
|
|
} catch {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.ErrorDescription.listModelsFailed(error),
|
|
|
messageCode: .listModelsError)
|
|
|
- mainQueueHandler(completion(.failure(.internalError(description: error
|
|
|
+ asyncOnMainQueue(completion(.failure(.internalError(description: error
|
|
|
.localizedDescription))))
|
|
|
}
|
|
|
}
|
|
|
@@ -290,7 +295,7 @@ public class ModelDownloader {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.ErrorDescription.modelNotFound(modelName),
|
|
|
messageCode: .modelNotFound)
|
|
|
- mainQueueHandler(completion(.failure(.notFound)))
|
|
|
+ asyncOnMainQueue(completion(.failure(.notFound)))
|
|
|
return
|
|
|
}
|
|
|
do {
|
|
|
@@ -305,7 +310,7 @@ public class ModelDownloader {
|
|
|
eventName: .remoteModelDeleteOnDevice,
|
|
|
isSuccessful: true
|
|
|
)
|
|
|
- mainQueueHandler(completion(.success(())))
|
|
|
+ asyncOnMainQueue(completion(.success(())))
|
|
|
} catch let error as DownloadedModelError {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
|
|
|
@@ -314,7 +319,7 @@ public class ModelDownloader {
|
|
|
eventName: .remoteModelDeleteOnDevice,
|
|
|
isSuccessful: false
|
|
|
)
|
|
|
- mainQueueHandler(completion(.failure(error)))
|
|
|
+ asyncOnMainQueue(completion(.failure(error)))
|
|
|
} catch {
|
|
|
DeviceLogger.logEvent(level: .debug,
|
|
|
message: ModelDownloader.ErrorDescription.modelDeletionFailed(error),
|
|
|
@@ -323,7 +328,7 @@ public class ModelDownloader {
|
|
|
eventName: .remoteModelDeleteOnDevice,
|
|
|
isSuccessful: false
|
|
|
)
|
|
|
- mainQueueHandler(completion(.failure(.internalError(description: error
|
|
|
+ asyncOnMainQueue(completion(.failure(.internalError(description: error
|
|
|
.localizedDescription))))
|
|
|
}
|
|
|
}
|
|
|
@@ -417,28 +422,28 @@ extension ModelDownloader {
|
|
|
// Progress handler for model file download.
|
|
|
let taskProgressHandler: ModelDownloadTask.ProgressHandler = { progress in
|
|
|
if let progressHandler = progressHandler {
|
|
|
- self.mainQueueHandler(progressHandler(progress))
|
|
|
+ self.asyncOnMainQueue(progressHandler(progress))
|
|
|
}
|
|
|
}
|
|
|
// Completion handler for model file download.
|
|
|
let taskCompletion: ModelDownloadTask.Completion = { result in
|
|
|
switch result {
|
|
|
case let .success(model):
|
|
|
- self.mainQueueHandler(completion(.success(model)))
|
|
|
+ self.asyncOnMainQueue(completion(.success(model)))
|
|
|
case let .failure(error):
|
|
|
switch error {
|
|
|
case .notFound:
|
|
|
- self.mainQueueHandler(completion(.failure(.notFound)))
|
|
|
+ self.asyncOnMainQueue(completion(.failure(.notFound)))
|
|
|
case .invalidArgument:
|
|
|
- self.mainQueueHandler(completion(.failure(.invalidArgument)))
|
|
|
+ self.asyncOnMainQueue(completion(.failure(.invalidArgument)))
|
|
|
case .permissionDenied:
|
|
|
- self.mainQueueHandler(completion(.failure(.permissionDenied)))
|
|
|
+ self.asyncOnMainQueue(completion(.failure(.permissionDenied)))
|
|
|
// This is the error returned when model download URL has expired.
|
|
|
case .expiredDownloadURL:
|
|
|
// Retry model info and model file download, if allowed.
|
|
|
guard self.numberOfRetries > 0 else {
|
|
|
self
|
|
|
- .mainQueueHandler(
|
|
|
+ .asyncOnMainQueue(
|
|
|
completion(.failure(.internalError(description: ModelDownloader
|
|
|
.ErrorDescription
|
|
|
.expiredModelInfo)))
|
|
|
@@ -458,7 +463,7 @@ extension ModelDownloader {
|
|
|
completion: completion
|
|
|
)
|
|
|
default:
|
|
|
- self.mainQueueHandler(completion(.failure(error)))
|
|
|
+ self.asyncOnMainQueue(completion(.failure(error)))
|
|
|
}
|
|
|
}
|
|
|
self.taskSerialQueue.async {
|
|
|
@@ -502,15 +507,15 @@ extension ModelDownloader {
|
|
|
guard let localModel = self.getLocalModel(modelName: modelName) else {
|
|
|
// This can only happen if either local model info or the model file was wiped out after model info request but before server response.
|
|
|
self
|
|
|
- .mainQueueHandler(completion(.failure(.internalError(description: ModelDownloader
|
|
|
+ .asyncOnMainQueue(completion(.failure(.internalError(description: ModelDownloader
|
|
|
.ErrorDescription.deletedLocalModelInfoOrFile))))
|
|
|
return
|
|
|
}
|
|
|
- self.mainQueueHandler(completion(.success(localModel)))
|
|
|
+ self.asyncOnMainQueue(completion(.success(localModel)))
|
|
|
}
|
|
|
// Error retrieving model info.
|
|
|
case let .failure(error):
|
|
|
- self.mainQueueHandler(completion(.failure(error)))
|
|
|
+ self.asyncOnMainQueue(completion(.failure(error)))
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -530,6 +535,8 @@ public enum DownloadError: Error, Equatable {
|
|
|
case notEnoughSpace
|
|
|
/// Malformed model name or Firebase app options.
|
|
|
case invalidArgument
|
|
|
+ /// Model name is empty.
|
|
|
+ case emptyModelName
|
|
|
/// Other errors with description.
|
|
|
case internalError(description: String)
|
|
|
}
|