/PSCapsNet

A PyTorch implementation of Parameter-sharing Capsule Network based on the paper "Evaluating Generalization Ability of Convolutional Neural Networks and Capsule Networks for Image Classification via Top-2 Classification"

Primary LanguagePython

PSCapsNet

A PyTorch implementation of Parameter-sharing Capsule Network based on the paper Evaluating Generalization Ability of Convolutional Neural Networks and Capsule Networks for Image Classification via Top-2 Classification.

Requirements

conda install pytorch torchvision -c pytorch
  • PyTorchNet
pip install git+https://github.com/pytorch/tnt.git@master
  • capsule-layer
pip install git+https://github.com/leftthomas/CapsuleLayer.git@master
  • opencv
conda install opencv

Datasets

We have uploaded the datasets into BaiduYun(access code:tr4u). You needn't download the datasets by yourself, the code will download them automatically.

Usage

Train model

python -m visdom.server -logging_level WARNING & python main.py --data_type CIFAR10 --net_mode CNN --num_epochs 300
optional arguments:
--data_type                dataset type [default value is 'MNIST'](choices:['MNIST', 'FashionMNIST', 'SVHN', 'CIFAR10', 'STL10'])
--net_mode                 network mode [default value is 'Capsule'](choices:['Capsule', 'CNN'])
--capsule_type             capsule network type [default value is 'ps'](choices:['ps', 'fc'])
--routing_type             routing type [default value is 'k_means'](choices:['k_means', 'dynamic'])
--num_iterations           routing iterations number [default value is 3]
--batch_size               train batch size [default value is 64]
--num_epochs               train epochs number [default value is 100]
--use_da                   use data augmentation or not [default value is False]

Visdom now can be accessed by going to 127.0.0.1:8097/env/$data_type in your browser, $data_type means the dataset type which you are training. If you want to interrupt this process, just type ps aux | grep visdom to find the PID, then kill PID.

Visualization

python vis.py --data_type CIFAR10
optional arguments:
--data_type                dataset type [default value is 'STL10'](choices:['MNIST', 'FashionMNIST', 'SVHN', 'CIFAR10', 'STL10'])
--data_mode                visualized data mode [default value is 'test_single'](choices:['test_single', 'test_multi'])
--num_iterations           routing iterations number [default value is 3]

Generated results are on the same directory with README.md.

Generate figures

python utils.py

Generated figures are on the same directory with README.md.

Results

The train loss、accuracy, test loss、accuracy and confusion matrix are showed on visdom. The pretrained models and more results can be found in BaiduYun(access code:kvmj).

MNIST

  • PS-Capsule

result

  • FC-Capsule

result

  • CNN

result

FashionMNIST

  • PS-Capsule

result

  • FC-Capsule

result

  • CNN

result

SVHN

  • PS-Capsule

result

  • FC-Capsule

result

  • CNN

result

CIFAR10

  • PS-Capsule

result

  • FC-Capsule

result

  • CNN

result

STL10

  • PS-Capsule

result

  • FC-Capsule

result

  • CNN

result