ZhengPeng7/BiRefNet

RuntimeError: CUDA error: device-side assert triggered

Closed this issue · 8 comments

when i use my own datasets,it trains somes epoches,interrupt and encountered this problem:
3333

How many epochs? More than one?
Besides, this error is most likely to occur in the BCE loss here. You can print the inputs of BCE losses to have a check.

How many epochs? More than one? Besides, this error is most likely to occur in the BCE loss here. You can print the inputs of BCE losses to have a check.

about 20 to 30 epoches

Then, you may have a check in the inputs of BCE losses and the GT data. Make sure they are all between 0 and 1.

Hi, have you found the problem? Is that caused by some of my codes?

@ZhengPeng7 I got the same error message when training with custom private dataset (~1500 images).

Below is the log output and traceback. Do you have any idea?

My environment

  • OS: Ubuntu 20.04
  • GPU: RTX 3090
  • python 3.10.14
  • python packages
    • torch==2.0.1+cu118
    • torchvision==0.15.2+cu118

The command to run is CUDA_LAUNCH_BLOCKING=1 ./train.sh
Because cuda was set to blocking mode, the traceback should be meaningful (where error happened).

...
2024-06-26 20:25:25,932 INFO Epoch[22/30] Iter[420/754]. Training Losses, loss_pix: 10.148
2024-06-26 20:25:40,913 INFO Epoch[22/30] Iter[440/754]. Training Losses, loss_pix: 12.462
...
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [79,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [80,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [81,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [82,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [83,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [84,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [85,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [86,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [87,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [88,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [89,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [90,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [91,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [92,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [93,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [94,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [95,0,0] Assertion `input_val >= zero && input_val <= one` failed.
../aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [42,0,0], thread: [95,0,0] Assertion `input_val >= zero && input_val <= one` failed.
Traceback (most recent call last):
  File "BiRefNet/train.py", line 427, in <module>
    main()
  File "BiRefNet/train.py", line 404, in main
    train_loss = trainer.train_epoch(epoch)
  File "BiRefNet/train.py", line 329, in train_epoch
    self._train_batch(batch)
  File "BiRefNet/train.py", line 222, in _train_batch
    loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt
  File "/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 619, in forward
    return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  File "/lib/python3.10/site-packages/torch/nn/functional.py", line 3098, in binary_cross_entropy
    return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Dump of configurations values from config.py

{
    "IoU_finetune_last_epochs": -20,
    "SDPA_enabled": false,
    "allow_tf32": true,
    "auxiliary_classification": false,
    "batch_size": 2,
    "batch_size_valid": 1,
    "bb": "swin_v1_t",
    "compile": true,
    "compile_mode": "default",
    "cudnn_benchmark": false,
    "cxt": [
        192,
        384,
        768
    ],
    "cxt_num": 3,
    "data_root_dir": "PIIs",
    "dec_att": "ASPPDeformable",
    "dec_blk": "BasicDecBlk",
    "dec_channels_inter": "fixed",
    "dec_ipt": true,
    "dec_ipt_split": true,
    "device": 0,
    "enable_clearml": true,
    "ender": "",
    "freeze_bb": false,
    "lambda_adv_d": 0.0,
    "lambda_adv_g": 0.0,
    "lambdas_cls": {
        "ce": 5.0
    },
    "lambdas_pix_last": {
        "bce": 30,
        "cnt": 0,
        "iou": 0.5,
        "iou_patch": 0.0,
        "mse": 0,
        "reg": 0,
        "ssim": 10,
        "structure": 0,
        "triplet": 0
    },
    "lat_blk": "BasicLatBlk",
    "lateral_channels_in_collection": [
        1536,
        768,
        384,
        192
    ],
    "load_all": true,
    "lr": 7.071067811865476e-06,
    "lr_decay_epochs": [
        100000.0
    ],
    "lr_decay_rate": 0.5,
    "model": "BiRefNet",
    "ms_supervision": true,
    "mul_scl_ipt": "cat",
    "num_workers": 8,
    "only_S_MAE": false,
    "optimizer": "AdamW",
    "out_ref": true,
    "precisionHigh": true,
    "preproc_methods": [
        "flip",
        "enhance",
        "rotate",
        "pepper"
    ],
    "progressive_ref": "",
    "prompt4loc": false,
    "rand_seed": 7,
    "refine": "",
    "refine_iteration": 1,
    "save_last": 10,
    "save_step": 2,
    "scale": "",
    "size": 1024,
    "squeeze_block": "BasicDecBlk_x1",
    "sys_home_dir": "PIIs",
    "task": "fingerprint",
    "training_set": "synthetic_v2_train",
    "use_fp16": false,
    "val_step": 2,
    "verbose_eval": true,
    "weights": {
        "pvt_v2_b0": "PIIs/weights/pvt_v2_b0.pth",
        "pvt_v2_b1": "PIIs/weights/pvt_v2_b1.pth",
        "pvt_v2_b2": "PIIs/weights/pvt_v2_b2.pth",
        "pvt_v2_b5": "PIIs/weights/pvt_v2_b5.pth",
        "swin_v1_b": "PIIs/weights/swin_base_patch4_window12_384_22kto1k.pth",
        "swin_v1_l": "PIIs/weights/swin_large_patch4_window12_384_22kto1k.pth",
        "swin_v1_s": "/PIIs/weights/swin_small_patch4_window7_224_22kto1k_finetune.pth",
        "swin_v1_t": "/PIIs/weights/swin_tiny_patch4_window7_224_22kto1k_finetune.pth"
    },
    "weights_root_dir": "PIIs"
}

@ZhengPeng7 Update:

After a closer look at the code, I think it's my bad!

I refactored the code to use the same block to train, whether it is using amp fp16 or not. However, after the refactoring I forgot to apply _gdt_label = _gdt_label.sigmoid() before calcuting BCE loss if not using fp16.

Your code is correctly applying sigmoid() before calculating the loss.
Therefore, it's my bad that I introduced the error during refacorting.

I'll rerun the training again and let you know if the error disappeared.

@ZhengPeng7 The training completed successfully 🎉 . I've confirmed that the error has disappeared.

Aha, that's good! I just woke up to see your message. Feel free to leave more issues if you encounter more problems.