EricLBuehler/candle-lora

Is there any way to save lora-converted model?

Adamska1008 opened this issue · 5 comments

I tried to fine tune TinyLlama with this crate. I use candle-lora/candle-lora-transformers/examples/llama.rs to load model.safetensors, do stuff about training, eventually find that there's no way to save the model in safetensors format.

I tried to implement a save method myself wrapping candle_core::safetensors::save(), but how can I get the weight of lora part? All I can get is the raw model before it converted to lora model.

For example, if you run /candle-lora-macro/examples/linear.rs, by println!("{:?}", model.a); you will see it printed as Linear struct, not a LoraLinear struct, and you can't get ff_aff_b from model.a, despite that the model is converted to a lora model.

This is implemented/fixed in #13 which has been merged. Please note that the weight naming is incompatible with peft at the moment. If this is a problem, please feel free to raise an issue and I will fix it

This is implemented/fixed in #13 which has been merged. Please note that the weight naming is incompatible with peft at the moment. If this is a problem, please feel free to raise an issue and I will fix it

Thank you very much! I tried this and get a 536KB safetensors file with header:

{"lora_llamaa0.weight":{"data_offsets":[0,512000],"dtype":"F16","shape":[8,32000]},"lora_llamab0.weight":{"data_offsets":[512000,544768],"dtype":"F16","shape":[2048,8]}}

Is it as expected? I also want to know how to apply the Lora tensors after loading a VarBuilder from original model.

No, the prefix was incorrect but it should be fixed now. To load the Lora tensors, pass get_lora_model the VarBuilder returned by from_mmaped_safetensors. Here is an example of loading the VarBuilder:

let vb = from_mmaped_safetensors(&filenames, dtype, &device, false)?;

That vb is then passed to get_lora_model:

if merge {
this.get_merged_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
} else {
this.get_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
}

No, the prefix was incorrect but it should be fixed now. To load the Lora tensors, pass get_lora_model the VarBuilder returned by from_mmaped_safetensors. Here is an example of loading the VarBuilder:

let vb = from_mmaped_safetensors(&filenames, dtype, &device, false)?;

That vb is then passed to get_lora_model:

if merge {
this.get_merged_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
} else {
this.get_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
}

Really helpful, thanks again!

Glad to help!