BlinkDL/RWKV-LM

KeyError: "attribute 'weight' already exists"

ByUnal opened this issue · 1 comments

I'm trying to train RWKV/rwkv-4-world-430m with LoRA by using Transformer's Trainer module. I've performed chunking to my data (chunk size = 128), and I initiated the training. Training continues properly, but it throws an error at the end of the epoch just before evaluating.
The error is: KeyError: "attribute 'weight' already exists"

Here is my training code and full error:

Code:

from transformers import Trainer, TrainingArguments
from datetime import datetime


TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 32
LEARNING_RATE = 5e-5
LR_WARMUP_STEPS = 100
WEIGHT_DECAY = 1e-4
EPOCH = 2

steps_per_epoch = int(len(train_dataset) / TRAIN_BATCH_SIZE)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


training_args = TrainingArguments(
    output_dir='./rwkv-output',
    overwrite_output_dir=True,
    logging_dir='./RWKVlogs',
    logging_steps = steps_per_epoch, # When to start reporting loss
    num_train_epochs = EPOCH,
    do_train=True,
    do_eval=True,
    bf16=True,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    warmup_steps=LR_WARMUP_STEPS,
    optim="paged_adamw_8bit",
    save_total_limit=1,
    weight_decay=WEIGHT_DECAY,
    learning_rate=LEARNING_RATE,
    evaluation_strategy='epoch',
    eval_steps=steps_per_epoch, 
    save_strategy='epoch',
    save_steps=steps_per_epoch,
    load_best_model_at_end=True,
    metric_for_best_model='loss',
    greater_is_better=False,
    seed=SEED_TRAIN,
    report_to = "none",
    gradient_accumulation_steps=1,
    # run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" 
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=lm_train_datasets,
    eval_dataset=lm_valid_datasets,
    # train_dataset=train_dataset,
    # eval_dataset=valid_dataset,
    tokenizer=tokenizer
    )

trainer.train()

Full of error:

KeyError                                  Traceback (most recent call last)
Cell In[52], line 57
     17 training_args = TrainingArguments(
     18     output_dir='./rwkv-output',
     19     overwrite_output_dir=True,
   (...)
     43     # run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" 
     44 )
     46 trainer = Trainer(
     47     model=peft_model,
     48     args=training_args,
   (...)
     54     tokenizer=tokenizer
     55     )
---> 57 trainer.train()

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:1627, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1625         hf_hub_utils.enable_progress_bars()
   1626 else:
-> 1627     return inner_training_loop(
   1628         args=args,
   1629         resume_from_checkpoint=resume_from_checkpoint,
   1630         trial=trial,
   1631         ignore_keys_for_eval=ignore_keys_for_eval,
   1632     )

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:2052, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2049     self.control.should_training_stop = True
   2051 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2052 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2054 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   2055     if is_torch_tpu_available():
   2056         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:2415, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2413 metrics = None
   2414 if self.control.should_evaluate:
-> 2415     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2416     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2418     # Run delayed LR scheduler now that metrics are populated

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:3232, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3229 start_time = time.time()
   3231 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3232 output = eval_loop(
   3233     eval_dataloader,
   3234     description="Evaluation",
   3235     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3236     # self.args.prediction_loss_only
   3237     prediction_loss_only=True if self.compute_metrics is None else None,
   3238     ignore_keys=ignore_keys,
   3239     metric_key_prefix=metric_key_prefix,
   3240 )
   3242 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3243 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:3421, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3418         batch_size = observed_batch_size
   3420 # Prediction step
-> 3421 loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3422 main_input_name = getattr(self.model, "main_input_name", "input_ids")
   3423 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:3638, in Trainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
   3636 if has_labels or loss_without_labels:
   3637     with self.compute_loss_context_manager():
-> 3638         loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
   3639     loss = loss.mean().detach()
   3641     if isinstance(outputs, dict):

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/trainer.py:2928, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2926 else:
   2927     labels = None
-> 2928 outputs = model(**inputs)
   2929 # Save past state if it exists
   2930 # TODO: this needs to be fixed and made cleaner later.
   2931 if self.args.past_index >= 0:

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    821 def forward(*args, **kwargs):
--> 822     return model_forward(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    809 def __call__(self, *args, **kwargs):
--> 810     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    821 def forward(*args, **kwargs):
--> 822     return model_forward(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    809 def __call__(self, *args, **kwargs):
--> 810     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/utils/operations.py:822, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    821 def forward(*args, **kwargs):
--> 822     return model_forward(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/utils/operations.py:810, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    809 def __call__(self, *args, **kwargs):
--> 810     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/peft/peft_model.py:1083, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1081     if peft_config.peft_type == PeftType.POLY:
   1082         kwargs["task_ids"] = task_ids
-> 1083     return self.base_model(
   1084         input_ids=input_ids,
   1085         attention_mask=attention_mask,
   1086         inputs_embeds=inputs_embeds,
   1087         labels=labels,
   1088         output_attentions=output_attentions,
   1089         output_hidden_states=output_hidden_states,
   1090         return_dict=return_dict,
   1091         **kwargs,
   1092     )
   1094 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1095 if attention_mask is not None:
   1096     # concat prompt attention mask

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/models/rwkv/modeling_rwkv.py:839, in RwkvForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, state, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    831 r"""
    832 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
    833     Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
    834     `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
    835     are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
    836 """
    837 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 839 rwkv_outputs = self.rwkv(
    840     input_ids,
    841     inputs_embeds=inputs_embeds,
    842     state=state,
    843     use_cache=use_cache,
    844     output_attentions=output_attentions,
    845     output_hidden_states=output_hidden_states,
    846     return_dict=return_dict,
    847 )
    848 hidden_states = rwkv_outputs[0]
    850 logits = self.head(hidden_states)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/models/rwkv/modeling_rwkv.py:642, in RwkvModel.forward(self, input_ids, attention_mask, inputs_embeds, state, use_cache, output_attentions, output_hidden_states, return_dict)
    639 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    641 if self.training == self.layers_are_rescaled:
--> 642     self._rescale_layers()
    644 if input_ids is not None and inputs_embeds is not None:
    645     raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/models/rwkv/modeling_rwkv.py:728, in RwkvModel._rescale_layers(self)
    726 elif hasattr(block.attention.output.weight, "quant_state"):
    727     self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
--> 728     self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
    729 else:
    730     block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/transformers/models/rwkv/modeling_rwkv.py:754, in RwkvModel._bnb_4bit_dequantize_and_rescale(self, target_layer, block_id)
    748 # re-quantize the model:
    749 # we need to put it first on CPU then back to the device
    750 # this will create an overhead :/
    751 # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
    752 # bugs with bnb
    753 quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
--> 754 setattr(target_layer, "weight", quant_weight)

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:1705, in Module.__setattr__(self, name, value)
   1702         raise AttributeError(
   1703             "cannot assign parameters before Module.__init__() call")
   1704     remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
-> 1705     self.register_parameter(name, value)
   1706 elif params is not None and name in params:
   1707     if value is not None:

File ~/anaconda3/envs/cht/lib/python3.10/site-packages/torch/nn/modules/module.py:575, in Module.register_parameter(self, name, param)
    573     raise KeyError("parameter name can't be empty string \"\"")
    574 elif hasattr(self, name) and name not in self._parameters:
--> 575     raise KeyError(f"attribute '{name}' already exists")
    577 if param is None:
    578     self._parameters[name] = None

KeyError: "attribute 'weight' already exists"

Please help me on this. Could be related with LoRa ?

I've found that the problem here is Quantization:

import torch
from transformers import BitsAndBytesConfig

config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    # llm_int8_has_fp16_weight = True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             trust_remote_code=True,
                                             quantization_config=config,
                                             device_map="auto"
                                            )

Model gives the error above, but it works without quantization.