/DIVA

DIVA: A Dirichlet Process Mixtures Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder

Primary LanguageJupyter Notebook

DIVA: A Dirichlet Process Mixtures Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder

License: MIT Python 3.7+

Official implementation for paper: DIVA: A Dirichlet Process Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder

A demo video for showing DIVA's dynamic adaptation ability in deep clustering. Demo Video

Requirements

we use python 3.7 and pytorch-lightning for training. Before start training, make sure you have installed bnpy package in your local environment, refer to here for more details.

  • python 3.7
  • bnpy 1.7.0
  • pytorch-lightning 1.9.4
  • numpy, pandas, matplotlib, seaborn, torchvision

Installation Instructions

# Install dependencies and package
pip3 install -r requirements.txt

Detailed Code Structure Overview

DIVA
  |- dataset                    # folder for saving datasets
  |    |- reuters10k.py         # dataset instance of reuters10k that follows torchvision formatting
  |    |- reuters10k.mat        # origin data of reuters10k
  |- pretrained                 # folder for saving pretrained example model on MNIST
  |    |- dpmm                  # folder for saving DPMM cluster module
  |    |- diva_vae.ckpt         # checkpoint file of trained DIVA VAE part on MNIST with 100 epochs and ACC 0.91
  |    |- pretrained.ipynb      # example file how to load pretrained model
  |- diva.py                    # diva implementations for image and text; train manager
  |- main_mnist.ipynb           # main entry point of diva training on MNIST, including evaluation plots.
  |- main_stl10.ipynb           # main entry point of diva training on STL-10.
  |- main_imagenet50.ipynb      # main entry point of diva training on ImageNet-50.
  |- feature_extraction.ipynb   # script that using pretrained ResNet-50 to extract features of STL-10.

Dataset Notation

Since the training on raw image of STL-10 and ImageNet-50 is quite difficult, we use extractor to get low dimensional encoding of these datasets. For STL-10 we use pretrained ResNet-50 provided by torchvision, just follow the script feature_extraction.ipynb you will get the features that we used in our study. For ImageNet-50 we use the MOCO to extract features, more details refer to here and here.

Load pretrained DPMM clustering module

# load DPMM module
dpmm_model = bnpy.ioutil.ModelReader.load_model_at_prefix('path/to/your/bn_model/folder/dpmm', prefix="Best")

# function for getting the cluster parameters
def calc_cluster_component_params(bnp_model):
        comp_mu = [torch.Tensor(bnp_model.obsModel.get_mean_for_comp(i)) for i in np.arange(0, bnp_model.obsModel.K)]
        comp_var = [torch.Tensor(np.sum(bnp_model.obsModel.get_covar_mat_for_comp(i), axis=0)) for i in np.arange(0, bnp_model.obsModel.K)] 
        return comp_mu, comp_var

Citation

if you would like to refer to our work, please use following BibTeX formatted citation

@misc{bing2023diva,
      title={DIVA: A Dirichlet Process Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder}, 
      author={Zhenshan Bing and Yuan Meng and Yuqi Yun and Hang Su and Xiaojie Su and Kai Huang and Alois Knoll},
      year={2023},
      eprint={2305.14067},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}