Xlora cannot reload model from last checkpoint by using trainer.train(resume_from_checkpoint="checkpp")
Opened this issue · 0 comments
System Info
Peft v0.13.2
Transformers v4.44.0
Accelerate v0.33.0
Who can help?
Since this relates to an interaction with PEFT and Xlora maybe @BenjaminBossan @EricLBuehler
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
Hi there,
I try to use trainer.train(resume_from_checkpoint=checkpoint_directory) or trainer.train(resume_from_checkpoint="YES") to reload(retrain) the model with Xlora from the last checkpoint, here is my training code and base model is chatglm4-9b:
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
EvalPrediction,
GenerationConfig,
PreTrainedTokenizer,
Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional
app = typer.Typer(pretty_exceptions_show_locals=False)
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
if output_ids is not None:
max_output_length = max(len(out) for out in output_ids)
if self.pad_to_multiple_of is not None:
max_output_length = (
(
max_output_length + self.pad_to_multiple_of - 1) //
self.pad_to_multiple_of * self.pad_to_multiple_of
)
for feature in features:
remainder = [self.tokenizer.pad_token_id] * (
max_output_length - len(feature['output_ids'])
)
if isinstance(feature['output_ids'], list):
feature['output_ids'] = feature['output_ids'] + remainder
else:
feature['output_ids'] = np.concatenate(
[feature['output_ids'], remainder]
).astype(np.int64)
return super().__call__(features, return_tensors)
class Seq2SeqTrainer(_Seq2SeqTrainer):
# Not Support for apex
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
self.accelerator.backward(loss)
detached_loss = loss.detach() / self.args.gradient_accumulation_steps
del inputs
torch.cuda.empty_cache()
return detached_loss
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, Any],
prediction_loss_only: bool,
ignore_keys=None,
**gen_kwargs,
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad(): # Ensure no gradient computation
if self.args.predict_with_generate:
output_ids = inputs.pop('output_ids')
input_ids = inputs['input_ids']
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
labels = output_ids
del inputs, input_ids, output_ids
torch.cuda.empty_cache()
return loss, generated_tokens, labels
@dc.dataclass
class DataConfig(object):
train_file: Optional[str] = None
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST],
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig(object):
data_config: DataConfig
max_input_length: int
max_output_length: int
combine: bool
freezeV: bool
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
)
peft_config: Optional[PeftConfig] = None
def __post_init__(self):
if not self.training_args.do_eval or self.data_config.val_file is None:
self.training_args.do_eval = False
self.training_args.evaluation_strategy = 'no'
self.data_config.val_file = None
else:
self.training_args.per_device_eval_batch_size = (
self.training_args.per_device_eval_batch_size
or self.training_args.per_device_train_batch_size
)
@classmethod
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
training_args = kwargs.get('training_args', None)
if training_args is not None and not isinstance(
training_args, Seq2SeqTrainingArguments
):
gen_config = training_args.get('generation_config')
if not isinstance(gen_config, GenerationConfig):
training_args['generation_config'] = GenerationConfig(
**gen_config
)
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
data_config = kwargs.get('data_config')
if not isinstance(data_config, DataConfig):
kwargs['data_config'] = DataConfig(**data_config)
peft_config = kwargs.get('peft_config', None)
if peft_config is not None and not isinstance(peft_config, PeftConfig):
kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
return cls(**kwargs)
@classmethod
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
path = Path(path)
parser = yaml.YAML(typ='safe', pure=True)
parser.indent(mapping=2, offset=2, sequence=4)
parser.default_flow_style = False
kwargs = parser.load(path)
return cls.from_dict(**kwargs)
def _load_datasets(
data_dir: str,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
) -> DatasetDict:
if data_format == '.jsonl':
dataset_dct = load_dataset(
data_dir,
data_files=data_files,
split=None,
num_proc=num_proc,
)
else:
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
data_dir,
data_config.data_format,
data_config.data_files,
self._num_proc,
)
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
)
def process_message(message):
if 'tools' in message and message['role'] == 'system':
for tool in message['tools']:
parameters = tool['function']['parameters']['properties']
tool['function']['parameters']['properties'] = \
{k: v for k, v in parameters.items() if
v is not None}
elif 'tools' in message:
del message['tools']
return message
def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_labels = []
for conv in batched_conv:
input_ids = [151331, 151333]
loss_masks = [False, False]
if combine:
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
input_ids = new_input_ids
loss_masks = [False] * len(input_ids)
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
for j in range(last_assistant_index + 1, len(input_ids)):
loss_masks[j] = True
else:
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
input_ids += new_input_ids
loss_masks += [loss_mask_val] * len(new_input_ids)
input_ids.append(151336) # EOS for chat
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
labels.append(input_id)
else:
labels.append(-100)
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'labels': batched_labels}
def process_batch_eval(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_output_ids = []
for conv in batched_conv:
if combine:
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
input_ids = new_input_ids
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
output_prompt, output_ids = (
input_ids[:1],
input_ids[last_assistant_index:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
else:
input_ids = [151331, 151333]
for message in conv:
if len(input_ids) >= max_input_length:
break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
def load_tokenizer_and_model(
model_dir: str,
peft_config: Optional[PeftConfig] = None,
):
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
if peft_config is not None:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16 # Must use BFloat 16
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16
)
return tokenizer, model
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_pred_ids, batched_label_ids = eval_preds
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
for k, v in scores[0].items():
metrics_dct[k].append(round(v['f'] * 100, 4))
metrics_dct['bleu-4'].append(
sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
return {k: np.mean(v) for k, v in metrics_dct.items()}
@app.command()
def main(
data_dir: Annotated[str, typer.Argument(help='')],
model_dir: Annotated[
str,
typer.Argument(
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
),
],
config_file: Annotated[str, typer.Argument(help='')],
auto_resume_from_checkpoint: str = typer.Argument(
default='',
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
),
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:', train_dataset)
val_dataset = data_manager.get_dataset(
Split.VALIDATION,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if val_dataset is not None:
print('val_dataset:', val_dataset)
test_dataset = data_manager.get_dataset(
Split.TEST,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if test_dataset is not None:
print('test_dataset:', test_dataset)
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()
ft_config.training_args.generation_config.pad_token_id = (
151329
)
ft_config.training_args.generation_config.eos_token_id = [
151329, 151336, 151338
]
trainer = Seq2SeqTrainer(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding='longest',
return_tensors='pt',
),
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
)
# trainer.train(resume_from_checkpoint="/home/zhangjunyi/hs_test/finetune_demo/output_1026/checkpoint-20")
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
trainer.train()
else:
output_dir = ft_config.training_args.output_dir
dirlist = os.listdir(output_dir)
checkpoint_sn = 0
for checkpoint_str in dirlist:
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
if checkpoint > checkpoint_sn:
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint.upper() == "YES":
if checkpoint_sn > 0:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
else:
if auto_resume_from_checkpoint.isdigit():
if int(auto_resume_from_checkpoint) > 0:
checkpoint_sn = int(auto_resume_from_checkpoint)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
print(auto_resume_from_checkpoint,
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
if test_dataset is not None:
trainer.predict(test_dataset)
if __name__ == '__main__':
app()
Expected behavior
I expect the training to resume from the last checkpoint, continuing from the saved state without reinitializing the model weights, optimizer states, or any training progress by using resume_from_checkpoint. When training with lora, I successfully saved several checkpoints and could resume training from these checkpoints without any issues. However, with xlora, although I am able to save checkpoints during training, I encounter issues when trying to resume training from these checkpoints. The model fails to load properly, preventing the continuation of training from the saved state. Here is bug information:
Loading checkpoint shards: 0%| | 0/10 [00:00<?, ?it/s]
Loading checkpoint shards: 10%|█ | 1/10 [00:00<00:01, 6.86it/s]
Loading checkpoint shards: 20%|██ | 2/10 [00:00<00:01, 6.89it/s]
Loading checkpoint shards: 30%|███ | 3/10 [00:00<00:01, 6.90it/s]
Loading checkpoint shards: 40%|████ | 4/10 [00:00<00:00, 6.43it/s]
Loading checkpoint shards: 50%|█████ | 5/10 [00:00<00:00, 6.61it/s]
Loading checkpoint shards: 60%|██████ | 6/10 [00:00<00:00, 6.72it/s]
Loading checkpoint shards: 70%|███████ | 7/10 [00:01<00:00, 6.80it/s]
Loading checkpoint shards: 80%|████████ | 8/10 [00:01<00:00, 6.85it/s]
Loading checkpoint shards: 90%|█████████ | 9/10 [00:01<00:00, 6.88it/s]
Loading checkpoint shards: 100%|██████████| 10/10 [00:01<00:00, 6.96it/s]
Loading checkpoint shards: 100%|██████████| 10/10 [00:01<00:00, 6.82it/s]
0%| | 0/2 [00:00<?, ?it/s]
50%|█████ | 1/2 [00:06<00:06, 6.33s/it]
100%|██████████| 2/2 [00:06<00:00, 3.22s/it]
Froze 160 adapters.
LoRA -> xLoRA complete: Swapped 40 LoRA layers (out of 971 modules).
trainable params: 67,145,732 || all params: 9,472,667,652 || trainable%: 0.7088
Map: 0%| | 0/14803 [00:00<?, ? examples/s]
Map: 7%|▋ | 1000/14803 [00:03<00:42, 327.62 examples/s]
Map: 14%|█▎ | 2000/14803 [00:05<00:36, 347.77 examples/s]
Map: 20%|██ | 3000/14803 [00:08<00:33, 356.42 examples/s]
Map: 27%|██▋ | 4000/14803 [00:11<00:29, 361.64 examples/s]
Map: 34%|███▍ | 5000/14803 [00:13<00:26, 363.21 examples/s]
Map: 41%|████ | 6000/14803 [00:16<00:24, 363.04 examples/s]
Map: 47%|████▋ | 7000/14803 [00:19<00:21, 363.56 examples/s]
Map: 54%|█████▍ | 8000/14803 [00:21<00:16, 413.31 examples/s]
Map: 61%|██████ | 9000/14803 [00:22<00:11, 504.23 examples/s]
Map: 68%|██████▊ | 10000/14803 [00:23<00:08, 597.65 examples/s]
Map: 74%|███████▍ | 11000/14803 [00:24<00:05, 681.07 examples/s]
Map: 81%|████████ | 12000/14803 [00:25<00:03, 754.44 examples/s]
Map: 88%|████████▊ | 13000/14803 [00:26<00:02, 809.81 examples/s]
Map: 95%|█████████▍| 14000/14803 [00:27<00:00, 856.43 examples/s]
Map: 100%|██████████| 14803/14803 [00:28<00:00, 890.57 examples/s]
Map: 100%|██████████| 14803/14803 [00:28<00:00, 528.34 examples/s]
train_dataset: Dataset({
features: ['input_ids', 'labels'],
num_rows: 14803
})
Map: 0%| | 0/2 [00:00<?, ? examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 187.78 examples/s]
val_dataset: Dataset({
features: ['input_ids', 'output_ids'],
num_rows: 2
})
Map: 0%| | 0/2 [00:00<?, ? examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 189.77 examples/s]
test_dataset: Dataset({
features: ['input_ids', 'output_ids'],
num_rows: 2
})
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs
resume checkpoint from checkpoint-20
Loading model from ./output_new/checkpoint-20.
Multiple active adapters detected will only consider the first adapter
[2024-10-15 18:47:30,968] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer.py:3098: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
[rank0]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank0]: │ /home/zhangjunyi/hs_test/finetune_demo/finetune.py:615 in main │
[rank0]: │ │
[rank0]: │ 612 │ │ │ │ model.enable_input_require_grads() │
[rank0]: │ 613 │ │ │ │ checkpoint_directory = os.path.join(output_dir, "check │
[rank0]: │ 614 │ │ │ │ print("resume checkpoint from checkpoint-" + str(check │
[rank0]: │ ❱ 615 │ │ │ │ trainer.train(resume_from_checkpoint=checkpoint_direct │
[rank0]: │ 616 │ │ │ else: │
[rank0]: │ 617 │ │ │ │ trainer.train() │
[rank0]: │ 618 │ │ else: │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer │
[rank0]: │ .py:1938 in train │
[rank0]: │ │
[rank0]: │ 1935 │ │ │ finally: │
[rank0]: │ 1936 │ │ │ │ hf_hub_utils.enable_progress_bars() │
[rank0]: │ 1937 │ │ else: │
[rank0]: │ ❱ 1938 │ │ │ return inner_training_loop( │
[rank0]: │ 1939 │ │ │ │ args=args, │
[rank0]: │ 1940 │ │ │ │ resume_from_checkpoint=resume_from_checkpoint, │
[rank0]: │ 1941 │ │ │ │ trial=trial, │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer │
[rank0]: │ .py:2126 in _inner_training_loop │
[rank0]: │ │
[rank0]: │ 2123 │ │ │ │ self._load_from_checkpoint(resume_from_checkpoint, se │
[rank0]: │ 2124 │ │ │
[rank0]: │ 2125 │ │ # Check if saved optimizer or scheduler states exist │
[rank0]: │ ❱ 2126 │ │ self._load_optimizer_and_scheduler(resume_from_checkpoint) │
[rank0]: │ 2127 │ │ │
[rank0]: │ 2128 │ │ # important: at this point: │
[rank0]: │ 2129 │ │ # self.model is the Transformers Model │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer │
[rank0]: │ .py:3097 in _load_optimizer_and_scheduler │
[rank0]: │ │
[rank0]: │ 3094 │ │ │ │ │ │ │ **_get_fsdp_ckpt_kwargs(), │
[rank0]: │ 3095 │ │ │ │ │ │ ) │
[rank0]: │ 3096 │ │ │ │ │ else: │
[rank0]: │ ❱ 3097 │ │ │ │ │ │ self.optimizer.load_state_dict( │
[rank0]: │ 3098 │ │ │ │ │ │ │ torch.load(os.path.join(checkpoint, OPTIM │
[rank0]: │ 3099 │ │ │ │ │ │ ) │
[rank0]: │ 3100 │ │ │ │ with warnings.catch_warnings(record=True) as caught_w │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/accelerate/optimizer │
[rank0]: │ .py:107 in load_state_dict │
[rank0]: │ │
[rank0]: │ 104 │ def load_state_dict(self, state_dict): │
[rank0]: │ 105 │ │ if self.accelerator_state.distributed_type == DistributedType. │
[rank0]: │ 106 │ │ │ xm.send_cpu_data_to_device(state_dict, self.accelerator_st │
[rank0]: │ ❱ 107 │ │ self.optimizer.load_state_dict(state_dict) │
[rank0]: │ 108 │ │
[rank0]: │ 109 │ def state_dict(self): │
[rank0]: │ 110 │ │ return self.optimizer.state_dict() │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/_compile.py:31 │
[rank0]: │ in inner │
[rank0]: │ │
[rank0]: │ 28 │ │ │ │ disable_fn = torch._dynamo.disable(fn, recursive) │
[rank0]: │ 29 │ │ │ │ fn.__dynamo_disable = disable_fn │
[rank0]: │ 30 │ │ │ │
[rank0]: │ ❱ 31 │ │ │ return disable_fn(*args, **kwargs) │
[rank0]: │ 32 │ │ │
[rank0]: │ 33 │ │ return inner │
[rank0]: │ 34 │ else: │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_f │
[rank0]: │ rame.py:600 in _fn │
[rank0]: │ │
[rank0]: │ 597 │ │ def _fn(*args, **kwargs): │
[rank0]: │ 598 │ │ │ prior = set_eval_frame(callback) │
[rank0]: │ 599 │ │ │ try: │
[rank0]: │ ❱ 600 │ │ │ │ return fn(*args, **kwargs) │
[rank0]: │ 601 │ │ │ finally: │
[rank0]: │ 602 │ │ │ │ set_eval_frame(prior) │
[rank0]: │ 603 │
[rank0]: │ │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/optim/optimize │
[rank0]: │ r.py:854 in load_state_dict │
[rank0]: │ │
[rank0]: │ 851 │ │ param_lens = (len(g["params"]) for g in groups) │
[rank0]: │ 852 │ │ saved_lens = (len(g["params"]) for g in saved_groups) │
[rank0]: │ 853 │ │ if any(p_len != s_len for p_len, s_len in zip(param_lens, sav │
[rank0]: │ ❱ 854 │ │ │ raise ValueError( │
[rank0]: │ 855 │ │ │ │ "loaded state dict contains a parameter group " │
[rank0]: │ 856 │ │ │ │ "that doesn't match the size of optimizer's group" │
[rank0]: │ 857 │ │ │ ) │
[rank0]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank0]: ValueError: loaded state dict contains a parameter group that doesn't match the
[rank0]: size of optimizer's group
E1015 18:47:35.719000 139827737793152 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 4075482) of binary: /home/zhangjunyi/anaconda3/bin/python
Traceback (most recent call last):
File "/home/zhangjunyi/anaconda3/bin/torchrun", line 8, in <module>
sys.exit(main())
^^^^^^
File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/run.py", line 901, in main
run(args)
File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in run
elastic_launch(
File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
finetune.py FAILED
I would greatly appreciate any guidance on resolving this issue with xlora checkpoint restoration. If anyone has encountered a similar problem or has insights into specific settings or steps to enable successful checkpoint recovery for xlora, your advice would be invaluable. Additionally, if any maintainers or community members familiar with xlora could offer support, that would be extremely helpful. Many thanks