setting `max-length` to larger value doesn't affect the output length
xu3kev opened this issue · 9 comments
Hi, I'm trying out the sampling code. I set max-length to larger value as I'd like it to output longer sequence but I still got very short one as follows.
~/CodeGen$ python3 -m jaxformer.hf.sample --model codegen-350M-mono --context "def hello_world():" --max-length 1024
loading parameters
loading parameters took 14.13s
loading tokenizer
loading tokenizer took 7.02s
sampling
====================================================================================================
print("Hello World")
hello_world()
#
====================================================================================================
def hello_world():
print("Hello World")
hello_world()
====================================================================================================
sampling took 0.47s
done.
Is this expected behavior? Thanks!
Yes, this is expected.
I believe, in your case, you would want to set min-length, not max-length for model.generate(min_length=1024, ...) here:
https://github.com/salesforce/CodeGen/blob/main/jaxformer/hf/sample.py#L120
This will pull in a LogitProcessor which is manipulating the value for the eos token, see:
Hope it helps.
Thanks for the reply! When I use the checkpoint from HuggingFace like in the example code, I would get much longer output like the following result in Colab.
In[1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
inputs = tokenizer("def hello_world():", return_tensors="pt").to(0)
sample = model.generate(**inputs, do_sample=True, max_length=512)
print(tokenizer.decode(sample[0]))
Out[1]:
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def hello_world():
global name
name = "Samira"
print("Hello")
hello_world()
print(name)
# In python, a global statement is at end of file
def show_info_me_name():
print("my name is, " + name)
def set_name():
global name
name = "Samira"
show_info_me_name()
set_name()
I try various decoding settings (temperature sampling, greedy decoding) but still can't seem to match this. Do the two methods have different implementations of the generate
function?
@xu3kev This is probably because of the optional truncate_before_pattern
implemented for CodeGenTokenizer, while our sampling code by default truncates based on a pattern. Please try setting with patterns as in here.
Thanks for the suggestion! I removed all of them and the output is still the same. I can't find the place where it is given to the tokenizer.
This repo's default temperature and top_p are 0.2 ad 0.95, respectively. But it seems that transformers' generate()
has set them to 1.0 and 1.0 by default. When you sample from our repo, please add the argument like --p 1.0 --t 1.0
.
Where to add patterns in the tokenizer (transformers v4.21.3):
In [5]: import re
In [6]: patterns = [
...: '^#',
...: re.escape('<|endoftext|>'),
...: "^'''",
...: '^"""',
...: '\n\n\n'
...: ]
In [7]: import torch
...: from transformers import AutoTokenizer, AutoModelForCausalLM
...: tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
...: model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
...: inputs = tokenizer("def hello_world():", return_tensors="pt")
...: sample = model.generate(**inputs, do_sample=True, max_length=512)
...: print(tokenizer.decode(sample[0], truncate_before_pattern=patterns))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def hello_world():
print ("Hello!")
if __name__=="__main__":
hello_world()
Thanks! I tried again to compare the two methods by setting do_sample
to False
to try to do greedy decoding.
The output from the following Huggingface's approach is pretty long
In[1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
inputs = tokenizer("# this function prints hello world", return_tensors="pt").to(0)
sample = model.generate(**inputs, do_sample=False, max_length=512)
print(tokenizer.decode(sample[0]))
Out[1]:
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
# this function prints hello world
# this function takes a string as an argument
# it prints the string in the following format:
# Hello, World!
# this function takes a string as an argument
# it prints the string in the following format:
# Hello, World!
# this function takes a string as an argument
# it prints the string in the following format:
# Hello, World!
# this function takes a string as an argument
# it prints the string in the following format:
# Hello, World!
# this function takes a string as an argument
# it prints
On the other hand, the sampling code in this repo produced a nearly empty result.
Modifying sampling code:
with torch.no_grad():
input_ids = input_ids.to(device)
tokens = model.generate(
input_ids,
do_sample=False,
num_return_sequences=num_return_sequences,
max_length=input_ids_len + max_length_sample,
pad_token_id=pad_token_id,
use_cache=True,
)
text = tokenizer.batch_decode(tokens[:, input_ids_len:, ...])
Output:
~/CodeGen$ python3 -m jaxformer.hf.sample --model codegen-350M-mono --context "# this function prints hello world" --max-length 1024
loading parameters
loading parameters took 13.88s
loading tokenizer
loading tokenizer took 6.99s
sampling
====================================================================================================
#
====================================================================================================
It seems very strange to me. Could you help me understand what might be the issue?
Thanks for the observation. Let us investigate and get back to you.
Thanks for the reply! If possible, could you reopen the issue, so it would be easier to track this? Or would it be better if I open another issue about the problem that I can't match the two greedy decoding methods?