Download progress is not accurate
DePasqualeOrg opened this issue · 2 comments
DePasqualeOrg commented
In Libraries/LLM/Load.swift
, the tokenizer is first loaded. As part of this process, config.json
is downloaded if it hasn't been downloaded already. Then that same file is included in modelFiles
(which are subsequently downloaded), and this small file contributes an equal amount to the download progress as a multi-gigabyte .safetensors
file. This causes the download progress to perceptibly start or end at 50% when the model consists of one file. Wouldn't it make sense not to include config.json
in modelFiles
?
/// Load and return the model and tokenizer
public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
do {
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
let modelDirectory: URL
switch configuration.id {
case .id(let id):
// download the model weights and config
let repo = Hub.Repo(id: id)
let modelFiles = ["config.json", "*.safetensors"]
modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
case .directory(let directory):
modelDirectory = directory
}
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
}
}
}
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
}
// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model)
return (model, tokenizer)
} catch Hub.HubClientError.authorizationRequired {
// an authorizationRequired means (typically) that the named repo doesn't exist on
// on the server so retry with local only configuration
var newConfiguration = configuration
newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
return try await load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
}
}
davidkoski commented
That seems like a reasonable idea -- do you want to make a PR for it?
DePasqualeOrg commented