/Restricted-Boltzmann-Machine

An Numpy implementation of RBM.

Primary LanguagePython

Restricted Boltzmann Machine

all

(The first column is the input. Other columns are reconstructed outputs.)

Requirements

python 3.6

numpy >= 1.14

matplotlib >= 3.0.0

How to use

mnist_bin.npy is an numpy binary file downloaded from Mnist or github source, which contains 6 million images of hand written digits (0 - 9), with 28x28 as image shape. Load this binary using numpy.

import numpy as np
mnist = np.load('mnist_bin.npy')  # 60000x28x28

To use RBM from rbm.py, specify the number of hidden and visible units in initialization.

rbm = RBM(n_hidden=100, m_observe=28 * 28)

Train the RBM with train method, and feed it with data.

rbm.train(mnist[:200], epochs=10)

After training, you can sample from RBM. What you get should be an image of a hand written digit generated by the model, which is not in the origin dataset. Usually, a good initial image produces better results than random initialized inputs.

v = rbm.sample(num_iter=200, v_init=mnist[0])

Visualize the output with matplotlib.

plt.imshow(v.reshape((28, 28)), cmap="gray")
plt.show()

Image of v

The full script:

import numpy as np
import matplotlib.pyplot as plt

mnist = np.load('mnist_bin.npy')  # 60000x28x28
n_imgs, n_rows, n_cols = mnist.shape
img_size = n_rows * n_cols
print(mnist.shape)

# construct rbm model
rbm = RBM(n_hidden=100, m_observe=28 * 28)

print("Start RBM training.")
# train rbm model using mnist
rbm.train(mnist[:200], epochs=10)
print("Finish RBM training.")

# sample from rbm model
v = rbm.sample(num_iter=200, v_init=mnist[0])
plt.imshow(v.reshape((28, 28)), cmap="gray")
plt.show()

For details about RBM, refer to this report.