open-mmlab/mmengine

ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. T

apachemycat opened this issue · 0 comments

Prerequisite

Environment

Pytorch 2.3 - If use_reentrant is not explicitly passed, an exception will now be raised

Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 360, in
main()
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 356, in main
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 287, in run
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 311, in run_iter
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 121, in train_step
losses = self._run_forward(data, mode='loss')
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 161, in _run_forward
results = self(**data, mode=mode)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 228, in forward
return self.compute_loss(data, data_samples)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 272, in compute_loss
return self._compute_sequence_parallel_loss(data)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 262, in _compute_sequence_parallel_loss
outputs = self.llm(**data)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1395, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 179, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1205, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 989, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 460, in checkpoint
raise ValueError(
ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

Reproduces the problem - code sample

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

training_args = dict(
type=TrainingArguments,
gradient_checkpointing=True, # Leads to reduction in memory at slighly decrease in speed
gradient_checkpointing_kwargs={"use_reentrant": False}
)
#######################################################################

PART 3 Dataset & Dataloader

#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path='json',data_files=data_files),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length,
use_varlen_attn=use_varlen_attn)

sampler = SequenceParallelSampler
if sequence_parallel_size > 1 else DefaultSampler
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=sampler, shuffle=True),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################

PART 4 Scheduler & Optimizer

#######################################################################

optimizer

optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

learning policy

More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501

auto_scale_lr=dict(base_batch_size=4, enable=True)
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
convert_to_iter_based=True)
]

train, val, test setting

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################

PART 5 Runtime

#######################################################################

Log the dialogue periodically during the training process, optional

#custom_hooks = [

dict(type=DatasetInfoHook, tokenizer=tokenizer),

dict(

type=EvaluateChatHook,

tokenizer=tokenizer,

every_n_iters=evaluation_freq,

evaluation_inputs=evaluation_inputs,

system=SYSTEM,

prompt_template=prompt_template)

#]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

configure default hooks

default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per save_steps.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

configure environment

env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

set visualizer

visualizer = dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')])

set log level

log_level = 'DEBUG'

load from which checkpoint

load_from = None

whether to resume training from the loaded checkpoint

resume = True

Defaults to use random seed and disable deterministic

randomness = dict(seed=None, deterministic=False)

set log processor

log_processor = dict(by_epoch=False)

Reproduces the problem - command or script

NCCL_DEBUG=INFO LOGLEVEL=DEBUG NPROC_PER_NODE=1 torchrun /usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py --launcher pytorch --work-dir ${work_dir} config.py

Reproduces the problem - error message

Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 360, in
main()
File "/usr/local/lib/python3.10/dist-packages/xtuner/tools/train.py", line 356, in main
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 287, in run
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 311, in run_iter
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 121, in train_step
losses = self._run_forward(data, mode='loss')
File "/usr/local/lib/python3.10/dist-packages/mmengine/model/wrappers/distributed.py", line 161, in _run_forward
results = self(**data, mode=mode)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 228, in forward
return self.compute_loss(data, data_samples)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 272, in compute_loss
return self._compute_sequence_parallel_loss(data)
File "/usr/local/lib/python3.10/dist-packages/xtuner/model/sft.py", line 262, in _compute_sequence_parallel_loss
outputs = self.llm(**data)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1395, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 179, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1205, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 989, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 460, in checkpoint
raise ValueError(
ValueError: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

Additional information

no error