/capsnet-traffic-sign-classifier

A Tensorflow implementation of CapsNet(Capsules Net) apply on german traffic sign dataset

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Capsnet - Traffic sign classifier - Tensorflow

A Tensorflow implementation of CapsNet(Capsules Net) apply on the German traffic sign dataset

Contributions welcome License completion

This implementation is based on this paper: Dynamic Routing Between Capsules (https://arxiv.org/abs/1710.09829) from Sara Sabour, Nicholas Frosst and Geoffrey E. Hinton.

This repository is a work in progress implementation of a Capsules Net. Since I am using a different dataset (Not MNIST) some details in the architecture are different. The code for the CapsNet is located in the following file: caps_net.py while the whole model is created inside the model.py file. The two main methods used to build the CapsNet are conv_caps_layer and fully_connected_caps_layer

Requirements

  • Python 3
  • NumPy 1.13.1
  • Tensorflow 1.3.0
  • docopt 0.6.2
  • Sklearn: 0.18.1
  • Matplotlib

Install

$> git clone https://github.com/thibo73800/capsnet_traffic_sign_classifier.git
$> cd capsnet_traffic_sign_classifier.git
$> wget https://d17h27t6h515a5.cloudfront.net/topher/2017/February/5898cd6f_traffic-signs-data/traffic-signs-data.zip
$> unzip traffic-signs-data.zip
$> mkdir dataset
$> mv *.p dataset/
$> rm traffic-signs-data.zip

Train

$> python train.py -h
$> python train.py dataset/

During the training, the checkpoint is saved by default into the outputs/checkpoints/ folder. The exact path and name of the checkpoint is print during the training.

Test

In order to measure the accuracy and the loss on the Test dataset you need to used the test.py script as follow:

$> python test.py outputs/checkpoints/ckpt_name dataset/ 

Metrics / Tensorboard

Accuracy:

  • Train: 99%
  • Validation: 98%
  • Test: 97%

Checkpoints and tensorboard files are stored inside the outputs folder.

Exemple of some prediction: