/fashion_mnist

Python module to download and extract Zalando's Fashion-MNIST database for training and testing deep learning neural networks in computer vision.

Primary LanguagePython

fashion_mnist.py

Title This python module provides a simple to use function to download and extract the Fashion-MNIST database of Zalando's article images that is provided in http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/.

Function:

load_FashionMNIST(path=None, normalise=True, flatten=True, onehot=True)

kwarg:

 path - str: FashionMNIST datasets directory. 
             Default to current directory/FashionMNIST. 
             Create if nonexistant. Download any missing files.
 normalise - boolean: yes -> pixels RGB values [0,255] divided by 255.
                      no  -> pixels RGB values [0,255].
 flatten   - boolean: yes -> pixels of all images stored as 2D numpy array.
                      no  -> pixels of all images stored as 3D numpy array.
 onehot    - boolean: yes -> labels stored as one-hot encoded numpy array.
                      no  -> labels values used.

Returns a nested dictionary:

 {'train': {'images': train_images, 'labels': train_labels},
  'test': {'images': test_images, 'labels': test_labels}}
 where,
  train_images = FashionMNISTimages(magic_number=2051, nimages=60000,
                                    nrows=28, ncols=28, pixels=np.array())
        if normalise, pixels dtype='float32' & [0.0(white), 1.0(black)]
        else,         pixels.dtype='uint8' & [0(white), 255(black)]
        if flatten,   pixels.shape=(60000, 784)
        else,         pixels.shape=(60000, 28, 28)
  train_labels = FashionMNISTlabels(magic_number=2049, nlabels=60000,
                                    labels=np.array() dtype='uint8')
        if onehot,    labels.shape=(60000, 10)
        else,         labels.shape=(60000,)
  test_images = FashionMNISTimages(magic_number=2051, nimages=10000,
                                   nrows=28, ncols=28, pixels=np.array())
        if normalise, pixels dtype='float32' & [0.0(white), 1.0(black)]
        else,         pixels dtype='uint8' & [0(white), 255(black)]
        if flatten,   pixels.shape=(10000, 784)
        else,         pixels.shape=(10000, 28, 28)
  test_labels = FashionMNISTlabels(magic_number=2049, nlabels=10000,
                                   labels=np.array() dtype='uint8')
        if onehot,    labels.shape=(10000, 10)
        else,         labels.shape=(10000,)

Remarks:

FashionMNISTimages() and FashionMNISTlabels() are dataklass objects. On my system, they performed ~25x faster than python3 built-in dataclass objects and 5x faster than namedtuple.

How to use?

from fashion_mnist import load_FashionMNIST  # Import function from module
fmdb = load_FashionMNIST()                   # Get FashionMNIST database using default settings
train_images = fmdb['train']['pixels']       # A 60000x784 numpy array with float32 values    
train_labels = fmdb['train']['labels']       # A 60000x10 numpy array with uint8 values
test_images = fmdb['test']['pixels']         # A 10000x784 numpy array with float32 values    
test_labels = fmdb['test']['labels']         # A 10000x10 numpy array with uint8 values