microsoft/DeepSpeed

DeepSpeed with trl

sagie-dekel opened this issue · 7 comments

Describe the bug
I am trying to train meta-llama/Llama-3.1-8B-Instruct with trl DPOTrainer.
After creating the trainer and starting the training loop, I'm getting the following error (in the forward pass):

[rank0]: Traceback (most recent call last):
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/optimize_HP_for_RLRF.py", line 771, in <module>
[rank0]:     main()
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/optimize_HP_for_RLRF.py", line 759, in main
[rank0]:     RLRF_Pipeline(
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/optimize_HP_for_RLRF.py", line 345, in RLRF_Pipeline
[rank0]:     RLRF_manager.RLRF_DPO(model_save_path=RLRF_model_save_path)
[rank0]:   File "/rg/kurland_prj/sagie.dekel/RLRF/base_run/RLRF_main.py", line 367, in RLRF_DPO
[rank0]:     self.FTTrainer.train()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/trainer.py", line 2123, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/trainer.py", line 3579, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1371, in compute_loss
[rank0]:     loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1330, in get_batch_loss_metrics
[rank0]:     ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 950, in compute_ref_log_probs
[rank0]:     ref_model_output = self.concatenated_forward(self.ref_model, batch)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1260, in concatenated_forward
[rank0]:     outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/engine.py", line 1909, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
[rank0]:     outputs = self.model(
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 891, in forward
[rank0]:     inputs_embeds = self.embed_tokens(input_ids)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1779, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/parameter_offload.py", line 285, in _pre_forward_module_hook
[rank0]:     self.pre_sub_module_forward_function(module)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/parameter_offload.py", line 460, in pre_sub_module_forward_function
[rank0]:     param_coordinator.fetch_sub_module(sub_module, forward=True)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 331, in fetch_sub_module
[rank0]:     assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
[rank0]: AssertionError: {'id': 682, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 0, 'shape': (0,), 'ds_shape': (0,), 'requires_grad': False, 'grad_shape': None, 'persist': True, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([0])}
  0%|          | 0/168 [00:03<?, ?it/s]

I tried to downgrade transformers with no success.

System info (please complete the following information):

  • OS: Linux Rocky 8 (with Portable Batch System)
  • GPU count and types: one machine with 3 Nvidia A100 GPUs.
  • Python version: 3.10.15
  • Transformers: 4.46.3
  • Torch: 2.5.1+cu118
  • CUDA (python -c 'import torch; print(torch.version.cuda)'): 11.8
  • Packages Version:
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   haa98f57_10
absl-py                   2.1.0                    pypi_0    pypi
accelerate                1.2.0                    pypi_0    pypi
aiohappyeyeballs          2.4.4                    pypi_0    pypi
aiohttp                   3.11.10                  pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
annotated-types           0.6.0           py310h06a4308_0
async-timeout             5.0.1                    pypi_0    pypi
attrs                     24.2.0                   pypi_0    pypi
binutils_impl_linux-64    2.40                 h5293946_0
blas                      1.0                         mkl
brotli-python             1.0.9           py310h6a678d5_8
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2024.11.26           h06a4308_0
certifi                   2024.8.30       py310h06a4308_0
charset-normalizer        3.3.2              pyhd3eb1b0_0
compilers                 0.0.0                    pypi_0    pypi
cuda                      11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-cccl                 11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-command-line-tools   11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-compiler             11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-cudart               11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-cudart-dev           11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-cuobjdump            11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-cupti                11.8.87                       0    nvidia/label/cuda-11.8.0
cuda-cuxxfilt             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-demo-suite           11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-documentation        11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-driver-dev           11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-gdb                  11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-libraries            11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-libraries-dev        11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-memcheck             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nsight               11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nsight-compute       11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-nvcc                 11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-nvdisasm             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvml-dev             11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvprof               11.8.87                       0    nvidia/label/cuda-11.8.0
cuda-nvprune              11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvrtc                11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-nvrtc-dev            11.8.89                       0    nvidia/label/cuda-11.8.0
cuda-nvtx                 11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-nvvp                 11.8.87                       0    nvidia/label/cuda-11.8.0
cuda-profiler-api         11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-runtime              11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-sanitizer-api        11.8.86                       0    nvidia/label/cuda-11.8.0
cuda-toolkit              11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-tools                11.8.0                        0    nvidia/label/cuda-11.8.0
cuda-visual-tools         11.8.0                        0    nvidia/label/cuda-11.8.0
datasets                  3.1.0                    pypi_0    pypi
deepspeed                 0.16.2+9ca60160          pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.16.1                   pypi_0    pypi
freetype                  2.12.1               h4a9f257_0
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.9.0                 pypi_0    pypi
gcc                       11.4.0              h602e360_13    conda-forge
gcc_impl_linux-64         11.4.0              h00c12a0_13    conda-forge
gds-tools                 1.4.0.31                      0    nvidia/label/cuda-11.8.0
giflib                    5.2.2                h5eee18b_0
gmp                       6.2.1                h295c915_3
gmpy2                     2.1.2           py310heeb90bb_0
gnutls                    3.6.15               he1e5248_0
grpcio                    1.68.1                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
huggingface-hub           0.26.5                   pypi_0    pypi
idna                      3.7             py310h06a4308_0
intel-openmp              2023.1.0         hdb19cb5_46306
jinja2                    3.1.4           py310h06a4308_1
jpeg                      9e                   h5eee18b_3
kernel-headers_linux-64   3.10.0              h57e8cba_10
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.40                 h12ee557_0
lerc                      3.0                  h295c915_0
libaio                    0.3.113              h5eee18b_0
libcublas                 11.11.3.6                     0    nvidia/label/cuda-11.8.0
libcublas-dev             11.11.3.6                     0    nvidia/label/cuda-11.8.0
libcufft                  10.9.0.58                     0    nvidia/label/cuda-11.8.0
libcufft-dev              10.9.0.58                     0    nvidia/label/cuda-11.8.0
libcufile                 1.4.0.31                      0    nvidia/label/cuda-11.8.0
libcufile-dev             1.4.0.31                      0    nvidia/label/cuda-11.8.0
libcurand                 10.3.0.86                     0    nvidia/label/cuda-11.8.0
libcurand-dev             10.3.0.86                     0    nvidia/label/cuda-11.8.0
libcusolver               11.4.1.48                     0    nvidia/label/cuda-11.8.0
libcusolver-dev           11.4.1.48                     0    nvidia/label/cuda-11.8.0
libcusparse               11.7.5.86                     0    nvidia/label/cuda-11.8.0
libcusparse-dev           11.7.5.86                     0    nvidia/label/cuda-11.8.0
libdeflate                1.17                 h5eee18b_1
libffi                    3.4.4                h6a678d5_1
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-devel_linux-64     11.4.0             h8f596e0_113    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libiconv                  1.16                 h5eee18b_3
libidn2                   2.3.4                h5eee18b_0
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
libnpp                    11.8.0.86                     0    nvidia/label/cuda-11.8.0
libnpp-dev                11.8.0.86                     0    nvidia/label/cuda-11.8.0
libnvjpeg                 11.9.0.86                     0    nvidia/label/cuda-11.8.0
libnvjpeg-dev             11.9.0.86                     0    nvidia/label/cuda-11.8.0
libpng                    1.6.39               h5eee18b_0
libsanitizer              11.4.0              h5763a12_13    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libtasn1                  4.19.0               h5eee18b_0
libtiff                   4.5.1                h6a678d5_0
libunistring              0.9.10               h27cfd23_0
libuuid                   1.41.5               h5eee18b_0
libwebp                   1.3.2                h11a3e52_0
libwebp-base              1.3.2                h5eee18b_1
llvm-openmp               14.0.6               h9e868ea_0
lz4-c                     1.9.4                h6a678d5_1
markdown                  3.7                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344
mkl-service               2.4.0           py310h5eee18b_1
mkl_fft                   1.3.11          py310h5eee18b_0
mkl_random                1.2.8           py310h1128e8f_0
mpc                       1.1.0                h10f8cd9_1
mpfr                      4.0.2                hb69a4c5_1
mpmath                    1.3.0           py310h06a4308_0
msgpack                   1.1.0                    pypi_0    pypi
multidict                 6.1.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
nettle                    3.7.3                hbbd107a_1
networkx                  3.4.2                    pypi_0    pypi
ninja                     1.11.1.2                 pypi_0    pypi
ninja-base                1.12.1               hdb19cb5_0
nsight-compute            2022.3.0.22                   0    nvidia/label/cuda-11.8.0
numpy                     1.22.4                   pypi_0    pypi
nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu11         9.1.0.70                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu11        10.3.0.86                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu11          2.21.5                   pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu11          11.8.86                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openh264                  2.1.1                h4ff587b_0
openjpeg                  2.5.2                he7f1fd0_0
openssl                   3.4.0                hb9d3cd8_0    conda-forge
pandas                    2.0.0                    pypi_0    pypi
pillow                    11.0.0          py310hfdbf927_0
pip                       24.2            py310h06a4308_0
propcache                 0.2.1                    pypi_0    pypi
protobuf                  5.29.1                   pypi_0    pypi
psutil                    6.1.0                    pypi_0    pypi
py-cpuinfo                9.0.0           py310h06a4308_0
pyarrow                   18.1.0                   pypi_0    pypi
pydantic                  2.8.2           py310h06a4308_0
pydantic-core             2.20.1          py310hb02cf49_0
pygments                  2.18.0                   pypi_0    pypi
pysocks                   1.7.1           py310h06a4308_0
python                    3.10.15              he870216_1
python-dateutil           2.9.0.post0              pypi_0    pypi
pytorch-cuda              11.8                 h7e8668a_6    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.2                   pypi_0    pypi
pyyaml                    6.0.2           py310h5eee18b_0
readline                  8.2                  h5eee18b_0
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3          py310h06a4308_1
rich                      13.9.4                   pypi_0    pypi
safetensors               0.4.5                    pypi_0    pypi
setuptools                75.6.0                   pypi_0    pypi
six                       1.17.0                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0
sympy                     1.13.1                   pypi_0    pypi
sysroot_linux-64          2.17                h57e8cba_10
tbb                       2021.8.0             hdb19cb5_0
tensorboard               2.18.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.20.3                   pypi_0    pypi
torch                     2.5.1+cu118              pypi_0    pypi
torchaudio                2.5.1                    pypi_0    pypi
torchvision               0.20.1                   pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.46.3                   pypi_0    pypi
triton                    3.1.0                    pypi_0    pypi
trl                       0.12.2                   pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
typing_extensions         4.11.0          py310h06a4308_0
tzdata                    2024.2                   pypi_0    pypi
urllib3                   2.2.3           py310h06a4308_0
werkzeug                  3.1.3                    pypi_0    pypi
wheel                     0.44.0          py310h06a4308_0
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.4.6                h5eee18b_1
yaml                      0.2.5                h7b6447c_0
yarl                      1.18.3                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_1
zstd                      1.5.6                hc292b87_0

my accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: DEEPSPEED
downcast_bf16: false
gpu_ids: all
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 3
rdzv_backend: static
same_network: true
fsdp_config: {}
mixed_precision: bf16
use_cpu: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3

Can you try different offload settings? Only offload optimizer or parameters and not both. How much cpu memory does the system have?

Hi @jomayeri, thanks for answering.

I already tried all the offload permutations and got the same error.

I don't know the exact cpu memory, but I have 32 cpu's and during the run i got the following info:

[2024-12-11 10:55:07,014] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory:  used = 45.89 GB, percent = 4.6%

so I don't think CPU memory is the problem.

[rank0]: File "/home/sagie.dekel/PROGS/anaconda3/envs/deepspeed/lib/python3.10/site-packages/deepspeed-0.16.2+9ca60160-py3.10.egg/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 331, in fetch_sub_module
[rank0]: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
[rank0]: AssertionError: {'id': 682, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 0, 'shape': (0,), 'ds_shape': (0,), 'requires_grad': False, 'grad_shape': None, 'persist': True, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([0])}

This indicates that the parameter was not fetched or all-gathered as required before use. This is a very strange failure for zero stage 3. Are you able to share full repro steps? By this, I mean including command line and datasets.

perfernces_dataset_from_ranker_train_queries_and_baseline_doc.csv
@tjruwase

command line:

accelerate launch --config_file=/home/sagie.dekel/.cache/huggingface/accelerate/accelerate_deepspeed_config.yaml path_to_program_file.py path_to_param_file.json

The datasets include a small csv file with prompts and an accepted + rejected sample.

What does the program file consist of? And what's in param_file.json?

Hi again @jomayeri

what do you mean by consist of? it's a regular python file executing DPO pipeline.
The param_file.json contains parameters for the program (e.g., model name) and HP for the DPO trainer (e.g., learning rate or beta).