TypeError: unsupported operand type(s) for *: 'Parameter' and 'NoneType'
misonsky opened this issue · 1 comments
misonsky commented
System Info
Adalora
def update_ipt(self, model):
# Update the sensitivity and uncertainty for every weight
for n, p in model.named_parameters():
if "lora_" in n and self.adapter_name in n:
if n not in self.ipt:
self.ipt[n] = torch.zeros_like(p)
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
self.exp_avg_unc[n] = (
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
)
When using adalora peft, the classification header layer includes:
base_model.model.classifier.original_module.dense.base_layer.weight
base_model.model.classifier.original_module.dense.base_layer.bias
base_model.model.classifier.original_module.dense.lora_A.default
base_model.model.classifier.original_module.dense.lora_B.default
base_model.model.classifier.original_module.dense.lora_E.default
base_model.model.classifier.original_module.dense.ranknum.default
base_model.model.classifier.original_module.out_proj.weight
base_model.model.classifier.original_module.out_proj.bias
base_model.model.classifier.modules_to_save.default.dense.base_layer.weight
base_model.model.classifier.modules_to_save.default.dense.base_layer.bias
base_model.model.classifier.modules_to_save.default.dense.lora_A.default
base_model.model.classifier.modules_to_save.default.dense.lora_B.default
base_model.model.classifier.modules_to_save.default.dense.lora_E.default
base_model.model.classifier.modules_to_save.default.dense.ranknum.default
base_model.model.classifier.modules_to_save.default.out_proj.weight
base_model.model.classifier.modules_to_save.default.out_proj.bias
But for layers
base_model.model.classifier.original_module.dense.lora_A.default
base_model.model.classifier.original_module.dense.lora_B.default
base_model.model.classifier.original_module.dense.lora_E.default
after checking, there is no gradient. In other words, the requires_grad
attribute is False, but the inclulde "lora_" string. I think gradient checking should be added to the update_ipt
function.
This error occurs when calling model.update_and_allocate(global_step).
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
This error occurs when calling model.update_and_allocate(global_step).
the config is:
peft_config = AdaLoraConfig(
peft_type="ADALORA",
task_type="SEQ_CLS",
r=rank,
lora_alpha=32,
lora_dropout=0.01)
the model is RoBERTa.
Expected behavior
I think gradient checking should be added to the update_ipt
function.
def update_ipt(self, model):
# Update the sensitivity and uncertainty for every weight
for n, p in model.named_parameters():
if not p.requires_grad: continue
if "lora_" in n and self.adapter_name in n:
if n not in self.ipt:
self.ipt[n] = torch.zeros_like(p)
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
self.exp_avg_unc[n] = (
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
)
BenjaminBossan commented
Thanks for reporting. Could you please paste the full error message? Also, do you have a reproducer or are you using one of the examples from PEFT?