/keras-image-classification

Using Kaggle cats vs dogs dataset

Primary LanguageJupyter Notebook

Keras Image Classification

Classifies an image as containing either a dog or a cat (using Kaggle's public dataset), but could easily be extended to other image classification problems.

To run these scripts/notebooks, you must have keras, numpy, scipy, and h5py installed, and enabling GPU acceleration is highly recommended if that's an option.

img_clf.py

After playing around with hyperparameters a bit, this reaches around 96-98% accuracy on the validation data, and when tested on Kaggle's hidden test data achieved a log loss score around 0.18.

Most of the code / strategy here was based on this Keras tutorial.

Pre-trained VGG16 model weights can be downloaded here.

The data directory structure I used was:

  • project
    • data
      • train
        • dogs
        • cats
      • validation
        • dogs
        • cats
      • test
        • test

cats_n_dogs.ipynb:

This produced a slightly better score (.161 log loss on kaggle test set). The better score most likely comes from having larger images and ensembling a few models, despite the fact there's no image augmentation in the notebook.

Might run into memory errors because of the large image dimensions -- if so reducing the number of folds and saving the model weights rather than keeping the models in memory should do the trick. The notebook uses a slightly flatter directory structure, with the validation split happening after the images are loaded:

  • project
    • data
      • train
        • dogs
        • cats
      • test
        • test

cats_n_dogs_BN.ipynb:

This produced the best score (0.069 loss without any ensembling). The notebook incorporates some of the techniques from Jeremy Howard's deep learning class , with the inclusion of batch normalization being the biggest factor. I also added extra layers of augmentation to the prediction script, which greatly improved performance.

Pre-trained model weights for VGG16 w/ batch normalization can be downloaded here.

The VGG16BN class is defined in vgg_bn.py, and the data directory structure used was:

  • project
    • data
      • train
        • dogs
        • cats
      • validation
        • dogs
        • cats
      • test
        • test