allanj/pytorch_neural_crf

computation of partition function

y1450 opened this issue · 11 comments

y1450 commented

I am trying to compute probability from viterbi score but could not get the values to make sense.

As a sanity check I tried pytorch crf https://pytorch-crf.readthedocs.io/en/stable/ and its z values matched with my naive implementation , but could not get the same results with this implementation.
Is there any way to verify the computation of unlabeled score?
I used the same potential and pairwise potential matrices.

I looked for tests in the repository but could not find it. Any suggestion for writing a test case would be really helpful.

You mean test the correctness with pytorch-crf?

y1450 commented

yes, the correctness of implementation.
I tested it with small sequence length of N=3, basically hard coded the values of transition and scores in the respective implementations and compare the forward labeled scores.
I can share a dummy sample notebook if you would like to look?

Sure. no problem. That will be appreciated

Not sure how you try, but I tried the example here (https://pytorch-crf.readthedocs.io/en/stable/).
everything looks fine.

image

import torch
from torchcrf import CRF
num_tags = 5  # number of tags is 5
model = CRF(num_tags)
seq_length = 3
batch_size = 2
emissions = torch.randn(seq_length, batch_size, num_tags)
tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long)
print(model(emissions, tags))


from src.model.module import LinearCRF
from src.data.data_utils import START_TAG, STOP_TAG, PAD
labels = ['A', 'B', 'C', 'D', 'E', START_TAG, STOP_TAG, PAD]
label2idx = {label: idx for idx, label in enumerate(labels)}
mycrf = LinearCRF(label_size=8, label2idx= label2idx, add_iobes_constraint=False, idx2labels=labels)
converted_emissions = torch.cat([emissions.transpose(0, 1), torch.full((batch_size, 3, 3), -10000)], dim=-1)
converted_tags= tags.transpose(0,1)
mask=torch.tensor([[1,1,1],[1,1,1]]).bool()
mycrf.transition.data[:5, :5] = model.transitions.data ## transition between labels
mycrf.transition.data[5, :5] = model.start_transitions.data ## transition from start tag
mycrf.transition.data[:5, 6] = model.end_transitions.data ## transition to end tag
unlabeled, labeled = mycrf(converted_emissions, torch.LongTensor([3,3]), converted_tags, mask)
loglikelihood = labeled- unlabeled
print(loglikelihood)
y1450 commented

thanks for the example, I'll try it with the example and update you on it.

y1450 commented

yes. It works. the log likelihood and norm computation are correct and same for both implementations.
My mistake was that I thought the num_tags in pytorch crf included pad,start and stop tags.
Thank you very much. You have been very helpful with issues.

y1450 commented

I am having a slight confusion about computing probability of most likely sequence and would appreciate your help.
viterbi score to probability
As per my understanding it is
np.exp(viterbi_score - z) , both viterbi_score and z are in log space.

if I change my my dataloader batch_size then the value of z changes as it is dependent on the batch (input). The problem is then I cant use batch_size > 1.
which is not efficient.

On intuitive side, the probability of sequence should be same given the same model parameters, irrespective of the batch_size.
Is my reasoning correct?

Not really.

Your viterbi score is always with size batch_size, same as z.
So even if you do exponential you still have the final probability with size "batch_size".

So the size of your batch, will never affect the value in each instance

y1450 commented

Ok. I'll try to write a test case and validate.
As per size of z, on top of mind I think it is a scalar not a vector of batch_size which requires setting batch_size=1.

y1450 commented

my Mistake , I assumed the z returned were of batchsize but the forward unlabel already sums them and then return it.


here last_alphas are the z values.
Maybe it would be a good idea, send the individual z rather than sum(z) , which would make forward_unlabeled api more useful,eg computing probability.
Again, thanks for help.

yeah, seems it might make the code more readable. I will see how I can do that