
A TensorFlow implementation of "Co-teaching: Robust Training of Deep Neural Networks with Extremely Noisy Labels"

Primary LanguagePythonMIT LicenseMIT


This repository reproduces the NeurIPS'18 paper Co-teaching: Robust Training of Deep Neural Networks with Extremely Noisy Labels by TensorFlow.

  • TensorFlow implementation, see all *_tf.py files.
  • Adapt original co-teaching PyTorch implementation to PyTorch 1.1.0, see all *_th.py files. The original PyTorch implementation is provided by the author "Bo Han" as: [bhanML/Co-teaching].


The codes are developed and tested on MacOS (python==3.7.x, CPU) and Ubuntu 18.04 (python==3.6.x, NVIDIA GeForce GTX 1080 Ti GPU with CUDA==10.0) with following environment:

  • tensorflow==1.13.1 (>=1.8.0)
  • pytorch==1.1.0 (>=0.4.1)
  • numpy==1.14.6 (>=1.14.2)


On MacOS

Install TensorFlow via:

$ pip3 install tensorflow==1.13.1

Install PyTorch via:

$ pip3 install torch torchvision

On Ubuntu

Install TensorFlow via:

$ pip3 install tensorflow==1.13.1  # CPU version
$ pip3 install tensorflow-gpu==1.13.1  # GPU version

Install PyTorch via:

# CPU version
$ pip3 install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl  
$ pip3 install https://download.pytorch.org/whl/cpu/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl
# GPU version
$ pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl
$ pip3 install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl


Here is an example for TensorFlow:

$ python3 main_tf.py --dataset cifar10 --noise_type symmetric --noise_rate 0.5

Here is an example for PyTorch:

$ python3 main_th.py --dataset cifar10 --noise_type symmetric --noise_rate 0.5


Performance on benchmark datasets reported by the Author:

(Flipping, Rate) MNIST CIFAR-10 CIFAR-100
(Pair, 45%) 87.58% 72.85% 34.40%
(Symmetry, 50%) 91.68% 74.49% 41.23%
(Symmetry, 20%) 97.71% 82.18% 54.36%

Performance on benchmark datasets derived by the codes in this repository:

th means PyTorch while tf means TensorFlow.

(Flipping, Rate) MNIST (th -- tf) CIFAR-10 (th -- tf) CIFAR-100 (th -- tf)
(Pair, 45%) 88.63% -- 94.16% 72.88% -- 76.04% 34.05% -- 35.24%
(Symmetry, 50%) 92.34% -- 98.05% 74.56% -- 79.64% 41.17% -- 49.09%
(Symmetry, 20%) 97.84% -- 99.16% 82.87% -- 87.02% 54.11% -- 59.55%

The model structure and parameters setting of TensorFlow version are almost same as those of PyTorch version, but the performance of TensorFlow version is generally better than the PyTorch version, I think it maybe caused by the internal implementation of some functions are different between these two frameworks.