open-mmlab/mmengine

activation_checkpointing 导致权重无法更新

Closed this issue · 1 comments

Prerequisite

Environment

OrderedDict([('sys.platform', 'linux'), ('Python', '3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51) [GCC 9.4.0]'), ('CUDA available', True), ('numpy_random_seed', 2147483648), ('GPU 0,1', 'NVIDIA A100-SXM4-80GB'), ('CUDA_HOME', '/usr/local/cuda'), ('NVCC', 'Cuda compilation tools, release 11.6, V11.6.55'), ('GCC', 'gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0'), ('PyTorch', '1.13.1'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 9.3\n - C++ Version: 201402\n - Intel(R) Math Kernel Library Version 2019.0.5 Product Build 20190808 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX2\n - CUDA Runtime 11.6\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37\n - CuDNN 8.3.2 (built against CUDA 11.5)\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.6, CUDNN_VERSION=8.3.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n'), ('TorchVision', '0.14.1'), ('OpenCV', '4.6.0'), ('MMEngine', '0.9.1')])

Reproduces the problem - code sample

activation_checkpointing=['img_backbone']
model = dict(
type='BEVC',
num_views=num_views,
stop_extra_grad=False,
data_preprocessor=dict(
type='Det4DDataPreprocessor',
mean=[103.530, 116.280, 123.675],
std=[57.375, 57.120, 58.395],
voxel=True,
voxel_type='dynamic',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(-1, -1)
)
),
img_backbone=dict(
type='ResNet',
scope='mmdet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
with_cp=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
),

Reproduces the problem - command or script

no

Reproduces the problem - error message

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameters which did not receive grad for rank 1: img_backbone.layer4.2.bn3.bias, img_backbone.layer4.2.bn3.weight, img_backbone.layer4.2.conv3.weight, img_backbone.layer4.2.bn2.bias, img_backbone.layer4.2.bn2.weight, img_backbone.layer4.2.conv2.weight, img_backbone.layer4.2.bn1.bias, img_backbone.layer4.2.bn1.weight, img_backbone.layer4.2.conv1.weight, img_backbone.layer4.1.bn3.bias, img_backbone.layer4.1.bn3.weight, img_backbone.layer4.1.conv3.weight, img_backbone.layer4.1.bn2.bias, img_backbone.layer4.1.bn2.weight, img_backbone.layer4.1.conv2.weight, img_backbone.layer4.1.bn1.bias, img_backbone.layer4.1.bn1.weight, img_backbone.layer4.1.conv1.weight, img_backbone.layer4.0.downsample.1.bias, img_backbone.layer4.0.downsample.1.weight, img_backbone.layer4.0.downsample.0.weight, img_backbone.layer4.0.bn3.bias, img_backbone.layer4.0.bn3.weight, img_backbone.layer4.0.conv3.weight, img_backbone.layer4.0.bn2.bias, img_backbone.layer4.0.bn2.weight, img_backbone.layer4.0.conv2.weight, img_backbone.layer4.0.bn1.bias, img_backbone.layer4.0.bn1.weight, img_backbone.layer4.0.conv1.weight, img_backbone.layer3.5.bn3.bias, img_backbone.layer3.5.bn3.weight, img_backbone.layer3.5.conv3.weight, img_backbone.layer3.5.bn2.bias, img_backbone.layer3.5.bn2.weight, img_backbone.layer3.5.conv2.weight, img_backbone.layer3.5.bn1.bias, img_backbone.layer3.5.bn1.weight, img_backbone.layer3.5.conv1.weight, img_backbone.layer3.4.bn3.bias, img_backbone.layer3.4.bn3.weight, img_backbone.layer3.4.conv3.weight, img_backbone.layer3.4.bn2.bias, img_backbone.layer3.4.bn2.weight, img_backbone.layer3.4.conv2.weight, img_backbone.layer3.4.bn1.bias, img_backbone.layer3.4.bn1.weight, img_backbone.layer3.4.conv1.weight, img_backbone.layer3.3.bn3.bias, img_backbone.layer2.0.bn3.bias, img_backbone.layer2.0.bn3.weight, img_backbone.layer2.0.conv3.weight, img_backbone.layer2.0.bn2.bias, img_backbone.layer2.0.bn2.weight, img_backbone.layer2.0.conv2.weight, img_backbone.layer2.0.bn1.bias, img_backbone.layer2.0.bn1.weight, img_backbone.layer2.0.conv1.weight, img_backbone.layer2.0.downsample.0.weight, img_backbone.layer2.0.downsample.1.weight, img_backbone.layer2.0.downsample.1.bias, img_backbone.layer2.1.conv1.weight, img_backbone.layer2.1.bn1.weight, img_backbone.layer2.1.bn1.bias, img_backbone.layer2.1.conv2.weight, img_backbone.layer2.1.bn2.weight, img_backbone.layer2.1.bn2.bias, img_backbone.layer2.1.conv3.weight, img_backbone.layer2.1.bn3.weight, img_backbone.layer2.1.bn3.bias, img_backbone.layer2.2.conv1.weight, img_backbone.layer2.2.bn1.weight, img_backbone.layer2.2.bn1.bias, img_backbone.layer2.2.conv2.weight, img_backbone.layer2.2.bn2.weight, img_backbone.layer2.2.bn2.bias, img_backbone.layer2.2.conv3.weight, img_backbone.layer2.2.bn3.weight, img_backbone.layer2.2.bn3.bias, img_backbone.layer2.3.conv1.weight, img_backbone.layer2.3.bn1.weight, img_backbone.layer2.3.bn1.bias, img_backbone.layer2.3.conv2.weight, img_backbone.layer2.3.bn2.weight, img_backbone.layer2.3.bn2.bias, img_backbone.layer2.3.conv3.weight, img_backbone.layer2.3.bn3.weight, img_backbone.layer2.3.bn3.bias, img_backbone.layer3.0.conv1.weight, img_backbone.layer3.0.bn1.weight, img_backbone.layer3.0.bn1.bias, img_backbone.layer3.0.conv2.weight, img_backbone.layer3.0.bn2.weight, img_backbone.layer3.0.bn2.bias, img_backbone.layer3.0.conv3.weight, img_backbone.layer3.0.bn3.weight, img_backbone.layer3.0.bn3.bias, img_backbone.layer3.0.downsample.0.weight, img_backbone.layer3.0.downsample.1.weight, img_backbone.layer3.0.downsample.1.bias, img_backbone.layer3.1.conv1.weight, img_backbone.layer3.1.bn1.weight, img_backbone.layer3.1.bn1.bias, img_backbone.layer3.1.conv2.weight, img_backbone.layer3.1.bn2.weight, img_backbone.layer3.1.bn2.bias, img_backbone.layer3.1.conv3.weight, img_backbone.layer3.1.bn3.weight, img_backbone.layer3.1.bn3.bias, img_backbone.layer3.2.conv1.weight, img_backbone.layer3.2.bn1.weight, img_backbone.layer3.2.bn1.bias, img_backbone.layer3.2.conv2.weight, img_backbone.layer3.2.bn2.weight, img_backbone.layer3.2.bn2.bias, img_backbone.layer3.2.conv3.weight, img_backbone.layer3.2.bn3.weight, img_backbone.layer3.2.bn3.bias, img_backbone.layer3.3.conv1.weight, img_backbone.layer3.3.bn1.weight, img_backbone.layer3.3.bn1.bias, img_backbone.layer3.3.conv2.weight, img_backbone.layer3.3.bn2.weight, img_backbone.layer3.3.bn2.bias, img_backbone.layer3.3.conv3.weight, img_backbone.layer3.3.bn3.weight

Additional information

如果不添加 “activation_checkpointing=['img_backbone']”, 一切都很正常,当添加后出现上述报错。

When using activation_checkpointing, you should specify certain layers, like “activation_checkpointing=['img_backbone.layer2.0', 'img_backbone.layer4.0']”, instead of using the whole model.