keras-team/keras-io

TypeError: Sampler.__call__() got an unexpected keyword argument 'end_token_id'

Closed this issue · 2 comments

Issue Type

Bug

Source

binary

Keras Version

3.2.1

Custom Code

No

OS Platform and Distribution

Colab

Python version

3.10.12

GPU model and memory

Colab T4 GPU

Current Behavior?

I've been trying to run the English-to-Spanish translation with KerasNLP but I get stuck at the prediction/eval part:

--------------------------------------------------------------------------- 

TypeError                                 Traceback (most recent call last)

[<ipython-input-17-7f8fd291e436>](https://localhost:8080/#) in <cell line: 35>()
     35 for i in range(2):
     36     input_sentence = random.choice(test_eng_texts)
---> 37     translated = decode_sequences([input_sentence])
     38     translated = translated.numpy()[0].decode("utf-8")
     39     translated = (

[<ipython-input-17-7f8fd291e436>](https://localhost:8080/#) in decode_sequences(input_sentences)
     22     prompt = ops.concatenate((start, pad), axis=-1)
     23 
---> 24     generated_tokens = keras_nlp.samplers.GreedySampler()(
     25         next,
     26         prompt,

TypeError: Sampler.__call__() got an unexpected keyword argument 'end_token_id'

The model trains successfully, though.

Standalone code to reproduce the issue or tutorial link

Colab: https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/neural_machine_translation_with_keras_nlp.ipynb (Runtime: T4 GPU)
Docs: https://keras.io/examples/nlp/neural_machine_translation_with_keras_nlp/

Relevant log output

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-17-7f8fd291e436> in <cell line: 35>()
     35 for i in range(2):
     36     input_sentence = random.choice(test_eng_texts)
---> 37     translated = decode_sequences([input_sentence])
     38     translated = translated.numpy()[0].decode("utf-8")
     39     translated = (

<ipython-input-17-7f8fd291e436> in decode_sequences(input_sentences)
     22     prompt = ops.concatenate((start, pad), axis=-1)
     23 
---> 24     generated_tokens = keras_nlp.samplers.GreedySampler()(
     25         next,
     26         prompt,

TypeError: Sampler.__call__() got an unexpected keyword argument 'end_token_id'

This is how I fixed it for my code. I highly suspect there's a better way to fix this, but in short three things needed to be fixed:

  1. Change 'end_token_id' to 'stop_token_ids'
  2. Wrap the token id argument in a list
  3. add '.to_tensor()' to the encoder input tokens'

(Note, I extended mine from Spanish only to French, so my encoders are named "fr_..." instead of "spa_")

def decode_sequences(input_sentences):
    batch_size = 1

    # Tokenize the encoder input.
    encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))
    if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:
        pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)
        #encoder_input_tokens = ops.concatenate([encoder_input_tokens.to_tensor(), pads], 1) # <-- Original
        encoder_input_tokens = ops.concatenate([encoder_input_tokens.to_tensor(), pads], 1) # <-- Add ".to_tensor()" at the base of this

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        # Ignore hidden states for now; only needed for contrastive search.
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length 40 with a start token and padding tokens.
    length = 40
    start = ops.full((batch_size, 1), fr_tokenizer.token_to_id("[START]"))
    pad = ops.full((batch_size, length - 1), fr_tokenizer.token_to_id("[PAD]"))
    prompt = ops.concatenate((start, pad), axis=-1)

    generated_tokens = nlp.samplers.GreedySampler()(
        next,
        prompt,
        # end_token_id = fr_tokenizer.token_to_id("[END]") #<-- Original
        stop_token_ids=[fr_tokenizer.token_to_id("[END]")], #<-- Change argument name and wrap in list
        index=1,  # Start sampling after start token.
    )
    generated_sentences = fr_tokenizer.detokenize(generated_tokens)
    return generated_sentences

Are you satisfied with the resolution of your issue?
Yes
No