This is the code for the paper: Investigating the Weighting Mechanism Using an Interpretable Weighting Framework
The requiring environment is as bellow:
- Linux
- python 3.8
- pytorch 1.9.0
- torchvision 0.10.0
Here are two examples for training imbalanced and noisy data:
ResNet32 on CIFAR10-LT with imbalanced factor of 10:
python main.py --dataset cifar10 --imbalanced_factor 10
ResNet32 on noisy CIFAR10 with 20% pair-flip noise:
python main.py --dataset cifar10 --corruption_type flip2 --corruption_ratio 0.2
The default sample weighting network in the code is Neural Regression Tree (NRT) with pruning. You can also use MLP as the sample weighting network. Both the two networks are in the file ``model.py".