Some questions about cls_weight in rein_dinov2_mask2former.py
Closed this issue · 3 comments
Dear author,
Thanks for ur share with ur code. But when I trained the head with the dataset REFUGE2, I find some problem here:
class_weight=[1.0] * num_classes + [0.1]
I noticed u set num_classes = 19, but I dont know why need to set a list at the length of 19+1 for class_weight
and I set reduce_zero_label = True for REFUGE2 dataset, and the classes=('background', ' Optic Cup', 'Optic Disc'), I set num_class = 2 for model, but I only get 'background' and 'Optic Cup' in my segmentation mask.
when I set num_class = 3 and class_weight = [0.1 ,1, 2], the index is out of boundary , I check my label:
class_weight: tensor([0.1000, 1.0000, 2.0000], device='cuda:0')
label:tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
I dont know why my label got a pixel at 3, I have transfered all the pixes to 0~2 before training
If you set reduce_zero_label=True
, the code will perform the following actions:
- Replace all occurrences of 0 with 255.
- Replace 1 with 0, 2 with 1, and so on.
- When
num_classes=2
, the model will only predict classes 0 and 1.
My recommended configuration is as follows:
- If 0 represents the background and not unlabeled regions, you should set
reduce_zero_label=False
andnum_classes=3
. - Otherwise, if 0 represents unlabeled regions, you should set
reduce_zero_label=True
andnum_classes=2
. Also, defineCLASSES=['Optic Cup', 'Optic Disc']
. - You should maintain
class_weight=[1.0] * num_classes + [0.1]
.