Simple pytorch classifier using custom modified VGG16 architecture.
Tutorial video here (in czech language): https://youtu.be/9U-RQQJmE6E
- conda env create --file environment.yml
- conda activate pytorch_classifier
- Download pretrained model from uloz.to server here.
- Save it into ./checkpoints folder.
- Run demo:
python demo.py
- Download dataset animals10 from here.
- Save dataset zip (archive.zip) into dataset folder.
- Run preprocess dataset (it splits dataset into TRAIN and VAL folders):
python prepare_dataset.py
python train.py --train_dir [path/to/train/data/folder] \
--val_dir [path/to/val/data/folder] \
--checkpoint_dir [path/to/checkpoints/dir] \
--pretrained_model [path/to/pretrained/model.pkl]
Example
python train.py --train_dir ./dataset/data/TRAIN \
--val_dir ./dataset/data/VAL \
--checkpoint_dir ./checkpoints \
--pretrained_model ./checkpoints/vgg16_0048.pkl
python image_tests.py --image_path [path/to/img] \
--checkpoint_path [path/to/checkpoint.pkl]
Example
python image_tests.py --image_path ./demo/butterfly.jpg \
--checkpoint_path ./checkpoints/vgg16_0048.pkl