Keras implementation of Column Attention
anshudaur opened this issue · 0 comments
HI All, I need some help in coding column attention using keras :
Here is my code for aggregate column prediction: (max_len is the maximum length of question/columns )
n_h = 128 # number of hidden units
question_input = Input(shape=(max_len,),name='Question_input')
column_input = Input(shape=(max_len,),name='Column_input')
embedding= Embedding(max_token_index, n_h, input_length=max_len,name='embedding')
Q_embedding= embedding(question_input)
C_embedding= embedding(column_input)
encoder_question = Bidirectional(LSTM(n_h, return_state=True, return_sequences=True))
Q_enc , Q_state_h1, Q_state_h2 = encoder_question(Q_embedding)
encoder_column = Bidirectional(LSTM(n_h, return_state=True, return_sequences=True))
C_enc , C_state_h1, C_state_h2 = encoder_column(C_embedding)
########## Column Attention Code ########
Q_num_att = Dense(max_len,activation='relu')(Q_enc)
Q_self = Dense(max_len,activation='relu')(Q_num_att)
att_val_qc_num = Concatenate()([Q_self,C_enc])
att_prob_qc_num = Dense(maxlen,activation='softmax')(att_val_qc_num)
q_weighted_num = (Q_enc * att_prob_qc_num).sum(axis=0, keepdims=True)
########## Column Attention Code ############
col_num_out_q = Dense(max_len,activation='relu')(q_weighted_num)
col_num_out = Dense(max_len,activation='tanh')(col_num_out_q)
#con=Concatenate()([Q_state_h1,Q_state_h2,C_state_h1,C_state_h2])
final=Dense(6,activation='softmax')(col_num_out)
model = Model([question_input, column_input], final)
model.summary()
Please correct me if i am wrong.