IndexError when n_classes < 6
seankortschot opened this issue · 1 comments
Hello,
First, thank you very much for your code. It has helped me out a lot.
I tried to change n_classes to 2, as I am only classifying between two states. However, I receive an IndexError whenever I reduce n_classes below 6. The error message is below:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-15-581f85d3f7ff> in <module>()
43 feed_dict={
44 x: X_test,
---> 45 y: one_hot(y_test)
46 }
47 )
<ipython-input-13-7d65b978d73d> in one_hot(y_, n_classes)
50 # Function to encode output labels from number indexes
51 y_ = y_.reshape(len(y_))
---> 52 return np.eye(n_classes)[np.array(y_, dtype=np.int32)] # Returns FLOATS
IndexError: index 2 is out of bounds for axis 0 with size 2
I'm not sure if I'm just misunderstanding what n_classes is supposed to represent or if there is a bug when it is reduced below 6 (any number I set it to that is greater than 6 still works).
My data has the shape:
X_train: (6312, 50, 9)
y_train: (6312, 1)
X_test: (1578, 50, 9)
y_test: (1578, 1)
Where the sole feature in the y arrays are labelled either 1 or 2 for my two classes.
My hyperparameters are currently set to:
training_data_count = len(X_train)
test_data_count = len(X_test)
n_steps = len(X_train[0])
n_input = len(X_train[0][0])
# NN Internal Structure
n_hidden = 32
n_classes = 2
# Training
learning_rate = 0.001
lambda_loss_amount = 0.0015
training_iters = training_data_count * 300 # Loop 300 times on the dataset
batch_size = 1500
display_iter = 30000 # To show test set accuracy during training
and the one_hot function that I'm using is a fix that you suggested in another issue
def one_hot(y_, n_classes=n_classes):
# Function to encode output labels from number indexes
y_ = y_.reshape(len(y_))
return np.eye(n_classes)[np.array(y_, dtype=np.int32)] # Returns FLOATS
Any help with this is much appreciated.
Best,
Sean
I seem to have fixed the problem by changing the labels of my classes to 0 and 1 instead of 1 and 2, as the y hot was adding a third column with labels of 1 and 2.
Apologies for raising an unnecessary issue.