Interpretable weighting framework

This is the code for the paper: Investigating the Weighting Mechanism Using an Interpretable Weighting Framework

Setups

The requiring environment is as bellow:

  • Linux
  • python 3.8
  • pytorch 1.9.0
  • torchvision 0.10.0

Running Interpretable weighting framework on benchmark datasets (CIFAR-10 and CIFAR-100).

Here are two examples for training imbalanced and noisy data:
ResNet32 on CIFAR10-LT with imbalanced factor of 10:

python main.py --dataset cifar10 --imbalanced_factor 10

ResNet32 on noisy CIFAR10 with 20% pair-flip noise:
python main.py --dataset cifar10 --corruption_type flip2 --corruption_ratio 0.2

The default sample weighting network in the code is Neural Regression Tree (NRT) with pruning. You can also use MLP as the sample weighting network. Both the two networks are in the file ``model.py".