Code for Gradient-Based Training of Gaussian Mixture Models for High-Dimensional Streaming Data
Tested with Python 3.7+ and tensorflow-gpu 1.14 (Linux and Windows)
Additional packages besides Python3 standard lib: numpy, matplotlib, scipy
Can all be installed via pip, e.g. python3 -m pip install numpy
- Install python3
- Install tensorflow (with pip)
- Install other dependencies (with pip)
- Start run with default parameters
python GMM.py
orpython3 emAlgos.py --taskEpochs 1 --nrTasks 1
The module 'experimentdataset' provides download facilities for all datasets used in the paper. This is controlled by the cmd line parameter '--dataset_file' for all programs in this archive. Valid arguments include: MNIST, FashionMNIST, NotMNIST, Devanagari, SVHN, Fruits, ISOLET (see experimentdataset doc for the full list). If a dataset is not already cached locally, it is downloaded automatically and cached (which can take a while for some datasets liek Fruits or SVHN). For some datasets, a kaggle account is required. MNIST and Devanagari are included/cached in this archive already.
The main file is GMM.py. To change default parameters, you can specify command line parameters. Use python3 GMM.py --help
to get an overview.
For example, to change the number of used Gaussian components: python3 GMM.py --K 100
Or work with FashionMNIST: python3 GMM.py --K 100 --dataset_file FashionMNIST
The main file is emAlgos.py which can perform EM and sEM with a variety of options, see python3 emAlgos.py --help
For example, to change the number of used Gaussian components: python3 emAlgos.py --mode sEM --taskEpochs 1 --nrTasks 1 --n 100
Or work with FashionMNIST: python3 emAlgos.py --mode sEM --n 100 --taskEpochs 1 --nrTasks 1 --dataset_file FashionMNIST
GMM.py and emAlgos.py write the current GMM weights, centroids and precision matrices to mus/pis/sigmas.npy.
Visualize centroids and pis with python3 vis.py
which generates a file mus.png that can be displayed.
Standard incovation to visualize centroids: python3 vis.py
See python3 vis.py --help
for options.
Both GMM.py and emAlgos.py generate .json files which contain all relevant information about an experiment. Most notably, they contain the log-likelihoods at various points in time, measured on all relevant datasets. The Name of the log file is derived from the arguments '--tmp_dir' and '--exp_id'
The files sem.bash and gmm.bash contain some sample invocations both for sEM and GMM from the streaming experiments in the paper.
Under Linux, you can execute these files by typing, e.g., source sem.bash
Log files from these experiments will be plaed in the subdirectory 'ExpDist'.
The files sem_inc.bash and gmm_inc.bash contain some sample invocations both for sEM and GMM from the concept drift experiments in the paper.
Under Linux, you can execute these files by typing, e.g., source sem_inc.bash
Log files from these experiments will be placed in the subdirectory 'ExpDist'.