wasiahmad/NeuralCodeSum

Split decoder

Closed this issue · 1 comments

Hi, is the split decoder part implemented? I tried ur code with argument args.split_decoder True and got this error:

Epoch = 1 [perplexity = x.xx, ml_loss = x.xx]: 0% 0/939 [00:00<?, ?it/s]Traceback (most recent call last):
File "../../main/train.py", line 708, in
main(args)
File "../../main/train.py", line 653, in main
train(args, train_loader, model, stats)
File "../../main/train.py", line 283, in train
net_loss = model.update(ex)
File ".../model.py", line 173, in update
example_weights=ex_weights)
File ".../module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File ".../transformer.py", line 435, in forward
**kwargs)
File ".../transformer.py", line 363, in _run_forward_ml
summ_emb)
File ".../module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "...transformer.py", line 295, in forward
return self.decode(tgt_pad_mask, tgt_emb, memory_bank, state)
File "...transformer.py", line 273, in decode
f_t = self.fusion_sigmoid(torch.cat([copier_out, dec_out], dim=-1))
TypeError: expected Tensor as element 0 in argument 0, but got list
Epoch = 1 [perplexity = x.xx, ml_loss = x.xx]: 0% 0/939 [00:01<?, ?it/s]

should we use torch.stack() ??? thanks in advance

Yes, the split decoder mechanism is implemented, however, I found a mistake that is causing the error. The error is due to copier_out being a list of Tensor, instead of a Tensor. As you can see here, TransformerDecoder returns a list of representations where each item is the output from each layer of the TransformerDecoder. So, to make the split decoder mechanism work, you need to make 2 changes.

First, in the following line,

f_t = self.fusion_sigmoid(torch.cat([copier_out, dec_out], dim=-1))

copier_out and dec_out are both lists. So, modify the line as follows.

f_t = self.fusion_sigmoid(torch.cat([copier_out[-1], dec_out[-1]], dim=-1))            

Second,

When split decoder mechanism is enabled, decoder_outputs [as in here] is no longer a list of Tensor, it is only a Tensor. Therefore, to cope with the statements 1 and 2, you can simply do:

decoder_outputs = self.fusion_gate(gate_input)
decoder_outputs = [decoder_outputs]

Hopefully, this would work but please verify it. A pull request is welcomed.