/SemiReward

[ICLR 2024] SemiReward: A General Reward Model for Semi-supervised Learning

Primary LanguagePythonApache License 2.0Apache-2.0

Siyuan Li*,1,2, Weiyang Jin*,1, Zedong Wang1,2, Fang Wu1,2, Zicheng Liu1,2, Chen Tan1,2, Stan Z. Li†,1

1Westlake University, 2Zhejiang University

Semi-supervised Reward framework (SemiReward) is designed to predict reward scores to evaluate and filter out high-quality pseudo labels, which is pluggable to mainstream Semi-Supervised Learning (SSL) methods in wide task types and scenarios. The results and details are reported in our paper. The implementations and models of SemiReward are based on USB codebase. USB is a Pytorch-based Python package for SSL. It is easy-to-use/extend, affordable to small groups, and comprehensive for developing and evaluating SSL algorithms. USB provides the implementation of 14 SSL algorithms based on Consistency Regularization, and 15 tasks for evaluation from CV, NLP, and Audio domain. More details can be seen in Semi-supervised Learning.

Table of Contents
  1. News and Updates
  2. Introduction
  3. Getting Started
  4. Usage
  5. Community
  6. License
  7. Acknowledgments

Introduction

Semi-supervised learning (SSL) has witnessed great progress with various improvements in the self-training framework with pseudo labeling. The main challenge is how to distinguish high-quality pseudo labels against the confirmation bias. However, existing pseudo-label selection strategies are limited to pre-defined schemes or complex hand-crafted policies specially designed for classification, failing to achieve high-quality labels, fast convergence, and task versatility simultaneously. To these ends, we propose a Semi-supervised Reward framework (SemiReward) that predicts reward scores to evaluate and filter out high-quality pseudo labels, which is pluggable to mainstream SSL methods in wide task types and scenarios. To mitigate confirmation bias, SemiReward is trained online in two stages with a generator model and subsampling strategy. With classification and regression tasks on 13 standard SSL benchmarks of three modalities, extensive experiments verify that SemiReward achieves significant performance gains and faster convergence speeds upon Pseudo Label, FlexMatch, and Free/SoftMatch.

News and Updates

  • [01/16/2024] SemiReward v0.2.0 has been updated and accepted by ICLR'2024.
  • [10/18/2023] SemiReward v0.1.0 has been released.

(back to top)

Getting Started

First, you need to set up USB locally. To get a local copy up, running follow these simple example steps.

Prerequisites

USB is built on pytorch, with torchvision, torchaudio, and transformers.

To install the required packages, you can create a conda environment:

conda create --name semireward python=3.8

then use pip to install required packages:

pip install -r requirements.txt

From now on, you can start use USB by typing

python train.py --c config/usb_cv/fixmatch/fixmatch_cifar100_200_0.yaml

Installation

USB provide a Python package semilearn of USB for users who want to start training/testing the supported SSL algorithms on their data quickly:

pip install semilearn

You can also develop your own SSL algorithm and evaluate it by cloning SemiReward (USB):

git clone https://github.com/Westlake-AI/SemiReward.git

(back to top)

Prepare Datasets

The detailed instructions for downloading and processing are shown in Dataset Download. Please follow it to download datasets before running or developing algorithms.

(back to top)

Usage

Start with Docker

The following steps to train your own SemiReward model just as same with USB.

Step1: Check your environment

You need to properly install Docker and nvidia driver first. To use GPU in a docker container You also need to install nvidia-docker2 (Installation Guide). Then, Please check your CUDA version via nvidia-smi

Step2: Clone the project

git clone https://github.com/microsoft/Semi-supervised-learning.git

Step3: Build the Docker image

Before building the image, you may modify the Dockerfile according to your CUDA version. The CUDA version we use is 11.6. You can change the base image tag according to this site. You also need to change the --extra-index-url according to your CUDA version in order to install the correct version of Pytorch. You can check the url through Pytorch website.

Use this command to build the image

cd Semi-supervised-learning && docker build -t semilearn .

Job done. You can use the image you just built for your own project. Don't forget to use the argument --gpu when you want to use GPU in a container.

Training

Here is an example to train one of baselines FlexMatch on CIFAR-100 with 200 labels. Training other supported algorithms (on other datasets with different label settings) can be specified by a config file:

python train.py --c config/usb_cv/flexmatch/flexmatch_cifar100_200_0.yaml

Here is an example to train FlexMatch with SemiReward on CIFAR-100 with 200 labels. Training other baselines with SemiReward can be specified by a config file:

python train.py --c config/SemiReward/usb_cv/flexmatch/flexmatch_cifar100_200_0.yaml

You can change hyperparameters for SemiReward by configurations (.yaml files) like other baselines. If you want to change loss or something is fixed in our method for SemiReward, it is recommanded to open flie from:

semilearn/algorithms/srflexmatch/srflexmatch.py

Tips: Semireward uses 4GPUs for training by default. Also, for users in some areas of China, huggingface region locking occurs, so local pre-training weights need to be used when using the Bert and huBert models. Take the Bert model as an example, you need to focus on ./semilearn/datasets/collactors/nlp_collactor.py, find line 102 to change it's address into your local folder for Bert. Also, in file ./semilearn/nets/bert/bert.py line 13, it need to the same way to adjust.

Evaluation

After training, you can check the evaluation performance on training logs, or running evaluation script:

python eval.py --dataset cifar100 --num_classes 100 --load_path /PATH/TO/CHECKPOINT

Develop

Check the developing documentation for creating your own SSL algorithm!

For more examples, please refer to the Documentation

(back to top)

Contributing

If you have any ideas to improve SemiReward, we welcome your contributions! Feel free to fork the repository and submit a pull request. Alternatively, you can open an issue and label it as "enhancement." Don't forget to show your support by giving the project a star! Thank you once more!

  1. Fork the project
  2. Create your branch (git checkout -b your_name/your_branch)
  3. Commit your changes (git commit -m 'Add some features')
  4. Push to the branch (git push origin your_name/your_branch)
  5. Open a Pull Request

License

Distributed under the MIT License. See LICENSE.txt for more information.

(back to top)

Citation

Please consider citing us if you find this project helpful for your project/paper:

@inproceedings{iclr2024semireward,
  title={SemiReward: A General Reward Model for Semi-supervised Learning},
  author={Siyuan Li and Weiyang Jin and Zedong Wang and Fang Wu and Zicheng Liu and Cheng Tan and Stan Z. Li},
  booktitle={International Conference on Learning Representations},
  year={2024}
}

Acknowledgments

SemiReward's implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works:

Contribution and Contact

For adding new features, looking for helps, or reporting bugs associated with SemiReward, please open a GitHub issue and pull request with the tag "new features" or "help wanted". Feel free to contact us through email if you have any questions.

(back to top)