/PyTorch-GAAL

PyTorch implementation of Generative adversarial active learning

Primary LanguagePython

PyTorch-GAAL

This is a PyTorch implementation of Generative Adversarial Active Learning

Directory Structure

———— PyTorch-GAAL
 |__ data           # data directory
 |__ gan			# train and save DCGAN model
 |__ oracle			# pre-trained models act as human oracle
 |__ plot			
 |__ main.py		# Generate commmand to run train.py
 |__ train.py		# main training loop
 |__ utils.py		# toolbox
 |__ requirements.txt       # auto-generated dependencies file, usage: pip install -r requirements.txt
 |__ README.md

Usage

Edit the parameters (dataset, label budget, etc.) in main.py, then:

python main.py

My Environment (For your reference).

python 3.9.7 + pytorch 1.9.0 + torchvision 0.10.0 + cudatoolkit 10.2

Experiment results

  1. Train on MNIST, test on MNIST. Classifying 5 and 7: Results

  2. Train on MNIST, test on USPS. Classifying 5 and 7: Results

To be updated

  • Experiment results on CIFAR-10
  • Comparing with the SVMactive algorithm

Reference

[1] : PyTorch DCGAN Tutorial

[2] : DCGAN on CIFAR-10