[BUG] fp16 causes AssertionError: No inf checks were recorded for this optimizer
jsrdcht opened this issue · 4 comments
Describe the bug
I modified the examples/legacy/image_classification.py to adapt to huggingface accelerate , meeting the following question:
Traceback (most recent call last):
File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
main(argparser.parse_args())
File "examples/legacy/image_classification_accelerate.py", line 198, in main
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
File "examples/legacy/image_classification_accelerate.py", line 129, in train
train_one_epoch(training_box, device, epoch, log_freq)
File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
training_box.update_params(loss)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
self.optimizer.step()
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
self.scaler.step(self.optimizer, closure)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
To Reproduce
Provide
- Exact command to run your code
accelerate launch examples/legacy/image_classification_accelerate.py --config /workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml
- Whether or not you made any changes in Python code (if so, how you made the changes?)
I have enabled the fp16 multi-gpu option in the configuration file of accelerate. My main experiment configuration file is for the AT algorithm.
I made some modifications to the image_classification file, mainly following the modifications made to the text_classification.py file by the author. I did not make any personalized changes and simply followed the approach of text_classification.py with minimal modifications, which ultimately led to this error. - YAML config file
datasets:
ilsvrc2012:
name: &dataset_name 'ilsvrc2012'
type: 'ImageFolder'
root: &root_dir !join ['/workspace/sync/imagenet-1k']
splits:
train:
dataset_id: &imagenet_train !join [*dataset_name, '/train']
params:
root: !join [*root_dir, '/train']
transform_params:
- type: 'RandomResizedCrop'
params:
size: &input_size [224, 224]
- type: 'RandomHorizontalFlip'
params:
p: 0.5
- &totensor
type: 'ToTensor'
params:
- &normalize
type: 'Normalize'
params:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
val:
dataset_id: &imagenet_val !join [*dataset_name, '/val']
params:
root: !join [*root_dir, '/val']
transform_params:
- type: 'Resize'
params:
size: 256
- type: 'CenterCrop'
params:
size: *input_size
- *totensor
- *normalize
models:
teacher_model:
name: &teacher_model_name 'maskedvit_base_patch16_224'
params:
num_classes: 1000
pretrained: True
mask_ratio: 0.0
experiment: &teacher_experiment !join [*dataset_name, '-', *teacher_model_name]
ckpt: !join ['./resource/ckpt/ilsvrc2012/teacher/', *teacher_experiment, '.pt']
student_model:
name: &student_model_name 'maskedvit_base_patch16_224'
params:
num_classes: 1000
pretrained: False
mask_ratio: 0.5
experiment: &student_experiment !join [*dataset_name, '-', *student_model_name, '_from_', *teacher_model_name]
ckpt: !join ['./imagenet/mask_distillation/', *student_experiment, '.pt']
train:
log_freq: 1000
num_epochs: 100
train_data_loader:
dataset_id: *imagenet_train
random_sample: True
batch_size: 64
num_workers: 16
cache_output:
val_data_loader:
dataset_id: *imagenet_val
random_sample: False
batch_size: 128
num_workers: 16
teacher:
sequential: []
forward_hook:
input: []
output: ['mask_filter']
wrapper: 'DataParallel'
requires_grad: False
student:
adaptations:
sequential: []
frozen_modules: []
forward_hook:
input: []
output: ['mask_filter']
wrapper: 'DistributedDataParallel'
requires_grad: True
optimizer:
type: 'SGD'
grad_accum_step: 16
max_grad_norm: 5.0
module_wise_params:
- params: ['mask_token', 'cls_token', 'pos_embed']
is_teacher: None
module: None
weight_decay: 0.0
params:
lr: 0.001
momentum: 0.9
weight_decay: 0.0001
scheduler:
type: 'MultiStepLR'
params:
milestones: [30, 60, 90]
gamma: 0.1
criterion:
type: 'GeneralizedCustomLoss'
org_term:
criterion:
type: 'CrossEntropyLoss'
params:
reduction: 'mean'
factor: 1.0
sub_terms:
GenerativeKDLoss:
criterion:
type: 'GenerativeKDLoss'
params:
student_module_io: 'output'
student_module_path: 'mask_filter'
teacher_module_io: 'output'
teacher_module_path: 'mask_filter'
factor: 1.0
test:
test_data_loader:
dataset_id: *imagenet_val
random_sample: False
batch_size: 1
num_workers: 16
- Log file
(pytorch_1) root@baa8ef5448b2:/workspace/sync/torchdistill# accelerate launch examples/legacy/image_classification_accelerate.py --config /workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml
2023/08/15 02:49:09 INFO __main__ Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 0
2023/08/15 02:49:09 INFO __main__ Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 1
2023/08/15 02:49:09 INFO __main__ Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09 INFO __main__ Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1)
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 2
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 3
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
2023/08/15 02:49:09 INFO __main__ Distributed environment: MULTI_GPU Backend: nccl
Num processes: 4
Process index: 0
Local process index: 0
Device: cuda:0
Mixed precision type: fp16
2023/08/15 02:49:09 INFO torchdistill.datasets.util Loading train data
2023/08/15 02:49:12 INFO torchdistill.datasets.util dataset_id `ilsvrc2012/train`: 2.874385356903076 sec
2023/08/15 02:49:12 INFO torchdistill.datasets.util Loading val data
2023/08/15 02:49:12 INFO torchdistill.datasets.util dataset_id `ilsvrc2012/val`: 0.12787175178527832 sec
2023/08/15 02:49:15 INFO timm.models._builder Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2023/08/15 02:49:16 INFO timm.models._hub [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2023/08/15 02:49:16 INFO torchdistill.common.main_util ckpt file is not found at `./resource/ckpt/ilsvrc2012/teacher/ilsvrc2012-maskedvit_base_patch16_224.pt`
2023/08/15 02:49:18 INFO torchdistill.common.main_util ckpt file is not found at `./imagenet/mask_distillation/ilsvrc2012-maskedvit_base_patch16_224_from_maskedvit_base_patch16_224.pt`
2023/08/15 02:49:18 INFO __main__ Start training
2023/08/15 02:49:18 INFO torchdistill.models.util [teacher model]
2023/08/15 02:49:18 INFO torchdistill.models.util Using the original teacher model
2023/08/15 02:49:18 INFO torchdistill.models.util [student model]
2023/08/15 02:49:18 INFO torchdistill.models.util Using the original student model
2023/08/15 02:49:18 INFO torchdistill.core.distillation Loss = 1.0 * OrgLoss + 1.0 * GenerativeKDLoss(
(cross_entropy_loss): CrossEntropyLoss()
(SmoothL1Loss): SmoothL1Loss()
)
2023/08/15 02:49:18 INFO torchdistill.core.distillation Freezing the whole teacher model
2023/08/15 02:49:18 INFO torchdistill.common.module_util `None` of `None` could not be reached in `DataParallel`
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
warnings.warn(
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
warnings.warn(
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
warnings.warn(
/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use `AcceleratorState.mixed_precision == 'fp16'` instead.
warnings.warn(
2023/08/15 02:49:24 INFO torchdistill.misc.log Epoch: [0] [ 0/5005] eta: 8:39:24 lr: 0.001 img/s: 21.99282017795937 loss: 0.4513 (0.4513) time: 6.2267 data: 3.3162 max mem: 8400
2023/08/15 02:49:24 INFO torch.nn.parallel.distributed Reducer buckets have been rebuilt in this iteration.
2023/08/15 02:49:24 INFO torch.nn.parallel.distributed Reducer buckets have been rebuilt in this iteration.
Traceback (most recent call last):
File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
Traceback (most recent call last):
File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
main(argparser.parse_args())
File "examples/legacy/image_classification_accelerate.py", line 198, in main
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
File "examples/legacy/image_classification_accelerate.py", line 129, in train
train_one_epoch(training_box, device, epoch, log_freq)
File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
main(argparser.parse_args())
File "examples/legacy/image_classification_accelerate.py", line 198, in main
training_box.update_params(loss)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
File "examples/legacy/image_classification_accelerate.py", line 129, in train
self.optimizer.step()
train_one_epoch(training_box, device, epoch, log_freq) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
training_box.update_params(loss)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
self.optimizer.step()
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
self.scaler.step(self.optimizer, closure)self.scaler.step(self.optimizer, closure)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionErrorAssertionError: : No inf checks were recorded for this optimizer.No inf checks were recorded for this optimizer.
Traceback (most recent call last):
File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
main(argparser.parse_args())
File "examples/legacy/image_classification_accelerate.py", line 198, in main
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
File "examples/legacy/image_classification_accelerate.py", line 129, in train
train_one_epoch(training_box, device, epoch, log_freq)
File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
training_box.update_params(loss)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
self.optimizer.step()
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
self.scaler.step(self.optimizer, closure)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
Traceback (most recent call last):
File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
main(argparser.parse_args())
File "examples/legacy/image_classification_accelerate.py", line 198, in main
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
File "examples/legacy/image_classification_accelerate.py", line 129, in train
train_one_epoch(training_box, device, epoch, log_freq)
File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
training_box.update_params(loss)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
self.optimizer.step()
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
self.scaler.step(self.optimizer, closure)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 3701268) of binary: /root/miniconda3/envs/pytorch_1/bin/python
Traceback (most recent call last):
File "/root/miniconda3/envs/pytorch_1/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
args.func(args)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/launch.py", line 970, in launch_command
multi_gpu_launcher(args)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/launch.py", line 646, in multi_gpu_launcher
distrib_run.run(args)
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run
elastic_launch(
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
examples/legacy/image_classification_accelerate.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2023-08-15_02:49:37
host : baa8ef5448b2
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 3701269)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
time : 2023-08-15_02:49:37
host : baa8ef5448b2
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 3701270)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
time : 2023-08-15_02:49:37
host : baa8ef5448b2
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 3701271)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2023-08-15_02:49:37
host : baa8ef5448b2
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 3701268)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
Expected behavior
A clear and concise description of what you expected to happen.
Environment (please complete the following information):
- OS: Ubuntu 22.04 LTS
- Python ver.3.8
- torchdistill ver. v0.3.3
(pytorch_1) root@baa8ef5448b2:/workspace/sync/torchdistill# conda list
# packages in environment at /root/miniconda3/envs/pytorch_1:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main defaults
_openmp_mutex 5.1 1_gnu defaults
accelerate 0.21.0 pypi_0 pypi
blas 1.0 mkl defaults
brotlipy 0.7.0 py38h27cfd23_1003 defaults
bzip2 1.0.8 h7b6447c_0 defaults
ca-certificates 2023.05.30 h06a4308_0 defaults
certifi 2023.7.22 py38h06a4308_0 defaults
cffi 1.15.1 py38h5eee18b_3 defaults
charset-normalizer 2.0.4 pyhd3eb1b0_0 defaults
contourpy 1.1.0 pypi_0 pypi
cryptography 41.0.2 py38h22a60cf_0 defaults
cuda-cudart 11.7.99 0 nvidia
cuda-cupti 11.7.101 0 nvidia
cuda-libraries 11.7.1 0 nvidia
cuda-nvrtc 11.7.99 0 nvidia
cuda-nvtx 11.7.91 0 nvidia
cuda-runtime 11.7.1 0 nvidia
cycler 0.11.0 pypi_0 pypi
cython 3.0.0 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.12.2 pypi_0 pypi
fonttools 4.42.0 pypi_0 pypi
freetype 2.12.1 h4a9f257_0 defaults
fsspec 2023.6.0 pypi_0 pypi
future 0.18.3 py38h06a4308_0 defaults
giflib 5.2.1 h5eee18b_3 defaults
gmp 6.2.1 h295c915_3 defaults
gnutls 3.6.15 he1e5248_0 defaults
huggingface-hub 0.16.4 pypi_0 pypi
idna 3.4 py38h06a4308_0 defaults
importlib-resources 6.0.1 pypi_0 pypi
intel-openmp 2023.1.0 hdb19cb5_46305 defaults
jpeg 9e h5eee18b_1 defaults
kiwisolver 1.4.4 pypi_0 pypi
lame 3.100 h7b6447c_0 defaults
lcms2 2.12 h3be6417_0 defaults
ld_impl_linux-64 2.38 h1181459_1 defaults
lerc 3.0 h295c915_0 defaults
libcublas 11.10.3.66 0 nvidia
libcufft 10.7.2.124 h4fbf590_0 nvidia
libcufile 1.7.1.12 0 nvidia
libcurand 10.3.3.129 0 nvidia
libcusolver 11.4.0.1 0 nvidia
libcusparse 11.7.4.91 0 nvidia
libdeflate 1.17 h5eee18b_0 defaults
libffi 3.4.4 h6a678d5_0 defaults
libgcc-ng 11.2.0 h1234567_1 defaults
libgfortran-ng 11.2.0 h00389a5_1 defaults
libgfortran5 11.2.0 h1234567_1 defaults
libgomp 11.2.0 h1234567_1 defaults
libiconv 1.16 h7f8727e_2 defaults
libidn2 2.3.4 h5eee18b_0 defaults
libnpp 11.7.4.75 0 nvidia
libnvjpeg 11.8.0.2 0 nvidia
libopenblas 0.3.21 h043d6bf_0 defaults
libpng 1.6.39 h5eee18b_0 defaults
libprotobuf 3.20.3 he621ea3_0 defaults
libstdcxx-ng 11.2.0 h1234567_1 defaults
libtasn1 4.19.0 h5eee18b_0 defaults
libtiff 4.5.0 h6a678d5_2 defaults
libunistring 0.9.10 h27cfd23_0 defaults
libwebp 1.2.4 h11a3e52_1 defaults
libwebp-base 1.2.4 h5eee18b_1 defaults
lz4-c 1.9.4 h6a678d5_0 defaults
matplotlib 3.7.2 pypi_0 pypi
mkl 2023.1.0 h213fc3f_46343 defaults
mkl-service 2.4.0 py38h5eee18b_1 defaults
mkl_fft 1.3.6 py38h417a72b_1 defaults
mkl_random 1.2.2 py38h417a72b_1 defaults
ncurses 6.4 h6a678d5_0 defaults
nettle 3.7.3 hbbd107a_1 defaults
ninja 1.10.2 h06a4308_5 defaults
ninja-base 1.10.2 hd09550d_5 defaults
numpy 1.24.3 py38hf6e8229_1 defaults
numpy-base 1.24.3 py38h060ed82_1 defaults
openh264 2.1.1 h4ff587b_0 defaults
openssl 3.0.10 h7f8727e_0 defaults
packaging 23.1 pypi_0 pypi
pillow 9.4.0 py38h6a678d5_0 defaults
pip 23.2.1 py38h06a4308_0 defaults
psutil 5.9.5 pypi_0 pypi
pycocotools 2.0.6 pypi_0 pypi
pycparser 2.21 pyhd3eb1b0_0 defaults
pyopenssl 23.2.0 py38h06a4308_0 defaults
pyparsing 3.0.9 pypi_0 pypi
pysocks 1.7.1 py38h06a4308_0 defaults
python 3.8.17 h955ad1f_0 defaults
python-dateutil 2.8.2 pypi_0 pypi
pytorch 1.13.0 py3.8_cuda11.7_cudnn8.5.0_0 pytorch
pytorch-cuda 11.7 h778d358_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pyyaml 6.0 py38h5eee18b_1 defaults
readline 8.2 h5eee18b_0 defaults
requests 2.31.0 py38h06a4308_0 defaults
safetensors 0.3.2 pypi_0 pypi
scipy 1.10.1 pypi_0 pypi
setuptools 68.0.0 py38h06a4308_0 defaults
six 1.16.0 pypi_0 pypi
sqlite 3.41.2 h5eee18b_0 defaults
tbb 2021.8.0 hdb19cb5_0 defaults
timm 0.9.5 pypi_0 pypi
tk 8.6.12 h1ccaba5_0 defaults
torchaudio 0.13.0 py38_cu117 pytorch
torchdistill 0.3.3 pypi_0 pypi
torchvision 0.14.0 py38_cu117 pytorch
tqdm 4.66.1 pypi_0 pypi
typing-extensions 4.7.1 py38h06a4308_0 defaults
typing_extensions 4.7.1 py38h06a4308_0 defaults
urllib3 1.26.16 py38h06a4308_0 defaults
wheel 0.38.4 py38h06a4308_0 defaults
xz 5.4.2 h5eee18b_0 defaults
yaml 0.2.5 h7b6447c_0 defaults
zipp 3.16.2 pypi_0 pypi
zlib 1.2.13 h5eee18b_0 defaults
zstd 1.5.5 hc292b87_0 defaults
Additional context
Add any other context about the problem here.
Hi @jsrdcht
Since you made changes in code and did not share the actual code, I cannot confirm that this is a bug from torchdistill.
If you're trying to introduce new components and still at trial-and-error phase, please use Discussions instead, and provide your modified code. You also still keep this discussion unanswered for this topic.
From the error message, I guess that you didn't pass accelerator when instantiating distillation_box or training_box, and then missed the following lines to be executed inside distillation_box or training_box
https://github.com/yoshitomo-matsubara/torchdistill/blob/v0.3.3/torchdistill/core/distillation.py#L306-L307
Here's my code. Not because of the reason you guessed, I have debugged it and saw that line of code being executed.
import argparse
import datetime
import logging
import os
import time
import torch
import vit
import custom_loss
from accelerate import Accelerator, DistributedType
from torch import distributed as dist
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torchdistill.common import file_util, yaml_util, module_util
from torchdistill.common.constant import def_logger
from torchdistill.common.main_util import is_main_process, init_distributed_mode, load_ckpt, save_ckpt, set_seed, setup_for_distributed
from torchdistill.core.distillation import get_distillation_box
from torchdistill.core.training import get_training_box
from torchdistill.datasets import util
from torchdistill.eval.classification import compute_accuracy
from torchdistill.misc.log import setup_log_file, SmoothedValue, MetricLogger
from torchdistill.models.official import get_image_classification_model
from torchdistill.models.registry import get_model
logger = def_logger.getChild(__name__)
def get_argparser():
parser = argparse.ArgumentParser(description='Knowledge distillation for image classification models')
parser.add_argument('--config', required=True, help='yaml file path')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('--log', help='log file path')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
parser.add_argument('--seed', type=int, help='seed in random number generator')
parser.add_argument('-test_only', action='store_true', help='only test the models')
parser.add_argument('-student_only', action='store_true', help='test the student model only')
parser.add_argument('-log_config', action='store_true', help='log config')
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('-adjust_lr', action='store_true',
help='multiply learning rate by number of distributed processes (world_size)')
return parser
def load_model(model_config, device, distributed):
model = get_image_classification_model(model_config, distributed)
if model is None:
repo_or_dir = model_config.get('repo_or_dir', None)
model = get_model(model_config['name'], repo_or_dir, **model_config['params'])
ckpt_file_path = model_config['ckpt']
load_ckpt(ckpt_file_path, model=model, strict=True)
return model.to(device)
def train_one_epoch(training_box, device, epoch, log_freq):
metric_logger = MetricLogger(delimiter=' ')
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
for sample_batch, targets, supp_dict in \
metric_logger.log_every(training_box.train_data_loader, log_freq, header):
start_time = time.time()
sample_batch, targets = sample_batch.to(device), targets.to(device)
loss = training_box(sample_batch, targets, supp_dict)
training_box.update_params(loss)
batch_size = sample_batch.shape[0]
metric_logger.update(loss=loss.item(), lr=training_box.optimizer.param_groups[0]['lr'])
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
if (torch.isnan(loss) or torch.isinf(loss)) and is_main_process():
raise ValueError('The training loop was broken due to loss = {}'.format(loss))
@torch.inference_mode()
def evaluate(model, data_loader, device, device_ids, distributed, log_freq=1000, title=None, header='Test:'):
model.to(device)
if distributed:
model = DistributedDataParallel(model, device_ids=device_ids)
elif device.type.startswith('cuda'):
model = DataParallel(model, device_ids=device_ids)
if title is not None:
logger.info(title)
model.eval()
metric_logger = MetricLogger(delimiter=' ')
for image, target in metric_logger.log_every(data_loader, log_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
acc1, acc5 = compute_accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
top1_accuracy = metric_logger.acc1.global_avg
top5_accuracy = metric_logger.acc5.global_avg
logger.info(' * Acc@1 {:.4f}\tAcc@5 {:.4f}\n'.format(top1_accuracy, top5_accuracy))
return metric_logger.acc1.global_avg
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator):
logger.info('Start training')
train_config = config['train']
lr_factor = args.world_size if distributed and args.adjust_lr else 1
training_box = get_training_box(student_model, dataset_dict, train_config,
device, device_ids, distributed, lr_factor, accelerator) if teacher_model is None \
else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
device, device_ids, distributed, lr_factor, accelerator = accelerator)
best_val_top1_accuracy = 0.0
optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
if file_util.check_if_exists(ckpt_file_path):
best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)
log_freq = train_config['log_freq']
student_model_without_ddp = student_model.module if module_util.check_if_wrapped(student_model) else student_model
start_time = time.time()
for epoch in range(args.start_epoch, training_box.num_epochs):
training_box.pre_process(epoch=epoch)
train_one_epoch(training_box, device, epoch, log_freq)
val_top1_accuracy = evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed,
log_freq=log_freq, header='Validation:')
if val_top1_accuracy > best_val_top1_accuracy and is_main_process():
logger.info('Best top-1 accuracy: {:.4f} -> {:.4f}'.format(best_val_top1_accuracy, val_top1_accuracy))
logger.info('Updating ckpt at {}'.format(ckpt_file_path))
best_val_top1_accuracy = val_top1_accuracy
if distributed is False and accelerator is not None:
student_model_without_ddp = accelerator.unwrap_model(student_model)
save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
best_val_top1_accuracy, config, args, ckpt_file_path)
training_box.post_process()
if distributed:
dist.barrier()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))
training_box.clean_modules()
def main(args):
log_file_path = args.log
if is_main_process() and log_file_path is not None:
setup_log_file(os.path.expanduser(log_file_path))
logger.info(args)
cudnn.benchmark = True
set_seed(args.seed)
config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
# distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator(mixed_precision='fp16')
distributed = accelerator.state.distributed_type == DistributedType.MULTI_GPU
device_ids = [accelerator.device.index]
if distributed:
setup_for_distributed(is_main_process())
logger.info(accelerator.state)
device = accelerator.device
# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
# device = torch.device(args.device)
dataset_dict = util.get_all_datasets(config['datasets'])
models_config = config['models']
teacher_model_config = models_config.get('teacher_model', None)
teacher_model =\
load_model(teacher_model_config, device, distributed) if teacher_model_config is not None else None
student_model_config =\
models_config['student_model'] if 'student_model' in models_config else models_config['model']
ckpt_file_path = student_model_config['ckpt']
student_model = load_model(student_model_config, device, distributed)
if accelerator is not None:
student_model, teacher_model = accelerator.prepare(student_model, teacher_model)
for name, dataset in dataset_dict.items():
dataset_dict[name] = accelerator.prepare(dataset)
if args.log_config:
logger.info(config)
if not args.test_only:
train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
student_model_without_ddp =\
student_model.module if module_util.check_if_wrapped(student_model) else student_model
load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True)
test_config = config['test']
test_data_loader_config = test_config['test_data_loader']
test_data_loader = util.build_data_loader(dataset_dict[test_data_loader_config['dataset_id']],
test_data_loader_config, distributed)
log_freq = test_config.get('log_freq', 1000)
if not args.student_only and teacher_model is not None:
evaluate(teacher_model, test_data_loader, device, device_ids, distributed, log_freq=log_freq,
title='[Teacher: {}]'.format(teacher_model_config['name']))
evaluate(student_model, test_data_loader, device, device_ids, distributed, log_freq=log_freq,
title='[Student: {}]'.format(student_model_config['name']))
if __name__ == '__main__':
argparser = get_argparser()
main(argparser.parse_args())
module_wise_params:
- params: ['mask_token', 'cls_token', 'pos_embed']
is_teacher: None
module: None
weight_decay: 0.0
Your module_wise_params
entry looks broken, and use is_teacher: False
instead of None (or you can skip is_teacher, as the default value is False)
See https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/legacy/sample/pascal_voc2012/ce/deeplabv3_resnet101.yaml#L102-L109 for the format
Again, please use Discussions for this kind of question since it doesn't look like a bug from torchdistill
It is the issues you marked that caused the strange errors. I also made some changes in the source code to adapt to my configuration.
Here are some problems that exist in the source code:
def pre_process(self, epoch=None, **kwargs):
clear_io_dict(self.teacher_io_dict)
clear_io_dict(self.student_io_dict)
self.teacher_model.eval()
self.student_model.train()
if self.distributed and self.accelerator is None: # batch_sampler.sampler is only valid for ddp without accelerator
self.train_data_loader.batch_sampler.sampler.set_epoch(epoch)
# Set up accelerator if necessary
if self.accelerator is not None:
if self.teacher_updatable:
self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
self.train_data_loader, self.val_data_loader)
else:
# self.teacher_model = self.teacher_model.to(self.accelerator.device)
# if self.accelerator.state.use_fp16:
# self.teacher_model = self.teacher_model.half()
# sice fp16 is took by accelerate, we have to warp the teacher model otherwise the input can't be casted to fp16
self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
self.train_data_loader, self.val_data_loader)
module_wise_params_configs = optim_config.get('module_wise_params', list())
if len(module_wise_params_configs) > 0:
trainable_module_list = list()
for module_wise_params_config in module_wise_params_configs:
module_wise_params_dict = dict()
if isinstance(module_wise_params_config.get('params', None), dict):
module_wise_params_dict.update(module_wise_params_config['params'])
if 'lr' in module_wise_params_dict:
module_wise_params_dict['lr'] *= self.lr_factor
target_model = \
self.teacher_model if module_wise_params_config.get('is_teacher', False) else self.student_model
module = get_module(target_model, module_wise_params_config['module'])
# support for nn.Parameter()
module_wise_params_dict['params'] = module.parameters() if isinstance(module, nn.Module) else [module]
trainable_module_list.append(module_wise_params_dict)
else:
trainable_module_list = nn.ModuleList([self.student_model])
if self.teacher_updatable:
logger.info('Note that you are training some/all of the modules in the teacher model')
trainable_module_list.append(self.teacher_model)