w1oves/Rein

Some questions about cls_weight in rein_dinov2_mask2former.py

Closed this issue · 3 comments

Dear author,
Thanks for ur share with ur code. But when I trained the head with the dataset REFUGE2, I find some problem here:
class_weight=[1.0] * num_classes + [0.1]
I noticed u set num_classes = 19, but I dont know why need to set a list at the length of 19+1 for class_weight

and I set reduce_zero_label = True for REFUGE2 dataset, and the classes=('background', ' Optic Cup', 'Optic Disc'), I set num_class = 2 for model, but I only get 'background' and 'Optic Cup' in my segmentation mask.
when I set num_class = 3 and class_weight = [0.1 ,1, 2], the index is out of boundary , I check my label:
class_weight: tensor([0.1000, 1.0000, 2.0000], device='cuda:0')
label:tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
I dont know why my label got a pixel at 3, I have transfered all the pixes to 0~2 before training

If you set reduce_zero_label=True, the code will perform the following actions:

  1. Replace all occurrences of 0 with 255.
  2. Replace 1 with 0, 2 with 1, and so on.
  3. When num_classes=2, the model will only predict classes 0 and 1.

My recommended configuration is as follows:

  • If 0 represents the background and not unlabeled regions, you should set reduce_zero_label=False and num_classes=3.
  • Otherwise, if 0 represents unlabeled regions, you should set reduce_zero_label=True and num_classes=2. Also, define CLASSES=['Optic Cup', 'Optic Disc'].
  • You should maintain class_weight=[1.0] * num_classes + [0.1].