kanishkamisra/minicons

Discrepancies in output between simonepri/lm-scorer and minicons libraries

Closed this issue · 8 comments

Issue

I am the author of hashformers, a state-of-the-art library for hashtag segmentation.

I am currently transitioning from using simonepri/lm-scorer to using minicons as the backbone for my library. So my goal right now is to replicate the exact scores produced by lm-scorer using minicons.

Here's the original code snippet using lm-scorer:

import torch
from lm_scorer.models.auto import AutoLMScorer as LMScorer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size = 1
scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size)
logprobs = scorer.tokens_score("I like this package.")
print(logprobs)

The corresponding scorer in the lm-scorer library can be found here.

In my attempts to duplicate this functionality with minicons, I came up with the following code:

import torch
from minicons import scorer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = scorer.IncrementalLMScorer('gpt2', device)
logprobs = model.compute_stats(model.prepare_text("I like this package."))
print(logprobs)

However, this code doesn't provide the expected output. The differences include:

  • The number of tokens returned by the two libraries doesn't match.
  • The absolute value of each token score differs between the two outputs.
  • The relative positions of the token scores also differ. np.argsort returns different results for each library's output.

This is the output produced by lm-scorer ( the relevant scores are the first list in the tuple ):

(
  [0.018321018666028976, 0.0066428035497665405, 0.08063317090272903, 0.000607448979280889, 0.277709037065506, 0.0036384568084031343], 
  [40, 588, 428, 5301, 13, 50256], 
  ['I', 'Ġlike', 'Ġthis', 'Ġpackage', '.', '<|endoftext|>']
)

This is the output produced by minicons:

Using pad_token, but it is not set yet.

[[-6.164241790771484, -3.1028060913085938, -7.756439208984375, -1.4581527709960938]]

Code comparison

It's hard to spot where is the difference because the code in both libraries is rather similar:

minicons

minicons/scorer.py:

        ids = [
            [i for i in instance if i != self.tokenizer.pad_token_id]
            for instance in encoded["input_ids"].tolist()
        ]

        ## Ignore the probabilities of the first token.
        effective_ids = [id[1:] for id in ids]

        with torch.no_grad():
            logits = self.model(**encoded).logits.detach()

        logits[:, :, self.tokenizer.pad_token_id] = float("-inf")

        logits = logits.split([1] * len(offsets))

        ## Set up storage variables
        scores = []
        if rank:
            ranks = []

        for logit, idx, offset in zip(logits, effective_ids, offsets):
            length = len(idx)
            logit = logit.squeeze(0)[:, :-1][
                torch.arange(offset, length),
            ]

            logprob_distribution = logit - logit.logsumexp(1).unsqueeze(1)
            query_ids = idx[offset:]

...


                    score = logprob_distribution[
                        torch.arange(length - offset), query_ids
                    ].tolist()

...

            scores.append(score)

...

            return scores

lm-scorer

lm-scorer/gpt2.py:

with torch.no_grad():
            ids = encoding["input_ids"].to(self.model.device)
            attention_mask = encoding["attention_mask"].to(self.model.device)
            nopad_mask = ids != self.tokenizer.pad_token_id
            logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0]

        for sent_index in range(len(text)):
            sent_nopad_mask = nopad_mask[sent_index]
            # len(tokens) = len(text[sent_index]) + 1
            sent_tokens = [
                tok
                for i, tok in enumerate(encoding.tokens(sent_index))
                if sent_nopad_mask[i] and i != 0
            ]

            # sent_ids.shape = [len(text[sent_index]) + 1]
            sent_ids = ids[sent_index, sent_nopad_mask][1:]
            # logits.shape = [len(text[sent_index]) + 1, vocab_size]
            sent_logits = logits[sent_index, sent_nopad_mask][:-1, :]
            sent_logits[:, self.tokenizer.pad_token_id] = float("-inf")
            # ids_scores.shape = [seq_len + 1]
            sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1)
            # log_prob.shape = [seq_len + 1]
            sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1)

            sent_log_probs = cast(torch.DoubleTensor, sent_log_probs)
            sent_ids = cast(torch.LongTensor, sent_ids)

            output = (sent_log_probs, sent_ids, sent_tokens)
            outputs.append(output)

        return outputs

I'd appreciate any insights or suggestions on how to make the minicons output match the lm-scorer output.

Thank you for your assistance!

Hi - sorry for responding to this late! The main difference here is that lm-scorer prepends all sentences with the <|endoftext|> token and that way you also get to have the probability of the first token in the sentence. Since the target audience for minicons was mostly psycholinguists and folks using psycholing methods to evaluate/use LMs, it was recommended to me to start from P(second token | first token)... and so on. In principle, if you prepend the sentence input with <|endoftext|> token you are assuming that the first real word of the sentence is also the first word in a document, which may not be completely valid.

If your task really needs the logprob/prob of the first token, I can spend some time adding a flag to the compute_stats function and that should work, though I cannot promise anything to be done quickly since I have very little bandwidth in the next two weeks.

Regardless, I would recommend you try out minicons and see if the results do matter -- they ideally shouldn't.

Does this make sense?

Regardless, I recommend you try out minicons and assess if the results matter -- ideally, they shouldn't.

Does this make sense?

I need to clarify that I've already tested minicons extensively as a replacement for the lm-scorer library. The results achieved by lm-scorer are significantly better. In some tasks, lm-scorer attains F1-scores of around 95%, whereas minicons reaches only about 60% under the same conditions.

This discrepancy stems from the fact that hashtag segmentation depends on comparing probabilities between very similar token sequences, differing only by the space position (e.g., "#hello world" vs. "#hell oworld"). In this context, every token is essential, and omitting P(first token | <|endoftext|>) has a profound negative impact on the results.

If your task indeed requires the logprob/prob of the first token, I can add a flag to the compute_stats function. That should work.

I would greatly appreciate that. I've already transitioned the backend of the hashformers library to minicons, as documented in the project wiki.

I'm excited about this PR because I believe it could potentially set a new state-of-the-art for hashtag segmentation. While lm-scorer limits us to GPT-2, minicons will allow us to experiment with a wider range of models. All it needs is a minor adjustment to cater to our specific task.

Got it -- I will try to get that done asap! I had no idea LMs were being used for hashtag segmentation -- this seems like a nice idea :D For clarification this would only be a change in autoregressive/incremental models like GPT2, I think BERT-style models already do give you first token probability. For clarification, would you also want an <|endoftext|> at the end? I guess perhaps I can make them separate flags since bos_token and eos_token might differ for different models.

Got it -- I will try to get that done asap! I had no idea LMs were being used for hashtag segmentation -- this seems like a nice idea :D

It's actually the state-of-the-art approach, as documented in this LREC 2022 paper.

For clarification this would only be a change in autoregressive/incremental models like GPT2, I think BERT-style models already do give you first token probability.

That's great. The probabilities produced by BERT-style models are suitable only for breaking ties between the top segmentations suggested by an autoregressive/incremental model, so we need to solve this issue first before I can test them.

By the way: correct me if I'm wrong, but it seems to me that Seq2SeqScorer also does not skip the first token probability.

For clarification, would you also want an <|endoftext|> at the end? I guess perhaps I can make them separate flags since bos_token and eos_token might differ for different models.

It's best to have bos_token and eos_token as separate flags to make the code as flexible as possible. lm-scorer also defines bos_token and eos_token as separate flags, despite both of them being equivalent to <|endoftext|> for GPT-2.

By the way: correct me if I'm wrong, but it seems to me that Seq2SeqScorer also does not skip the first token probability.

Yes I think so -- I am not quite sure what the right way to do Seq2SeqScoring is, looping in the wonderful @aaronmueller who implemented in that functionality and has done a lot work using scores elicited from LMs such as T5 and BART.

It's best to have bos_token and eos_token as separate flags to make the code as flexible as possible. lm-scorer also defines bos_token and eos_token as separate flags, despite both of them being equivalent to <|endoftext|> for GPT-2.

Sounds good -- I will try to get this done soon (should be straightfoward in practice).

I come bearing gifts! This issue seems to be resolved (I will let you close it). My latest commit adds flags for bos and eos tokens for the incremental LM scorer, you can now do:

pip install -U minicons

and then:

from minicons import scorer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = scorer.IncrementalLMScorer('gpt2', device)
print(model.token_score("I like this package.", bos_token=True, eos_token=True, prob=True))

'''
[[('<|endoftext|>', 0.0),
  ('I', 0.018320878967642784),
  ('like', 0.006643006112426519),
  ('this', 0.08063255250453949),
  ('package', 0.0006074582342989743),
  ('.', 0.27771538496017456),
  ('<|endoftext|>', 0.003638095920905471)]]
'''

If you want the weird "G" prefixes, pass an additional flag with decode=False

If you prefer the compute_stats method, do:

probs = model.compute_stats(model.prepare_text("I like this package.", eos_token=True, bos_token=True),  prob=True)
print(probs)

'''
[[0.018320878967642784, 0.006643006112426519, 0.08063255250453949, 0.0006074582342989743, 0.27771538496017456, 0.003638095920905471]]
'''

Thanks again for raising this issue, this has made minicons a better library :D

I've tested your solution against lm-scorer and the scores matched exactly, just as we thought they would. With this, I'm happy to say that the issue is now solved. Thank you for responding so quickly and helping us fix this big problem. This was holding us back from releasing hashformers v2.0, but now we're good to go.

That is great -- glad that minicons is proving to be useful! Thanks again :)