Original Paper: link
- Recommend using an virtual environment to run
pip install -r requirements.txt
Go to Kaggle MNIST Dataset and download
Extract data file to get mnist.mat
data file.
unzip archive.zip
usage: python train.py [-h] -d DATA [-hd HIDDEN] [-ld LATENT] [-lr LEARNING] [-e EPOCHS] [-b BATCH_SIZE] [-m MODEL]
optional arguments:
-h, --help show this help message and exit
-d DATA, --data DATA path/to/train/data
-hd HIDDEN, --hidden HIDDEN
number of hidden unit
-ld LATENT, --latent LATENT
number of latent unit
-lr LEARNING, --learning LEARNING
learning rate
-e EPOCHS, --epochs EPOCHS
epochs
-b BATCH_SIZE, --batch_size BATCH_SIZE
Batch size
-m MODEL, --model MODEL
path/to/model/saving/location
# Model class must be defined somewhere
model = torch.load("path/to/model/file/located")
model.eval()