Prediction problem
Opened this issue · 3 comments
Amazing work!
But, there are one strange thing that I can't figure out. I'm trying to train my own dataset - almost mnist, but slightly extended to 21 symbol. Records exactly like mnist, loader too, and so on, but, one random class never predicted (for example 7). I have no idea why. Outputs contains 21 labels, but 7'th label always has very very small values.
I didn't encounter such problem. Could you tell me these:
- except for the data reading pipeline, did you modify the code and which part;
- Is the number of training samples for each class balanced?
Dataset is balanced
but I wrote my own test code
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework.errors_impl import OutOfRangeError
from sklearn.metrics import accuracy_score, confusion_matrix
from seaborn import heatmap
import capslayer as cl
import os
import cv2
chars_map = {
1: '1',
2: '2',
3: '3',
4: '4',
5: '5',
6: '6',
7: '7',
8: '8',
9: '9',
10: '0',
11: "A",
12: "B",
13: "C",
14: "E",
15: "H",
16: "K",
17: "M",
18: "P",
19: "T",
20: "X",
21: "Y"
}
WIDTH = 20
HEIGHT = 30
def elem_conv(elem):
image = elem['images']
image = np.reshape(image, newshape=(HEIGHT,WIDTH))
label = elem['labels']
return image, label
def parse_fn(serialized_example):
features = tf.parse_single_example(serialized_example,
features={'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64)})
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
image = tf.decode_raw(features['image'], tf.float32)
image = tf.reshape(image, shape=[height * width * depth])
image.set_shape([HEIGHT * WIDTH * 1])
image = tf.cast(image, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.int32)
features = {'images': image, 'labels': label}
return (features)
def get_model():
# Vector CapsNet
num_label = 21
in_images = tf.placeholder(tf.float32, [None, HEIGHT*WIDTH])
with tf.variable_scope('Conv1_layer'):
# Conv1, return with shape [batch_size, 20, 20, 256]
inputs = tf.reshape(in_images, shape=[-1, HEIGHT, WIDTH, 1])
conv1 = tf.layers.conv2d(inputs,
filters=256,
kernel_size=9,
strides=1,
padding='VALID',
activation=tf.nn.relu)
with tf.variable_scope('PrimaryCaps_layer'):
primaryCaps, activation = cl.layers.primaryCaps(conv1,
filters=32,
kernel_size=9,
strides=2,
out_caps_dims=[8, 1],
method="norm")
with tf.variable_scope('DigitCaps_layer'):
routing_method = "EMRouting"
num_inputs = np.prod(cl.shape(primaryCaps)[1:4])
primaryCaps = tf.reshape(primaryCaps, shape=[-1, num_inputs, 8, 1])
activation = tf.reshape(activation, shape=[-1, num_inputs])
poses, probs = cl.layers.dense(primaryCaps,
activation,
num_outputs=num_label,
out_caps_dims=[16, 1],
routing_method=routing_method)
# Decoder structure
# Reconstructe the inputs with 3 FC layers
with tf.variable_scope('Decoder'):
logits_idx = tf.to_int32(tf.argmax(cl.softmax(probs, axis=1), axis=1))
labels = tf.one_hot(logits_idx, depth=num_label, axis=-1, dtype=tf.float32)
labels_one_hoted = tf.reshape(labels, (-1, num_label, 1, 1))
masked_caps = tf.multiply(poses, labels_one_hoted)
num_inputs = np.prod(masked_caps.get_shape().as_list()[1:])
active_caps = tf.reshape(masked_caps, shape=(-1, num_inputs))
fc1 = tf.layers.dense(active_caps, units=512, activation=tf.nn.relu)
fc2 = tf.layers.dense(fc1, units=1024, activation=tf.nn.relu)
num_outputs = HEIGHT * WIDTH * 1
recon_imgs = tf.layers.dense(fc2,
units=num_outputs,
activation=tf.sigmoid)
recon_imgs = tf.reshape(recon_imgs, shape=[-1, HEIGHT, WIDTH, 1])
return in_images, recon_imgs, probs
def show_reconstruct(original, reconstruct, true_lbl, pred_lbl, lbl=None):
if lbl is not None:
ind = np.where(true_lbl==lbl)[0][0]
else:
ind = 0
original_image = original[ind]
reconstruct_image = reconstruct[0]
original_image = np.reshape(original_image, newshape=(HEIGHT,WIDTH))
true_lbl = chars_map[true_lbl[ind]]
pred_lbl = chars_map[pred_lbl[ind]]
title = true_lbl + 20*' ' + pred_lbl
res = np.hstack((original_image, reconstruct_image))
plt.imshow(res, cmap='gray')
plt.title(title)
plt.show()
def test(records):
"""
param records: list of .record files
"""
batch_size = 128
dataset = tf.data.TFRecordDataset(records)
dataset = dataset.map(parse_fn).batch(batch_size).repeat(1).shuffle(buffer_size=5000, seed=3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
inputs, recon_imgs, labels_one_hoted = get_model()
saver = tf.train.Saver()
true_labels, predicted = [], []
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(os.path.dirname('../models/models/results/logdir/model.ckpt-6600'))
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
while True:
try:
elem = sess.run(next_element)
raw_images = elem['images']
true_lbls = elem['labels'] + 1
reconstructed, pred_lbls = sess.run([recon_imgs,labels_one_hoted], feed_dict={inputs : raw_images})
reconstructed = np.squeeze(reconstructed)
pred_lbls = np.argmax(pred_lbls, axis=1) + 1
#show_reconstruct(raw_images, reconstructed, true_lbls, pred_lbls, lbl=9)
predicted.extend(pred_lbls)
true_labels.extend(true_lbls)
except OutOfRangeError as ex:
break
labels_id = np.arange(1, 11, 1).astype(np.int16)
labels = [chars_map[lbl] for lbl in labels_id]
conf_matr = confusion_matrix(true_labels, predicted)
ax = heatmap(conf_matr, annot=True, fmt='d')
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
plt.show()
if __name__ == '__main__':
records = ['data/symbols/eval_symbols.tfrecord']
test(records=records)
Checkpoint 6600 is not matter, I trained 50000 steps, and problem was the same
The code looks fine. I'm not sure what the problem is. What I can guess is that capsnet might have bias problem. But before making this conclusion, I suggest:
- visualize the input image and print its corresponding label (both for training and validation set) to make sure the dataset is right;
- then remove one or more class from dataset (not the 7th class), and train the model from scratch and test it again. To see if the 7th or any others class were never predicted.
Of course it might be an implementation problem, I will check my code again.