yoshitomo-matsubara/torchdistill

[BUG] fp16 causes AssertionError: No inf checks were recorded for this optimizer

jsrdcht opened this issue · 4 comments

Describe the bug
I modified the examples/legacy/image_classification.py to adapt to huggingface accelerate , meeting the following question:

Traceback (most recent call last):
  File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
    main(argparser.parse_args())
  File "examples/legacy/image_classification_accelerate.py", line 198, in main
    train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
  File "examples/legacy/image_classification_accelerate.py", line 129, in train
    train_one_epoch(training_box, device, epoch, log_freq)
  File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
    training_box.update_params(loss)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
    self.optimizer.step()
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

To Reproduce
Provide

  1. Exact command to run your code
    accelerate launch examples/legacy/image_classification_accelerate.py --config /workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml
  2. Whether or not you made any changes in Python code (if so, how you made the changes?)
    I have enabled the fp16 multi-gpu option in the configuration file of accelerate. My main experiment configuration file is for the AT algorithm.
    I made some modifications to the image_classification file, mainly following the modifications made to the text_classification.py file by the author. I did not make any personalized changes and simply followed the approach of text_classification.py with minimal modifications, which ultimately led to this error.
  3. YAML config file
datasets:
  ilsvrc2012:
    name: &dataset_name 'ilsvrc2012'
    type: 'ImageFolder'
    root: &root_dir !join ['/workspace/sync/imagenet-1k']
    splits:
      train:
        dataset_id: &imagenet_train !join [*dataset_name, '/train']
        params:
          root: !join [*root_dir, '/train']
          transform_params:
            - type: 'RandomResizedCrop'
              params:
                size: &input_size [224, 224]
            - type: 'RandomHorizontalFlip'
              params:
                p: 0.5
            - &totensor
              type: 'ToTensor'
              params:
            - &normalize
              type: 'Normalize'
              params:
                mean: [0.485, 0.456, 0.406]
                std: [0.229, 0.224, 0.225]
      val:
        dataset_id: &imagenet_val !join [*dataset_name, '/val']
        params:
          root: !join [*root_dir, '/val']
          transform_params:
            - type: 'Resize'
              params:
                size: 256
            - type: 'CenterCrop'
              params:
                size: *input_size
            - *totensor
            - *normalize

models:
  teacher_model:
    name: &teacher_model_name 'maskedvit_base_patch16_224'
    params:
      num_classes: 1000
      pretrained: True
      mask_ratio: 0.0
    experiment: &teacher_experiment !join [*dataset_name, '-', *teacher_model_name]
    ckpt: !join ['./resource/ckpt/ilsvrc2012/teacher/', *teacher_experiment, '.pt']
  student_model:
    name: &student_model_name 'maskedvit_base_patch16_224'
    params:
      num_classes: 1000
      pretrained: False
      mask_ratio: 0.5
    experiment: &student_experiment !join [*dataset_name, '-', *student_model_name, '_from_', *teacher_model_name]
    ckpt: !join ['./imagenet/mask_distillation/', *student_experiment, '.pt']

train:
  log_freq: 1000
  num_epochs: 100
  train_data_loader:
    dataset_id: *imagenet_train
    random_sample: True
    batch_size: 64
    num_workers: 16
    cache_output:
  val_data_loader:
    dataset_id: *imagenet_val
    random_sample: False
    batch_size: 128
    num_workers: 16
  teacher:
    sequential: []
    forward_hook:
      input: []
      output: ['mask_filter']
    wrapper: 'DataParallel'
    requires_grad: False
  student:
    adaptations:
    sequential: []
    frozen_modules: []
    forward_hook:
      input: []
      output: ['mask_filter']
    wrapper: 'DistributedDataParallel'
    requires_grad: True
  optimizer:
    type: 'SGD'
    grad_accum_step: 16
    max_grad_norm: 5.0
    module_wise_params:
      - params: ['mask_token', 'cls_token', 'pos_embed']
        is_teacher: None
        module: None
        weight_decay: 0.0
    params:
      lr: 0.001
      momentum: 0.9
      weight_decay: 0.0001
      
  scheduler:
    type: 'MultiStepLR'
    params:
      milestones: [30, 60, 90]
      gamma: 0.1
  criterion:
    type: 'GeneralizedCustomLoss'
    org_term:
      criterion:
        type: 'CrossEntropyLoss'
        params:
          reduction: 'mean'
      factor: 1.0
    sub_terms:
      GenerativeKDLoss:
        criterion:
          type: 'GenerativeKDLoss'
          params:
            student_module_io: 'output'
            student_module_path: 'mask_filter'
            teacher_module_io: 'output'
            teacher_module_path: 'mask_filter'
        factor: 1.0

test:
  test_data_loader:
    dataset_id: *imagenet_val
    random_sample: False
    batch_size: 1
    num_workers: 16

  1. Log file
(pytorch_1) root@baa8ef5448b2:/workspace/sync/torchdistill# accelerate launch examples/legacy/image_classification_accelerate.py --config /workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml
2023/08/15 02:49:09     INFO    __main__        Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Added key: store_based_barrier_key:1 to store for rank: 0
2023/08/15 02:49:09     INFO    __main__        Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Added key: store_based_barrier_key:1 to store for rank: 1
2023/08/15 02:49:09     INFO    __main__        Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09     INFO    __main__        Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Added key: store_based_barrier_key:1 to store for rank: 2
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Added key: store_based_barrier_key:1 to store for rank: 3
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09     INFO    torch.distributed.distributed_c10d      Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09     INFO    __main__        Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 4
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: fp16

2023/08/15 02:49:09     INFO    torchdistill.datasets.util      Loading train data
2023/08/15 02:49:12     INFO    torchdistill.datasets.util      dataset_id `ilsvrc2012/train`: 2.874385356903076 sec
2023/08/15 02:49:12     INFO    torchdistill.datasets.util      Loading val data
2023/08/15 02:49:12     INFO    torchdistill.datasets.util      dataset_id `ilsvrc2012/val`: 0.12787175178527832 sec
2023/08/15 02:49:15     INFO    timm.models._builder    Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2023/08/15 02:49:16     INFO    timm.models._hub        [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2023/08/15 02:49:16     INFO    torchdistill.common.main_util   ckpt file is not found at `./resource/ckpt/ilsvrc2012/teacher/ilsvrc2012-maskedvit_base_patch16_224.pt`
2023/08/15 02:49:18     INFO    torchdistill.common.main_util   ckpt file is not found at `./imagenet/mask_distillation/ilsvrc2012-maskedvit_base_patch16_224_from_maskedvit_base_patch16_224.pt`
2023/08/15 02:49:18     INFO    __main__        Start training
2023/08/15 02:49:18     INFO    torchdistill.models.util        [teacher model]
2023/08/15 02:49:18     INFO    torchdistill.models.util        Using the original teacher model
2023/08/15 02:49:18     INFO    torchdistill.models.util        [student model]
2023/08/15 02:49:18     INFO    torchdistill.models.util        Using the original student model
2023/08/15 02:49:18     INFO    torchdistill.core.distillation  Loss = 1.0 * OrgLoss + 1.0 * GenerativeKDLoss(
  (cross_entropy_loss): CrossEntropyLoss()
  (SmoothL1Loss): SmoothL1Loss()
)
2023/08/15 02:49:18     INFO    torchdistill.core.distillation  Freezing the whole teacher model
2023/08/15 02:49:18     INFO    torchdistill.common.module_util `None` of `None` could not be reached in `DataParallel`
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
  warnings.warn(
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
  warnings.warn(
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
  warnings.warn(
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
  warnings.warn(
2023/08/15 02:49:24     INFO    torchdistill.misc.log   Epoch: [0]  [   0/5005]  eta: 8:39:24  lr: 0.001  img/s: 21.99282017795937  loss: 0.4513 (0.4513)  time: 6.2267  data: 3.3162  max mem: 8400
2023/08/15 02:49:24     INFO    torch.nn.parallel.distributed   Reducer buckets have been rebuilt in this iteration.
2023/08/15 02:49:24     INFO    torch.nn.parallel.distributed   Reducer buckets have been rebuilt in this iteration.
Traceback (most recent call last):
  File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
Traceback (most recent call last):
  File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
    main(argparser.parse_args())
  File "examples/legacy/image_classification_accelerate.py", line 198, in main
    train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
  File "examples/legacy/image_classification_accelerate.py", line 129, in train
    train_one_epoch(training_box, device, epoch, log_freq)
  File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
    main(argparser.parse_args())
  File "examples/legacy/image_classification_accelerate.py", line 198, in main
    training_box.update_params(loss)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
    train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
  File "examples/legacy/image_classification_accelerate.py", line 129, in train
    self.optimizer.step()    
train_one_epoch(training_box, device, epoch, log_freq)  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step

  File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
    training_box.update_params(loss)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
    self.optimizer.step()
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
        self.scaler.step(self.optimizer, closure)self.scaler.step(self.optimizer, closure)

  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
        assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."

AssertionErrorAssertionError: : No inf checks were recorded for this optimizer.No inf checks were recorded for this optimizer.

Traceback (most recent call last):
  File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
    main(argparser.parse_args())
  File "examples/legacy/image_classification_accelerate.py", line 198, in main
    train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
  File "examples/legacy/image_classification_accelerate.py", line 129, in train
    train_one_epoch(training_box, device, epoch, log_freq)
  File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
    training_box.update_params(loss)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
    self.optimizer.step()
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
Traceback (most recent call last):
  File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
    main(argparser.parse_args())
  File "examples/legacy/image_classification_accelerate.py", line 198, in main
    train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
  File "examples/legacy/image_classification_accelerate.py", line 129, in train
    train_one_epoch(training_box, device, epoch, log_freq)
  File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
    training_box.update_params(loss)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
    self.optimizer.step()
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 3701268) of binary: /root/miniconda3/envs/pytorch_1/bin/python
Traceback (most recent call last):
  File "/root/miniconda3/envs/pytorch_1/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/launch.py", line 970, in launch_command
    multi_gpu_launcher(args)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/launch.py", line 646, in multi_gpu_launcher
    distrib_run.run(args)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
examples/legacy/image_classification_accelerate.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-08-15_02:49:37
  host      : baa8ef5448b2
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3701269)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2023-08-15_02:49:37
  host      : baa8ef5448b2
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 3701270)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2023-08-15_02:49:37
  host      : baa8ef5448b2
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 3701271)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-08-15_02:49:37
  host      : baa8ef5448b2
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3701268)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Expected behavior
A clear and concise description of what you expected to happen.

Environment (please complete the following information):

  • OS: Ubuntu 22.04 LTS
  • Python ver.3.8
  • torchdistill ver. v0.3.3
(pytorch_1) root@baa8ef5448b2:/workspace/sync/torchdistill# conda list
# packages in environment at /root/miniconda3/envs/pytorch_1:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    defaults
_openmp_mutex             5.1                       1_gnu    defaults
accelerate                0.21.0                   pypi_0    pypi
blas                      1.0                         mkl    defaults
brotlipy                  0.7.0           py38h27cfd23_1003    defaults
bzip2                     1.0.8                h7b6447c_0    defaults
ca-certificates           2023.05.30           h06a4308_0    defaults
certifi                   2023.7.22        py38h06a4308_0    defaults
cffi                      1.15.1           py38h5eee18b_3    defaults
charset-normalizer        2.0.4              pyhd3eb1b0_0    defaults
contourpy                 1.1.0                    pypi_0    pypi
cryptography              41.0.2           py38h22a60cf_0    defaults
cuda-cudart               11.7.99                       0    nvidia
cuda-cupti                11.7.101                      0    nvidia
cuda-libraries            11.7.1                        0    nvidia
cuda-nvrtc                11.7.99                       0    nvidia
cuda-nvtx                 11.7.91                       0    nvidia
cuda-runtime              11.7.1                        0    nvidia
cycler                    0.11.0                   pypi_0    pypi
cython                    3.0.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.12.2                   pypi_0    pypi
fonttools                 4.42.0                   pypi_0    pypi
freetype                  2.12.1               h4a9f257_0    defaults
fsspec                    2023.6.0                 pypi_0    pypi
future                    0.18.3           py38h06a4308_0    defaults
giflib                    5.2.1                h5eee18b_3    defaults
gmp                       6.2.1                h295c915_3    defaults
gnutls                    3.6.15               he1e5248_0    defaults
huggingface-hub           0.16.4                   pypi_0    pypi
idna                      3.4              py38h06a4308_0    defaults
importlib-resources       6.0.1                    pypi_0    pypi
intel-openmp              2023.1.0         hdb19cb5_46305    defaults
jpeg                      9e                   h5eee18b_1    defaults
kiwisolver                1.4.4                    pypi_0    pypi
lame                      3.100                h7b6447c_0    defaults
lcms2                     2.12                 h3be6417_0    defaults
ld_impl_linux-64          2.38                 h1181459_1    defaults
lerc                      3.0                  h295c915_0    defaults
libcublas                 11.10.3.66                    0    nvidia
libcufft                  10.7.2.124           h4fbf590_0    nvidia
libcufile                 1.7.1.12                      0    nvidia
libcurand                 10.3.3.129                    0    nvidia
libcusolver               11.4.0.1                      0    nvidia
libcusparse               11.7.4.91                     0    nvidia
libdeflate                1.17                 h5eee18b_0    defaults
libffi                    3.4.4                h6a678d5_0    defaults
libgcc-ng                 11.2.0               h1234567_1    defaults
libgfortran-ng            11.2.0               h00389a5_1    defaults
libgfortran5              11.2.0               h1234567_1    defaults
libgomp                   11.2.0               h1234567_1    defaults
libiconv                  1.16                 h7f8727e_2    defaults
libidn2                   2.3.4                h5eee18b_0    defaults
libnpp                    11.7.4.75                     0    nvidia
libnvjpeg                 11.8.0.2                      0    nvidia
libopenblas               0.3.21               h043d6bf_0    defaults
libpng                    1.6.39               h5eee18b_0    defaults
libprotobuf               3.20.3               he621ea3_0    defaults
libstdcxx-ng              11.2.0               h1234567_1    defaults
libtasn1                  4.19.0               h5eee18b_0    defaults
libtiff                   4.5.0                h6a678d5_2    defaults
libunistring              0.9.10               h27cfd23_0    defaults
libwebp                   1.2.4                h11a3e52_1    defaults
libwebp-base              1.2.4                h5eee18b_1    defaults
lz4-c                     1.9.4                h6a678d5_0    defaults
matplotlib                3.7.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46343    defaults
mkl-service               2.4.0            py38h5eee18b_1    defaults
mkl_fft                   1.3.6            py38h417a72b_1    defaults
mkl_random                1.2.2            py38h417a72b_1    defaults
ncurses                   6.4                  h6a678d5_0    defaults
nettle                    3.7.3                hbbd107a_1    defaults
ninja                     1.10.2               h06a4308_5    defaults
ninja-base                1.10.2               hd09550d_5    defaults
numpy                     1.24.3           py38hf6e8229_1    defaults
numpy-base                1.24.3           py38h060ed82_1    defaults
openh264                  2.1.1                h4ff587b_0    defaults
openssl                   3.0.10               h7f8727e_0    defaults
packaging                 23.1                     pypi_0    pypi
pillow                    9.4.0            py38h6a678d5_0    defaults
pip                       23.2.1           py38h06a4308_0    defaults
psutil                    5.9.5                    pypi_0    pypi
pycocotools               2.0.6                    pypi_0    pypi
pycparser                 2.21               pyhd3eb1b0_0    defaults
pyopenssl                 23.2.0           py38h06a4308_0    defaults
pyparsing                 3.0.9                    pypi_0    pypi
pysocks                   1.7.1            py38h06a4308_0    defaults
python                    3.8.17               h955ad1f_0    defaults
python-dateutil           2.8.2                    pypi_0    pypi
pytorch                   1.13.0          py3.8_cuda11.7_cudnn8.5.0_0    pytorch
pytorch-cuda              11.7                 h778d358_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0              py38h5eee18b_1    defaults
readline                  8.2                  h5eee18b_0    defaults
requests                  2.31.0           py38h06a4308_0    defaults
safetensors               0.3.2                    pypi_0    pypi
scipy                     1.10.1                   pypi_0    pypi
setuptools                68.0.0           py38h06a4308_0    defaults
six                       1.16.0                   pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0    defaults
tbb                       2021.8.0             hdb19cb5_0    defaults
timm                      0.9.5                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0    defaults
torchaudio                0.13.0               py38_cu117    pytorch
torchdistill              0.3.3                    pypi_0    pypi
torchvision               0.14.0               py38_cu117    pytorch
tqdm                      4.66.1                   pypi_0    pypi
typing-extensions         4.7.1            py38h06a4308_0    defaults
typing_extensions         4.7.1            py38h06a4308_0    defaults
urllib3                   1.26.16          py38h06a4308_0    defaults
wheel                     0.38.4           py38h06a4308_0    defaults
xz                        5.4.2                h5eee18b_0    defaults
yaml                      0.2.5                h7b6447c_0    defaults
zipp                      3.16.2                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0    defaults
zstd                      1.5.5                hc292b87_0    defaults

Additional context
Add any other context about the problem here.

Hi @jsrdcht

Since you made changes in code and did not share the actual code, I cannot confirm that this is a bug from torchdistill.

If you're trying to introduce new components and still at trial-and-error phase, please use Discussions instead, and provide your modified code. You also still keep this discussion unanswered for this topic.

From the error message, I guess that you didn't pass accelerator when instantiating distillation_box or training_box, and then missed the following lines to be executed inside distillation_box or training_box
https://github.com/yoshitomo-matsubara/torchdistill/blob/v0.3.3/torchdistill/core/distillation.py#L306-L307

Here's my code. Not because of the reason you guessed, I have debugged it and saw that line of code being executed.

import argparse
import datetime
import logging
import os
import time

import torch
import vit
import custom_loss

from accelerate import Accelerator, DistributedType

from torch import distributed as dist
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from torchdistill.common import file_util, yaml_util, module_util
from torchdistill.common.constant import def_logger
from torchdistill.common.main_util import is_main_process, init_distributed_mode, load_ckpt, save_ckpt, set_seed, setup_for_distributed
from torchdistill.core.distillation import get_distillation_box
from torchdistill.core.training import get_training_box
from torchdistill.datasets import util
from torchdistill.eval.classification import compute_accuracy
from torchdistill.misc.log import setup_log_file, SmoothedValue, MetricLogger
from torchdistill.models.official import get_image_classification_model
from torchdistill.models.registry import get_model

logger = def_logger.getChild(__name__)


def get_argparser():
    parser = argparse.ArgumentParser(description='Knowledge distillation for image classification models')
    parser.add_argument('--config', required=True, help='yaml file path')
    parser.add_argument('--device', default='cuda', help='device')
    parser.add_argument('--log', help='log file path')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
    parser.add_argument('--seed', type=int, help='seed in random number generator')
    parser.add_argument('-test_only', action='store_true', help='only test the models')
    parser.add_argument('-student_only', action='store_true', help='test the student model only')
    parser.add_argument('-log_config', action='store_true', help='log config')
    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('-adjust_lr', action='store_true',
                        help='multiply learning rate by number of distributed processes (world_size)')
    return parser


def load_model(model_config, device, distributed):
    model = get_image_classification_model(model_config, distributed)
    if model is None:
        repo_or_dir = model_config.get('repo_or_dir', None)
        model = get_model(model_config['name'], repo_or_dir, **model_config['params'])

    ckpt_file_path = model_config['ckpt']
    load_ckpt(ckpt_file_path, model=model, strict=True)
    return model.to(device)


def train_one_epoch(training_box, device, epoch, log_freq):
    metric_logger = MetricLogger(delimiter='  ')
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}'))
    header = 'Epoch: [{}]'.format(epoch)
    for sample_batch, targets, supp_dict in \
            metric_logger.log_every(training_box.train_data_loader, log_freq, header):
        start_time = time.time()
        sample_batch, targets = sample_batch.to(device), targets.to(device)
        loss = training_box(sample_batch, targets, supp_dict)
        training_box.update_params(loss)
        batch_size = sample_batch.shape[0]
        metric_logger.update(loss=loss.item(), lr=training_box.optimizer.param_groups[0]['lr'])
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        if (torch.isnan(loss) or torch.isinf(loss)) and is_main_process():
            raise ValueError('The training loop was broken due to loss = {}'.format(loss))


@torch.inference_mode()
def evaluate(model, data_loader, device, device_ids, distributed, log_freq=1000, title=None, header='Test:'):
    model.to(device)
    if distributed:
        model = DistributedDataParallel(model, device_ids=device_ids)
    elif device.type.startswith('cuda'):
        model = DataParallel(model, device_ids=device_ids)

    if title is not None:
        logger.info(title)

    model.eval()
    metric_logger = MetricLogger(delimiter='  ')
    for image, target in metric_logger.log_every(data_loader, log_freq, header):
        image = image.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        output = model(image)
        acc1, acc5 = compute_accuracy(output, target, topk=(1, 5))
        # FIXME need to take into account that the datasets
        # could have been padded in distributed setup
        batch_size = image.shape[0]
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    top1_accuracy = metric_logger.acc1.global_avg
    top5_accuracy = metric_logger.acc5.global_avg
    logger.info(' * Acc@1 {:.4f}\tAcc@5 {:.4f}\n'.format(top1_accuracy, top5_accuracy))
    return metric_logger.acc1.global_avg


def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor, accelerator) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor, accelerator = accelerator)
    best_val_top1_accuracy = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_top1_accuracy = evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed,
                                     log_freq=log_freq, header='Validation:')
        if val_top1_accuracy > best_val_top1_accuracy and is_main_process():
            logger.info('Best top-1 accuracy: {:.4f} -> {:.4f}'.format(best_val_top1_accuracy, val_top1_accuracy))
            logger.info('Updating ckpt at {}'.format(ckpt_file_path))
            best_val_top1_accuracy = val_top1_accuracy
            if distributed is False and accelerator is not None:
                student_model_without_ddp = accelerator.unwrap_model(student_model)
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_top1_accuracy, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()


def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    logger.info(args)
    cudnn.benchmark = True
    set_seed(args.seed)

    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))

    # distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url)
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    accelerator = Accelerator(mixed_precision='fp16')
    distributed = accelerator.state.distributed_type == DistributedType.MULTI_GPU
    device_ids = [accelerator.device.index]
    if distributed:
        setup_for_distributed(is_main_process())

    logger.info(accelerator.state)
    device = accelerator.device

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)


    
    # device = torch.device(args.device)
    dataset_dict = util.get_all_datasets(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config.get('teacher_model', None)
    teacher_model =\
        load_model(teacher_model_config, device, distributed) if teacher_model_config is not None else None
    student_model_config =\
        models_config['student_model'] if 'student_model' in models_config else models_config['model']
    ckpt_file_path = student_model_config['ckpt']
    student_model = load_model(student_model_config, device, distributed)
    if accelerator is not None:
        student_model, teacher_model = accelerator.prepare(student_model, teacher_model)
        for name, dataset in dataset_dict.items():
            dataset_dict[name] = accelerator.prepare(dataset)
    if args.log_config:
        logger.info(config)


    if not args.test_only:
        train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(dataset_dict[test_data_loader_config['dataset_id']],
                                              test_data_loader_config, distributed)
    log_freq = test_config.get('log_freq', 1000)
    if not args.student_only and teacher_model is not None:
        evaluate(teacher_model, test_data_loader, device, device_ids, distributed, log_freq=log_freq,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model, test_data_loader, device, device_ids, distributed, log_freq=log_freq,
             title='[Student: {}]'.format(student_model_config['name']))


if __name__ == '__main__':
    argparser = get_argparser()
    main(argparser.parse_args())
    module_wise_params:
      - params: ['mask_token', 'cls_token', 'pos_embed']
        is_teacher: None
        module: None
        weight_decay: 0.0

Your module_wise_params entry looks broken, and use is_teacher: False instead of None (or you can skip is_teacher, as the default value is False)

See https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/legacy/sample/pascal_voc2012/ce/deeplabv3_resnet101.yaml#L102-L109 for the format

Again, please use Discussions for this kind of question since it doesn't look like a bug from torchdistill

It is the issues you marked that caused the strange errors. I also made some changes in the source code to adapt to my configuration.

Here are some problems that exist in the source code:

def pre_process(self, epoch=None, **kwargs):
        clear_io_dict(self.teacher_io_dict)
        clear_io_dict(self.student_io_dict)
        self.teacher_model.eval()
        self.student_model.train()
        if self.distributed and self.accelerator is None: # batch_sampler.sampler is only valid for ddp without accelerator
            self.train_data_loader.batch_sampler.sampler.set_epoch(epoch)
# Set up accelerator if necessary
        if self.accelerator is not None:
            if self.teacher_updatable:
                self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
                    self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
                                             self.train_data_loader, self.val_data_loader)
            else:
                # self.teacher_model = self.teacher_model.to(self.accelerator.device)
                # if self.accelerator.state.use_fp16:
                #     self.teacher_model = self.teacher_model.half()
                
               # sice fp16 is took by accelerate, we have to warp the teacher model otherwise the input can't be casted to fp16
                self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
                    self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
                                             self.train_data_loader, self.val_data_loader)
module_wise_params_configs = optim_config.get('module_wise_params', list())
            if len(module_wise_params_configs) > 0:
                trainable_module_list = list()
                for module_wise_params_config in module_wise_params_configs:
                    module_wise_params_dict = dict()
                    if isinstance(module_wise_params_config.get('params', None), dict):
                        module_wise_params_dict.update(module_wise_params_config['params'])

                    if 'lr' in module_wise_params_dict:
                        module_wise_params_dict['lr'] *= self.lr_factor

                    target_model = \
                        self.teacher_model if module_wise_params_config.get('is_teacher', False) else self.student_model
                    module = get_module(target_model, module_wise_params_config['module'])
                    # support for nn.Parameter()
                    module_wise_params_dict['params'] = module.parameters() if isinstance(module, nn.Module) else [module]
                    trainable_module_list.append(module_wise_params_dict)
            else:
                trainable_module_list = nn.ModuleList([self.student_model])
                if self.teacher_updatable:
                    logger.info('Note that you are training some/all of the modules in the teacher model')
                    trainable_module_list.append(self.teacher_model)