/data_halucination_teaching

Implementation for <Iterative Teaching by Data Hallucination> in AISTATS'23

Primary LanguagePython

Iterative Teaching by Data Hallucination

arxiv-link made-with-pytorch License: MIT

Table of Contents: - Getting started - Usage - Citation -

About The Project

This is the official github repository for our work Iterative Teaching by Data Hallucination.

We consider the problem of iterative machine teaching, where a teacher sequentially provides examples based on the status of a learner under a discrete input space (i.e., a pool of finite samples), which greatly limits the teacher’s capability. To address this issue, we study iterative teaching under a continuous input space where the input example (i.e., image) can be either generated by solving an optimization problem or drawn directly from a continuous distribution. Specifically, we propose data hallucination teaching (DHT) where the teacher can generate input data intelligently based on labels, the learner’s status and the target concept. We study a number of challenging teaching setups (e.g., linear/neural learners in omniscient and black-box settings).

(back to top)

Getting Started

This is an example of how you may give instructions on setting up your project locally. To get a local copy up and running follow these simple example steps.

Data

We evaluate our data hullucination teaching framework on common image classification data, including MNIST, CIFAR10 and CIFAR100. Create a folder named data in the root directory and the datasets will be automatically downloaded to the repository first time when the training is executed.

Installation

Below is an example of how you can instruct your audience on installing and setting up your app. This template doesn't rely on any external dependencies or services.

  1. Clone the repo
    git clone https://github.com/Zeju1997/data_halucination_teaching.git
  2. Install required packages
     conda create -n dht python=3.6
     conda activate dht
     conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 -c pytorch
     pip install imageio scikit-image scikit-learn matplotlib seaborn pyyaml easydict tensorboard tensorboardX tqdm opencv-python mathutils==2.81.2 
  3. Create the environment from the yml file
     conda env create -f environment.yml

(back to top)

Usage

Using this repository to reproduce results from the paper is very straightforward, i.e., simply run the desired experiment specifiying a suited dataset.

  1. Run different teaching policy with the corresponding config file. Note, the config file should match the teaching policy!
    python train.py --teaching_policy='omniscient_unrolled_mnist' --config='mnist_omniscient_unrolled'

(back to top)

License

Distributed under the MIT License. See LICENSE.txt for more information.

(back to top)

Citation

@InProceedings{Qiu2023DHT,
    title={Iterative Teaching by Data Hallucination},
    author={Qiu, Zeju and Liu, Weiyang and Xiao, Tim Z and Liu, Zhen 
      and Bhatt, Umang and Luo, Yucen and Weller, Adrian and Schölkopf, Bernhard},
    booktitle = {AISTATS},
    year={2023}
}

(back to top)

Acknowledgments

Our repository is built upon the following projects:

(back to top)