keras-team/keras-nlp

create local variable per_token_loss in score method to global. So that we can modify loss function.

Closed this issue · 4 comments

Is your feature request related to a problem? Please describe.

I want to use custom loss function for per_token_loss_fn. So I want to make it it accessible outside the class to update.

Describe the solution you'd like

from this
per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="none"
)
to this
self.per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="none"
)
Describe alternatives you've considered

Additional context

@deveshklt can you explain more of your overall use case? E.g. what other loss you would like to pass, and for what end purpose?

The score() function is new functionality being added by @RyanMullins, in particular for interpretability applications (though it's a general function with many uses). Feedback welcome as we are building it out, but understand the user journeys here will help.

My use case is comparing two sentences and calculate divergence between them.
I want to use Kullback-Leibler (KL) Divergence Loss for this.
I want to use some different per token loss function for this.

Hi @deveshklt! As @mattdangerw mentioned, I've added a .score() API to the KerasNLP's implementations of Gemma, Llama, Mistral, and GPT-2. This function is inspired by the scoring mode in Google's T5X modeling framework (GitHub, paper); you provide a tokenized representation of a sequence and this API computes either the logits or the per-token loss for that sequence from the model, depending on the value of scoring_mode. If run in "logits" mode (the default), you can compute any custom loss that you like from the tensor this API returns (this was an intentional design choice to support use cases like yours).

lm = keras_nlp.models.CausalLM.from_preset("some_preset")

generations = ... # I assume you already have a list of strings here.

preprocessed = lm.preprocessor.generate_preprocess(generations)
generation_ids = preprocessed["token_ids"]
padding_mask = preprocessed["padding_mask"]

logits = lm.score(
    token_ids=generation_ids,
    padding_mask=padding_mask
)

model_loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=None)
# Compute loss with model_loss_fn

kldiv_loss_fn = keras.losses.KLDivergence(...)
# Compute loss with kldiv_loss_fn

# ...and so on with the other loss functions you're exploring.

I assume that since you're interested in KL Divergence, you have a dataset with some ground truth that you can use as the value of y_ture. In which case, you can use the .score() API to compute the logits for the ground truth and generation sequences and than pass those into the loss function as, for example, kldiv_loss_fn(gt_logits, gen_logits).

Thank you @RyanMullins for the explanation.