TPAMI2023: CMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning (Official Pytorch implementation)
======================================================================================================================================================
This is the code for the paper: CMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning. Jun Shu, Xiang Yuan, Deyu Meng, and Zongben Xu. Official site, Arxiv Vervision
- Overview
- Prerequisites
- Experiments
- Citation
- Acknowledgments
Modern deep neural networks (DNNs) can easily overfit to biased training data containing corrupted labels or class imbalance. Sample re-weighting methods are popularly used to alleviate this data bias issue. Most current methods, however, require manually pre-specifying the weighting schemes as well as their additional hyper-parameters relying on the characteristics of the investigated problem and training data. This makes them fairly hard to be generally applied in practical scenarios, due to their significant complexities and inter-class variations of data bias situations. To address this issue, we propose a meta-model capable of adaptively learning an explicit weighting scheme directly from data. Specifically, by seeing each training class as a separate learning task, our method aims to extract an explicit weighting function with sample loss and task/class feature as input, and sample weight as output, expecting to impose adaptively varying weighting schemes to different sample classes based on their own intrinsic bias characteristics. The architectures of the CMW-Net meta-model is shown blow:
- Python 3.7
- PyTorch >= 1.5.0
- Torchvision >= 0.4.0
- sklearn
- torchnet
Synthetic and real data experiments substantiate the capability of our method on achieving proper weighting schemes in various data bias cases. The task-transferability of the learned weighting scheme is also substantiated. A performance gain can be readily achieved compared with previous state-of-the-art ones without additional hyper-parameter tuning and meta gradient descent step. The general availability of our method for multiple robust deep learning issues has also been validated. We provide the running scripts in corresponding code. The detail description and main results are shown below.
Class Imbalance Experiments
You can repeat the results of Class Imbalance Experiments(TABLE 1 in the paper) by
cd section4/Class_Imbalance
bash table1.sh
The main results are shown below:
Dataset Name | CIFAR-10-LT | CIFAR-100-LT | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
Imbalance factor | 200 | 100 | 50 | 20 | 10 | 1 | 200 | 100 | 50 | 20 | 10 | 1 |
ERM | 34.32 | 29.64 | 25.19 | 17.77 | 13.61 | 7.53 | 65.16 | 61.68 | 56.15 | 48.86 | 44.29 | 29.50 |
Focal loss | 34.71 | 29.62 | 23.29 | 17.24 | 13.34 | 6.97 | 64.38 | 61.59 | 55.68 | 48.05 | 44.22 | 28.85 |
CB loss | 31.11 | 27.63 | 21.95 | 15.64 | 13.23 | 7.53 | 64.44 | 61.23 | 55.21 | 48.06 | 42.43 | 29.37 |
LDAM loss | - | 26.65 | - | - | 13.04 | - | 60.40 | - | - | - | 43.09 | - |
L2RW | 33.49 | 25.84 | 21.07 | 16.90 | 14.81 | 10.75 | 66.62 | 59.77 | 55.56 | 48.36 | 46.27 | 35.89 |
MW-Net | 32.80 | 26.43 | 20.90 | 15.55 | 12.45 | 7.19 | 63.38 | 58.39 | 54.34 | 46.96 | 41.09 | 29.90 |
MCW with CE loss | 29.34 | 23.59 | 19.49 | 13.54 | 11.15 | 7.21 | 60.69 | 56.65 | 51.47 | 44.38 | 40.42 | - |
CMW-Net with CE loss | 27.80 | 21.15 | 17.26 | 12.45 | 10.97 | 8.30 | 60.85 | 55.25 | 49.73 | 43.06 | 39.41 | 30.81 |
MCW with LDAM loss | 25.10 | 20.00 | 17.77 | 15.63 | 12.60 | 10.29 | 60.47 | 55.92 | 50.84 | 47.62 | 42.00 | - |
CMW-Net with LDAM loss | 25.57 | 19.95 | 17.66 | 13.08 | 11.42 | 7.04 | 59.81 | 55.87 | 51.14 | 45.26 | 40.32 | 29.19 |
SADE | 19.37 | 16.78 | 14.81 | 11.78 | 9.88 | 7.72 | 54.78 | 50.20 | 46.12 | 40.06 | 36.40 | 28.08 |
CMW-Net with SADE | 19.11 | 16.04 | 13.54 | 10.25 | 9.39 | 5.39 | 54.59 | 49.50 | 46.01 | 39.42 | 34.78 | 27.50 |
Details can refer to Section 4.1 of the main paper.
Feature-independent Label Noise Experiment
You can repeat the results of Feature-independent Label Noise Experiment(TABLE 2 and TABLE 3 in the paper) by
cd section4/Feature-independent_Label_Noise
bash table2.sh
The main results are shown below:
Datasets | Noise | Symmetric Noise | Asymmetric Noise | ||||||
---|---|---|---|---|---|---|---|---|---|
0.2 | 0.4 | 0.6 | 0.8 | 0.2 | 0.4 | 0.6 | 0.8 | ||
CIFAR-10 | ERM | 86.98 ± 0.12 | 77.52 ± 0.41 | 73.63 ± 0.85 | 53.82 ± 1.04 | 83.60 ± 0.24 | 77.85 ± 0.98 | 69.69 ± 0.72 | 55.20 ± 0.28 |
Forward | 87.99 ± 0.36 | 83.25 ± 0.38 | 74.96 ± 0.65 | 54.64 ± 0.44 | 91.34 ± 0.28 | 89.87 ± 0.61 | 87.24 ± 0.96 | 81.07 ± 1.92 | |
GCE | 89.99 ± 0.16 | 87.31 ± 0.53 | 82.15 ± 0.47 | 57.36 ± 2.08 | 89.75 ± 1.53 | 87.75 ± 0.36 | 67.21 ± 3.64 | 57.46 ± 0.31 | |
M-correction | 93.80 ± 0.23 | 92.53 ± 0.11 | 90.30 ± 0.34 | 86.80 ± 0.11 | 92.15 ± 0.18 | 91.76 ± 0.57 | 87.59 ± 0.33 | 67.78 ± 1.22 | |
DivideMix | 95.70 ± 0.31 | 95.00 ± 0.17 | 94.23 ± 0.23 | 92.90 ± 0.31 | 93.96 ± 0.21 | 91.80 ± 0.78 | 80.14 ± 0.45 | 59.23 ± 0.38 | |
L2RW | 89.45 ± 0.62 | 87.18 ± 0.84 | 81.57 ± 0.66 | 58.59 ± 1.84 | 90.46 ± 0.56 | 89.76 ± 0.53 | 88.22 ± 0.71 | 85.17 ± 0.31 | |
MW-Net | 90.46 ± 0.52 | 86.53 ± 0.57 | 82.98 ± 0.34 | 64.41 ± 0.92 | 92.69 ± 0.24 | 90.17 ± 0.11 | 68.55 ± 0.76 | 58.29 ± 1.33 | |
CMW-Net | 91.09 ± 0.54 | 86.91 ± 0.37 | 83.33 ± 0.55 | 64.80 ± 0.72 | 93.02 ± 0.25 | 92.70 ± 0.32 | 91.28 ± 0.40 | 87.50 ± 0.26 | |
CMW-Net-SL | 96.20 ± 0.33 | 95.29 ± 0.14 | 94.51 ± 0.32 | 92.10 ± 0.76 | 95.48 ± 0.29 | 94.51 ± 0.52 | 94.18 ± 0.21 | 93.07 ± 0.24 | |
CIFAR-100 | ERM | 60.38 ± 0.75 | 46.92 ± 0.51 | 31.82 ± 1.16 | 8.29 ± 3.24 | 61.05 ± 0.11 | 50.30 ± 1.11 | 37.34 ± 1.80 | 12.46 ± 0.43 |
Forward | 63.71 ± 0.49 | 49.34 ± 0.60 | 37.90 ± 0.76 | 9.57 ± 1.01 | 64.97 ± 0.47 | 52.37 ± 0.71 | 44.58 ± 0.60 | 15.84 ± 0.62 | |
GCE | 68.02 ± 1.05 | 64.18 ± 0.30 | 54.46 ± 0.31 | 15.61 ± 0.97 | 66.15 ± 0.44 | 56.85 ± 0.72 | 40.58 ± 0.47 | 15.82 ± 0.63 | |
M-correction | 73.90 ± 0.14 | 70.10 ± 0.14 | 59.50 ± 0.35 | 48.20 ± 0.23 | 71.85 ± 0.19 | 70.83 ± 0.48 | 60.51 ± 0.52 | 16.06 ± 0.33 | |
DivideMix | 76.90 ± 0.21 | 75.20 ± 0.12 | 72.00 ± 0.33 | 59.60 ± 0.21 | 76.12 ± 0.44 | 73.47 ± 0.63 | 45.83 ± 0.83 | 16.98 ± 0.40 | |
L2RW | 65.32 ± 0.42 | 55.75 ± 0.81 | 41.16 ± 0.85 | 16.80 ± 0.22 | 65.93 ± 0.17 | 62.48 ± 0.56 | 51.66 ± 0.49 | 12.40 ± 0.61 | |
MW-Net | 69.93 ± 0.40 | 65.29 ± 0.43 | 55.59 ± 1.07 | 27.63 ± 0.56 | 69.80 ± 0.34 | 64.88 ± 0.63 | 56.89 ± 0.95 | 17.05 ± 0.52 | |
CMW-Net | 70.11 ± 0.19 | 65.84 ± 0.50 | 56.93 ± 0.38 | 28.36 ± 0.67 | 71.07 ± 0.56 | 66.15 ± 0.51 | 58.21 ± 0.78 | 17.41 ± 0.16 | |
CMW-Net-SL | 77.84 ± 0.12 | 76.25 ± 0.67 | 72.61 ± 0.92 | 55.21 ± 0.31 | 77.73 ± 0.37 | 75.69 ± 0.68 | 61.54 ± 0.72 | 18.34 ± 0.21 |
Datasets | Noise | Symmetric | Asy. Noise | |||
---|---|---|---|---|---|---|
0.2 | 0.5 | 0.8 | 0.9 | 0.4 | ||
CIFAR-10 | DivideMix | 95.7 | 94.4 | 92.9 | 75.4 | 92.1 |
ELR+ | 94.6 | 93.8 | 93.1 | 75.2 | 92.7 | |
REED | 95.7 | 95.4 | 94.1 | 93.5 | - | |
AugDesc | 96.2 | 95.1 | 93.6 | 91.8 | 94.3 | |
C2D | 96.2 | 95.1 | 94.3 | 93.4 | 90.8 | |
Two-step | 96.2 | 95.3 | 93.7 | 92.7 | 92.4 | |
CMW-Net-SL | 96.2 | 95.1 | 92.1 | 48.0 | 94.5 | |
CMW-Net-SL+ | 96.6 | 96.2 | 95.4 | 93.7 | 96.0 | |
CIFAR-100 | DivideMix | 77.3 | 74.6 | 60.2 | 31.5 | 72.1 |
ELR+ | 77.5 | 72.4 | 58.2 | 30.8 | 76.5 | |
REED | 76.5 | 72.2 | 66.5 | 59.4 | - | |
AugDesc | 79.2 | 77.0 | 66.1 | 40.9 | 76.8 | |
C2D | 78.3 | 76.1 | 67.4 | 58.5 | 75.1 | |
Two-step | 79.1 | 78.2 | 70.1 | 53.2 | 65.5 | |
CMW-Net-SL | 77.84 | 76.2 | 55.2 | 21.2 | 75.7 | |
CMW-Net-SL+ | 80.2 | 78.2 | 71.1 | 64.6 | 77.2 |
Details can refer to Section 4.2 of the main paper.
Feature-dependent Label Noise Experiment
You can repeat the results of TABLE 4 in the paper by
cd section4/Feature-dependent_Label_Noise
bash table4.sh
The main results are shown below:
Datasets | Noise | ERM | LRT | GCE | MW-Net | PLC | CMW-Net | CMW-Net-SL |
---|---|---|---|---|---|---|---|---|
CIFAR-10 | Type-I (35%) | 78.11 ± 0.74 | 80.98 ± 0.80 | 80.65 ± 0.39 | 82.20 ± 0.40 | 82.80 ± 0.27 | 82.27 ± 0.33 | 84.23 ± 0.17 |
Type-I (70%) | 41.98 ± 1.96 | 41.52 ± 4.53 | 36.52 ± 1.62 | 38.85 ± 0.67 | 42.74 ± 2.14 | 42.23 ± 0.69 | 44.19 ± 0.69 | |
Type-II (35%) | 76.65 ± 0.57 | 80.74 ± 0.25 | 77.60 ± 0.88 | 81.28 ± 0.56 | 81.54 ± 0.47 | 81.69 ± 0.57 | 83.12 ± 0.40 | |
Type-II (70%) | 45.57 ± 1.12 | 81.08 ± 0.35 | 40.30 ± 1.46 | 42.15 ± 1.07 | 46.04 ± 2.20 | 46.30 ± 0.77 | 48.26 ± 0.88 | |
Type-III (35%) | 76.89 ± 0.79 | 76.89 ± 0.79 | 79.18 ± 0.61 | 81.57 ± 0.73 | 81.50 ± 0.50 | 81.52 ± 0.38 | 83.10 ± 0.34 | |
Type-III (70%) | 43.32 ± 1.00 | 44.47 ± 1.23 | 37.10 ± 0.59 | 42.43 ± 1.27 | 45.05 ± 1.13 | 43.76 ± 0.96 | 45.15 ± 0.91 | |
CIFAR-100 | Type-I (35%) | 57.68 ± 0.29 | 56.74 ± 0.34 | 58.37 ± 0.18 | 62.10 ± 0.50 | 60.01 ± 0.43 | 62.43 ± 0.38 | 64.01 ± 0.11 |
Type-I (70%) | 39.32 ± 0.43 | 45.29 ± 0.43 | 40.01 ± 0.71 | 44.71 ± 0.49 | 45.92 ± 0.61 | 46.68 ± 0.64 | 47.62 ± 0.44 | |
Type-II (35%) | 57.83 ± 0.25 | 57.25 ± 0.68 | 58.11 ± 1.05 | 63.78 ± 0.24 | 63.68 ± 0.29 | 64.08 ± 0.26 | 64.13 ± 0.19 | |
Type-II (70%) | 39.30 ± 0.32 | 43.71 ± 0.51 | 37.75 ± 0.46 | 44.61 ± 0.41 | 45.03 ± 0.50 | 50.01 ± 0.51 | 51.99 ± 0.35 | |
Type-III (35%) | 56.07 ± 0.79 | 56.57 ± 0.30 | 57.51 ± 1.16 | 62.53 ± 0.33 | 63.68 ± 0.29 | 63.21 ± 0.23 | 64.47 ± 0.15 | |
Type-III (70%) | 40.01 ± 0.18 | 44.41 ± 0.19 | 40.53 ± 0.60 | 45.17 ± 0.77 | 44.45 ± 0.62 | 47.38 ± 0.65 | 48.78 ± 0.62 |
We can repeat the results of TABLE 5 in the paper by
cd section4/Feature-dependent_Label_Noise
bash table5.sh
The main results are shown below:
Datasets | Noise | ERM | LRT | GCE | MW-Net | PLC | CMW-Net | CMW-Net-SL |
---|---|---|---|---|---|---|---|---|
CIFAR-10 | Type-I + Symmetric | 75.26 ± 0.32 | 75.97 ± 0.27 | 78.08 ± 0.66 | 76.39 ± 0.42 | 79.04 ± 0.50 | 78.42 ± 0.47 | 82.00 ± 0.36 |
Type-I + Asymmetric | 75.21 ± 0.64 | 76.96 ± 0.45 | 76.91 ± 0.56 | 76.54 ± 0.56 | 78.31 ± 0.41 | 77.14 ± 0.38 | 80.69 ± 0.47 | |
Type-II + Symmetric | 74.92 ± 0.63 | 75.94 ± 0.58 | 75.69 ± 0.21 | 76.57 ± 0.81 | 80.08 ± 0.37 | 76.77 ± 0.63 | 80.96 ± 0.23 | |
Type-II + Asymmetric | 74.28 ± 0.39 | 77.03 ± 0.62 | 75.30 ± 0.81 | 75.35 ± 0.40 | 77.63 ± 0.30 | 77.08 ± 0.52 | 80.94 ± 0.14 | |
Type-III + Symmetric | 74.00 ± 0.38 | 75.66 ± 0.57 | 77.00 ± 0.12 | 76.28 ± 0.82 | 80.06 ± 0.47 | 77.16 ± 0.30 | 81.58 ± 0.55 | |
Type-III + Asymmetric | 75.31 ± 0.34 | 77.19 ± 0.74 | 75.70 ± 0.91 | 75.82 ± 0.77 | 77.54 ± 0.70 | 76.49 ± 0.88 | 80.48 ± 0.48 | |
CIFAR-100 | Type-I + Symmetric | 48.86 ± 0.56 | 45.66 ± 1.60 | 52.90 ± 0.53 | 57.70 ± 0.32 | 60.09 ± 0.15 | 59.17 ± 0.42 | 60.87 ± 0.56 |
Type-I + Asymmetric | 45.85 ± 0.93 | 52.04 ± 0.15 | 52.69 ± 1.14 | 56.61 ± 0.71 | 56.40 ± 0.34 | 57.42 ± 0.81 | 61.35 ± 0.52 | |
Type-II + Symmetric | 49.32 ± 0.36 | 43.86 ± 1.31 | 53.61 ± 0.46 | 54.08 ± 0.18 | 60.01 ± 0.63 | 59.16 ± 0.18 | 61.00 ± 0.41 | |
Type-II + Asymmetric | 46.50 ± 0.95 | 52.11 ± 0.46 | 51.98 ± 0.37 | 58.53 ± 0.45 | 61.43 ± 0.33 | 58.99 ± 0.91 | 61.35 ± 0.57 | |
Type-III + Symmetric | 48.94 ± 0.61 | 42.79 ± 1.78 | 52.07 ± 0.35 | 55.29 ± 0.57 | 60.14 ± 0.97 | 58.48 ± 0.79 | 60.21 ± 0.48 | |
Type-III + Asymmetric | 45.70 ± 0.12 | 50.31 ± 0.39 | 50.87 ± 1.12 | 58.43 ± 0.60 | 54.56 ± 1.11 | 58.83 ± 0.57 | 60.52 ± 0.53 |
Details can refer to Section 4.3 of the main paper.
Learning with Real-world Noisy Datasets
We test our method in the ANIMAL-10N and mini WebVision. You can repeat the results in the ANIMAL-10N (TABLE 6 in the paper) by
cd section5/ANIMAL-10N
bash table6.sh
The main results are shown below:
Method | Test Accuracy | Method | Test Accuracy |
---|---|---|---|
ERM | 79.4 |
ActiveBias | 80.5 |
Co-teaching | 80.2 |
SELFIE | 81.8 |
PLC | 83.4 |
MW-Net | 80.7 |
CMW-Net | 80.9 |
CMW-Net-SL | 84.7 |
You can repeat the results in the mini WebVision (TABLE 7 in the paper) by
cd section5/mini_WebVision
bash table7.sh
The main results are shown below:
Methods | ILSVRC12 top1 | ILSVRC12 top5 | WebVision top1 | WebVision top5 |
---|---|---|---|---|
Forward | 61.12 | 82.68 | 57.36 | 82.36 |
MentorNet | 63.00 | 81.40 | 57.80 | 79.92 |
Co-teaching | 63.58 | 85.20 | 61.48 | 84.70 |
Interative-CV | 65.24 | 85.34 | 61.60 | 84.98 |
MW-Net | 69.34 | 87.44 | 65.80 | 87.52 |
CMW-Net | 70.56 | 88.76 | 66.44 | 87.68 |
DivideMix | 77.32 | 91.64 | 75.20 | 90.84 |
ELR | 77.78 | 91.68 | 70.29 | 89.76 |
DivideMix | 76.32 | 90.65 | 74.42 | 91.21 |
CMW-Net-SL | 78.08 | 92.96 | 75.72 | 92.52 |
DivideMix with C2D | 79.42 | 92.32 | 78.57 | 93.04 |
CMW-Net-SL+C2D | 80.44 | 93.36 | 77.36 | 93.48 |
Details can refer to Section 5.1 of the main paper.
Webly Supervised Fine-Grained Recognition
We further run our method on a benchmark WebFG-496 dataset consisting of three sub-datasets: Web-aircraft, Web-bird, Web-car, You can repeat the results in the mini WebVision (TABLE 7 in the paper) by
cd section5/WebFG-496
bash table8.sh
The main results are shown below:
Methods | Web-Bird | Web-Aircraft | Web-Car | Average |
---|---|---|---|---|
ERM | 66.56 | 64.33 | 67.42 | 66.10 |
Decoupling | 70.56 | 75.97 | 75.00 | 73.84 |
Co-teaching | 73.85 | 72.76 | 73.10 | 73.24 |
Peer-learning | 76.48 | 74.38 | 78.52 | 76.46 |
MW-Net | 75.60 | 72.93 | 77.33 | 75.29 |
CMW-Net | 75.72 | 73.72 | 77.42 | 75.62 |
CMW-Net-SL | 77.41 | 76.48 | 79.70 | 77.86 |
Details can refer to Section 5.2 of the main paper.
A potential usefulness of the metalearned weighing scheme by CMW-Net is that it is modelagnostic and hopefully equipped into other learning algorithms in a plug-and-play manner. To validate such transferable capability of CMW-Net, we attempt to transfer meta-learned CMW-Net on relatively smaller dataset to significantly larger-scale ones. In specific, we use CMWNet trained on CIFAR-10 with feature-dependent label noise (i.e.,35% Type-I + 30% Asymmetric) as introduced in Sec. 4.3 in the paper since it finely simulates the real-world noise configuration. The extracted weighting function is depicted blew.
Webvisoin dataset
We deploy it on full WebVision. Even with a relatively concise form, our method still outperforms the second-best Heteroscedastic method by an evident margin. This further validates the potential usefulness of CMWNet to practical large-scale problems with complicated data bias situations, with an intrinsic reduction of the labor and computation costs by readily specifying proper weighting scheme for a learning algorithm. You can repeat the performance on full WebVision(TABLE 10 in the main paper) by
cd section6/webvision
bash table10.sh
The main results are shown below:
Methods | ILSVRC12 top1 | ILSVRC12 top5 | WebVision top1 | WebVision top5 |
---|---|---|---|---|
ERM | 69.7 | 87.0 | 62.9 | 83.6 |
MentorNet | 70.8 | 88.0 | 62.5 | 83.0 |
MentorMix | {74.3} | 90.5 | 67.5 | {87.2} |
HAR | 75.0 | 90.6 | 67.1 | 86.7 |
MILE | 76.5 | 90.9 | 68.7 | 86.4 |
Heteroscedastic | 76.6 | 92.1 | 68.6 | 87.1 |
CurriculumNet | 79.3 | 93.6 | - | - |
ERM + CMW-Net-SL | 77.9 | 92.6 | 69.6 | 88.5 |
Details can refer to Section 6 of the main paper.
We evaluate the generality of our proposed adaptive sample weighting strategy in more robust learning tasks.
Partial-Label Learning
It is seen that CMW-Net can significantly enhance the performance of the baseline method in both test cases, showing its potential usability in this Partial-Label Learning task. You can repeat the performance in Partial-Label Learning(Fig 9 in the paper) by
cd section7/Partial-Label_Learning
bash fig9.sh
The main results are shown below.
Accuracy comparisons on PRODEN w/o CMW-Net strategy over CIFAR-10:
Accuracy comparisons on PRODEN w/o CMW-Net strategy over CIFAR-100:
Details can refer to Section 7.1 of the main paper.
If you find this code useful, please cite our paper.
@@inproceedings{CMW-Net,
title={CMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning},
author={Jun Shu, Xiang Yuan, Deyu Meng, and Zongben Xu},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
pages={1-15} ,
year={2023}
}
We appreciate the following github repos for their valuable codebase:
- https://github.com/LiJunnan1992/DivideMix
- https://github.com/ContrastToDivide/C2D
- https://github.com/Vanint/SADE-AgnosticLT
- https://github.com/abdullahjamal/Longtail_DA
- https://github.com/pxiangwu/PLC
- https://github.com/NUST-Machine-Intelligence-Laboratory/weblyFG-dataset
- https://github.com/Lvcrezia77/PRODEN