Error while using CRF with LSTM
himanshudce opened this issue · 1 comments
Here is my model
input = Input(shape=(19,))
word_embedding_size = 100
n_words=len(word2idx)
n_tags=len(tag2idx)
model = Embedding(input_dim=n_words, output_dim=word_embedding_size, input_length=19)(input)
model = Bidirectional(LSTM(units=word_embedding_size,
return_sequences=True,
dropout=0.5,
recurrent_dropout=0.5,
kernel_initializer=keras.initializers.he_normal()))(model)
model = LSTM(units=word_embedding_size * 2,
return_sequences=True,
dropout=0.5,
recurrent_dropout=0.5,
kernel_initializer=keras.initializers.he_normal())(model)
model = TimeDistributed(Dense(n_tags, activation="relu"))(model) # previously softmax output layer
crf = CRF(n_tags) # CRF layer
out = crf(model) # output
model = Model(input, out)
I am getting these errors, Is this repo not compatible with python3.8 or newer tensorflow versions
TypeError: in user code:
/home/himanshu/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/home/himanshu/.local/lib/python3.8/site-packages/keras_contrib/layers/crf.py:292 call *
test_output = self.viterbi_decoding(X, mask)
/home/himanshu/.local/lib/python3.8/site-packages/keras_contrib/layers/crf.py:564 viterbi_decoding *
argmin_tables = self.recursion(input_energy, mask, return_logZ=False)
/home/himanshu/.local/lib/python3.8/site-packages/keras_contrib/layers/crf.py:523 recursion *
target_val_last, target_val_seq, _ = K.rnn(_step, input_energy,
/home/himanshu/.local/lib/python3.8/site-packages/keras/backend/tensorflow_backend.py:3104 rnn *
reachable = tf_utils.get_reachable_from_inputs([learning_phase()],
/home/himanshu/.local/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py:140 get_reachable_from_inputs **
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
TypeError: Expected Operation, Variable, or Tensor, got
Found the solution, There is a compatibility issue with the latest python version.
Working with - python==3.7 or less