threelittlemonkeys/lstm-crf-pytorch

transition from last word to <END> tag not added in the the score function

zzyxzz opened this issue · 1 comments

Hi, maybe, I am wrong but it's confusing that in the following function, SOS_IDX was added to the begin of the tag sequence but the EOS_IDX tag wasn't added to the end of the tag sequence. As such,
tag sequence will look like <START> A B C D. When calculating the trans score, it stops at self.trans[3+1, 3] (i.e. transition score C->D) which doesn't calculate the transition score from D to <END>.

    def score(self, y, y0, mask): # calculate the score of a given sequence
        score = Tensor(BATCH_SIZE).fill_(0.)
        y0 = torch.cat([LongTensor(BATCH_SIZE, 1).fill_(SOS_IDX), y0], 1)
        for t in range(y.size(1)): # iterate through the sequence
            mask_t = mask[:, t]
            emit = torch.cat([y[b, t, y0[b, t + 1]].unsqueeze(0) for b in range(BATCH_SIZE)])
            trans = torch.cat([self.trans[seq[t + 1], seq[t]].unsqueeze(0) for seq in y0]) * mask_t
            score = score + emit + trans
        return score

Hi, if you look at train.py, you will see that EOS is added in the load_data() function. But I also think it is a bit confusing that these happen in two different functions... I will try to find a way to fix this inconsistency!