Memory problem of Lisa finetuning
lovekdl opened this issue · 6 comments
I tried fine-tuning the llama-2-7b model using LoRa on an RTX3090 with 24GB, where the memory usage was only about 17GB. However, when I used the same configuration on an A100 with 80GB, the memory usage soared over 70GB. I would like to know if this situation is normal and how I can reduce the memory consumption on the A100 80GB GPU.
I encountered the same issue when fine-tuning with Lisa. The memory consumption on the A100 80GB was significantly higher than on the RTX3090 24GB.
Config
model_name_or_path=meta-llama/Llama-2-7b-hf dataset_path=data/alpaca-gpt4 output_dir=output_models/finetuned_llama_2_7b_lora_128_batch1exp_id=finetuned_llama_2_7b_lora_128_batch1
project_dir=$(cd "$(dirname $0)"/..; pwd)
log_dir=${project_dir}/log/${exp_id}
mkdir -p ${output_dir} ${log_dir}
use_flash_attention=0
deepspeed examples/finetune.py
--model_name_or_path ${model_name_or_path}
--dataset_path ${dataset_path}
--output_dir ${output_dir} --overwrite_output_dir
--num_train_epochs 1
--learning_rate 5e-5
--block_size 512
--per_device_train_batch_size 1
--use_lora 1
--deepspeed configs/ds_config_zero2.json
--lora_r 128
--save_aggregated_lora 1
--fp16
--run_name ${exp_id}
--validation_split_percentage 0
--logging_steps 1
--do_train
--use_flash_attention ${use_flash_attention}
--ddp_timeout 72000
--save_steps 500000
--dataloader_num_workers 1
| tee ${log_dir}/train.log
2> ${log_dir}/train.err
GPU info
Thanks for your interest in LMFlow! I just tested the LISA script in 48G memory GPUs, and the memory consumption looks good. We think the mentioned memory-spike problem can be caused by deepspeed, as it will normally pre-allocate memory before training. You may try the original script.
If the problem does not occur again, you can locate the issue by turning off deepspeed offload (--deepspeed configs/ds_config_zero2_no_offload.json
), and zero2 (--deepspeed configs/ds_config_zero0_no_offload.json
) to see which mechanism causes the issue.
Hope this information can be helpful 😄
@research4pan Currently, I use deepspeed + lora for llama-2-7b fine tuning, and memory consumption is normal now.
But when I use lisa without deepspeed, I still have the problem of memory-spike. The GPU memory consumption increases slowly, and reach 65GB after 3000 steps.
script:
model_name_or_path=meta-llama/Llama-2-7b-hf
dataset_path=data/alpaca-gpt4
output_dir=output_models/finetune_lisa
lisa_activated_layers=1
lisa_interval_steps=20
gradient_checkpointing=True
use_flash_attention=0
gradient_accumulation_steps=1
block_size=256
per_device_train_batch_size=1
num_gpu=$(python -c "import torch; print(torch.cuda.device_count())")
ds_config_file=configs/ds_config_zero0_no_offload.json
if [ ${num_gpu} -ge 2 ]; then
ds_config_file=configs/ds_config_zero2_no_offload.json
fi
while [[ $# -ge 1 ]]; do
key="$1"
case ${key} in
-m|--model_name_or_path)
model_name_or_path="$2"
shift
;;
-d|--dataset_path)
dataset_path="$2"
shift
;;
-o|--output_model_path)
output_dir="$2"
shift
;;
--lisa_activated_layers)
lisa_activated_layers="$2"
shift
;;
--lisa_interval_steps)
lisa_interval_steps="$2"
shift
;;
--gradient_checkpointing)
gradient_checkpointing="$2"
shift
;;
--deepspeed)
ds_config_file="$2"
shift
;;
--use_flash_attention)
use_flash_attention="$2"
shift
;;
--gradient_accumulation_steps)
gradient_accumulation_steps="$2"
shift
;;
--block_size)
block_size="$2"
shift
;;
--per_device_train_batch_size|--batch_size)
per_device_train_batch_size="$2"
shift
;;
*)
echo "error: unknown option "${key}"" 1>&2
exit 1
esac
shift
done
exp_id=finetune
project_dir=$(cd "$(dirname $0)"/.; pwd)
log_dir=${project_dir}/log/${exp_id}
mkdir -p ${output_dir} ${log_dir}
python examples/finetune.py
--model_name_or_path ${model_name_or_path}
--dataset_path ${dataset_path}
--output_dir ${output_dir} --overwrite_output_dir
--num_train_epochs 1
--learning_rate 5e-5
--disable_group_texts 1
--block_size ${block_size}
--per_device_train_batch_size ${per_device_train_batch_size}
--bf16
--torch_dtype bfloat16
--run_name finetune
--optim paged_adamw_32bit
--validation_split_percentage 0
--logging_steps 5
--do_train
--ddp_timeout 72000
--save_steps 500000
--dataloader_num_workers 1
--gradient_checkpointing ${gradient_checkpointing}
--use_flash_attention ${use_flash_attention}
--gradient_accumulation_steps ${gradient_accumulation_steps}
--use_lisa 1
--lisa_activated_layers ${lisa_activated_layers}
--lisa_interval_steps ${lisa_interval_steps}
| tee ${log_dir}/train.log
2> ${log_dir}/train.err
There might be some optimizer problem, I think.
If I set self.freeze_all_layers() in the init() of class DynamicLayerActivationCallback(TrainerCallback), the memory consumption is normally 17G.
It seems that each time new layers are activated, the memory consumption may increase, and layers activated again will not increase the memory consumption.
This seems like a problem related to deepspeed. We are currently implementing a model-parallelism version that reinitializes optimizer state every time, which shall solve this issue as well. Please stay tuned for our latest updates 😄
It seems that each time new layers are activated, the memory consumption may increase, and layers activated again will not increase the memory consumption.
I also found this phenomenon using A6000 GPU. The following are my configurations. I use the latest version of LMFlow up to 2024/10/16.
Package Version Editable project location
absl-py 2.1.0
accelerate 1.0.1
aiofiles 23.2.1
aiohappyeyeballs 2.4.3
aiohttp 3.10.10
aiosignal 1.3.1
annotated-types 0.7.0
anyio 4.6.2.post1
appdirs 1.4.4
async-timeout 4.0.3
attrs 24.2.0
bitsandbytes 0.44.1
blinker 1.8.2
certifi 2024.8.30
chardet 5.2.0
charset-normalizer 3.4.0
click 8.1.7
cloudpickle 3.1.0
colorama 0.4.6
contourpy 1.3.0
cpm-kernels 1.0.11
cycler 0.12.1
DataProperty 1.0.1
datasets 2.14.6
deepspeed 0.15.2
dill 0.3.4
diskcache 5.6.3
distro 1.9.0
docker-pycreds 0.4.0
docstring_parser 0.16
einops 0.8.0
eval_type_backport 0.2.0
evaluate 0.4.0
exceptiongroup 1.2.2
fastapi 0.115.2
ffmpy 0.4.0
filelock 3.16.1
flash-attn 2.6.3
Flask 3.0.3
Flask-Cors 5.0.0
fonttools 4.54.1
frozenlist 1.4.1
fsspec 2023.10.0
gguf 0.10.0
gitdb 4.0.11
GitPython 3.1.43
gradio 4.44.1
gradio_client 1.3.0
h11 0.14.0
hjson 3.1.0
httpcore 1.0.6
httptools 0.6.2
httpx 0.27.2
huggingface-hub 0.25.2
icetk 0.0.7
idna 3.10
importlib_metadata 8.5.0
importlib_resources 6.4.5
interegular 0.3.3
itsdangerous 2.2.0
Jinja2 3.1.4
jiter 0.6.1
joblib 1.4.2
jsonlines 4.0.0
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
kiwisolver 1.4.7
lark 1.2.2
llvmlite 0.43.0
lm-eval 0.3.0
lm-format-enforcer 0.10.6
lmflow 0.0.7 /home/ubuntu/data/lisa-main/LMFlow/src
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.2
mbstrdecoder 1.1.3
mdurl 0.1.2
mistral_common 1.4.4
mpi4py 4.0.1
mpmath 1.3.0
msgpack 1.1.0
msgspec 0.18.6
multidict 6.1.0
multiprocess 0.70.12.2
nest-asyncio 1.6.0
networkx 3.2.1
ninja 1.11.1.1
nltk 3.9.1
numba 0.60.0
numexpr 2.10.1
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.560.30
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.77
nvidia-nvtx-cu12 12.1.105
openai 1.51.2
opencv-python-headless 4.10.0.84
orjson 3.10.7
outlines 0.0.46
packaging 24.1
pandas 2.2.3
partial-json-parser 0.2.1.1.post4
pathtools 0.1.2
pathvalidate 3.2.1
peft 0.13.2
pillow 10.4.0
pip 24.2
portalocker 2.10.1
prometheus_client 0.21.0
prometheus-fastapi-instrumentator 7.0.0
propcache 0.2.0
protobuf 3.18.3
psutil 6.0.0
py-cpuinfo 9.0.0
pyairports 2.1.1
pyarrow 17.0.0
pybind11 2.13.6
pycountry 24.6.1
pydantic 2.9.2
pydantic_core 2.23.4
pydub 0.25.1
Pygments 2.18.0
pyparsing 3.2.0
pytablewriter 1.2.0
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.12
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
ray 2.37.0
referencing 0.35.1
regex 2024.9.11
requests 2.32.3
responses 0.18.0
rich 13.9.2
rouge-score 0.1.2
rpds-py 0.20.0
ruff 0.6.9
sacrebleu 1.5.0
safetensors 0.4.5
scikit-learn 1.2.2
scipy 1.13.1
semantic-version 2.10.0
sentencepiece 0.2.0
sentry-sdk 2.16.0
setproctitle 1.3.3
setuptools 68.2.0
shellingham 1.5.4
shtab 1.7.1
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
sqlitedict 2.1.0
starlette 0.40.0
sympy 1.13.3
tabledata 1.3.3
tcolorpy 0.1.6
threadpoolctl 3.5.0
tiktoken 0.7.0
tokenizers 0.20.1
tomlkit 0.12.0
torch 2.4.0
torchvision 0.19.0
tqdm 4.66.5
tqdm-multiprocess 0.0.11
transformers 4.45.2
triton 3.0.0
trl 0.8.0
typepy 1.3.2
typer 0.12.5
typing_extensions 4.12.2
tyro 0.8.12
tzdata 2024.2
urllib3 2.2.3
uvicorn 0.32.0
uvloop 0.21.0
vllm 0.6.3
wandb 0.14.0
watchfiles 0.24.0
websockets 12.0
Werkzeug 3.0.4
wheel 0.41.2
xformers 0.0.27.post2
xxhash 3.5.0
yarl 1.15.3
zipp 3.20.2
zstandard 0.23.0