lizekang/ITDD

argmax between first encoder and second decoder

Closed this issue · 7 comments

You are using argmax between the first encoder and second decoder. How is the back propagation happening since argmax is non-differentiable ?

In contrast to the original Deliberation Network (Xia et al., 2017), where they propose a joint learning framework using Monte Carlo Method, we don't perform backpropagation from the second decoder to the first decoder as "Modeling coherence for discourse neural machine translation" does.

So you are updating weights of both the first decoder and second decoder using their individual loss separately ?

Then why are you detaching the first decoder at the end of every decoder time step?

So you are updating weights of both the first decoder and second decoder using their individual loss separately ?

Then why are you detaching the first decoder at the end of every decoder time step?

Yes, we update the weights using their individual loss separately.
For the second problem, please provide some details.

When computing the gradients using loss.backward() for the second decoder, gradients of the encoder should also change ?

I tried printing the gradients of some of the encoder parameters but the gradients did not change.

            print('first_decoder')
            batch_stats1 = self.train_loss.sharded_compute_loss(
                batch, first_outputs, first_attns, j,
                trunc_size, self.shard_size, normalization)

            for name, param in self.model.named_parameters():
                if 'encoder.htransformer.layer_norm.weight' in name:
                    print(param.grad)

            print('second_decoder')

            batch_stats2 = self.train_loss.sharded_compute_loss(
                batch, second_outputs, second_attns, j,
                trunc_size, self.shard_size, normalization)

            for name, param in self.model.named_parameters():
                if 'encoder.htransformer.layer_norm.weight' in name:
                    print(param.grad)

In such type of cases how do you ensure proper back propagation?

The context encoder doesn't change because the second decoder only uses the knowledge representation and the first-pass output representation. There is no grad propagate to the context encoder (encoder.htransformer).

Thanks for your response. I understood a lot of things. BTW have you ever used gumbel softmax ?

No, I haven't used the gumbel softmax. But I think it will work. You can have a try.