Output Differs from Hugging Face Transformer Result and EasyDel Results
jchauhan opened this issue · 4 comments
Describe the bug
We are running a model on TPU v3.8 instance using easydel. It worked great after your suggestion here. #114
However, the output produced by Easydel Hosted Model is wrong
Question: Which of the following is an example of monosomy?
Options:
46,XX
47,XXX
69,XYY
45,X
Please provide your choice first and then provide explanations if possible.
The correct answer is 46,XX.
Monosomy is a condition where a person has only one copy of a particular chromosome. In the case of 46,XX, a person has only one X chromosome. This is the most common form of monosomy and is typically associated with Turner syndrome, a genetic disorder that affects females.
47,XXX is also a form of monosomy, but it is a different type of chromosome. In this case, a person has three X chromosomes instead of the typical two. This condition is known as triple X syndrome and is also a genetic disorder that affects females.
69,XYY is a form of monosomy that involves an extra Y chromosome. This condition is known as Klinefelter syndrome and is also a genetic disorder that affects males.
45,X is a form of monosomy that involves a missing X chromosome. This condition is known as Turner syndrome and is a genetic disorder that affects females.
As opposed to output produced by Hugging face hosting of the model is
Question: Which of the following is an example of monosomy?
Options:
- 46,XX
- 47,XXX
- 69,XYY
- 45,X
Please provide your choice first and then provide explanations if possible.
### Assistant Output:
The correct answer is 45,X.
Monosomy is a condition where a person has only one copy of a particular chromosome. In this case, the person has only one X chromosome, which is a form of Turner syndrome. This condition is usually caused by a missing or partially deleted X chromosome.
The other options are not examples of monosomy:
- 46,XX: This is a normal karyotype, where a person has two X chromosomes.
- 47,XXX: This is a form of trisomy, where a person has three X chromosomes.
- 69,XYY: This is a form of trisomy, where a person has three X chromosomes and an extra Y chromosome.
Above is the correct answer.
To Reproduce
Run the following code to reproduce it
``import json
from typing import List, Union
from absl.app import run
from absl import flags
from EasyDel import JAXServer, JAXServerConfig
import jax
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, Qwen2Prompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
FLAGS = flags.FLAGS
flags.DEFINE_enum(
"prompter_type",
enum_values=("gemma", "llama", "openchat", "qwen2", "medllama"),
help="Prompter to be used to prompt the model",
default="medllama"
)
flags.DEFINE_string(
"pretrained_model_name_or_path",
default="AdaptLLM/medicine-chat",
help="The pretrained model path in huggingface.co/models"
)
flags.DEFINE_integer(
"max_compile_tokens",
default=256,
help="Maximum number of compiled tokens"
)
flags.DEFINE_integer(
"max_new_tokens_ratio",
default=20,
help="max new tokens ratio to be multiplied for max_compile_tokens for max_new_tokens"
)
flags.DEFINE_integer(
"max_sequence_length",
default=2048,
help="max sequence length to be used in the model"
)
flags.DEFINE_enum(
"dtype",
enum_values=(
"bf16",
"fp16",
"fp32"
),
default="bf16",
help="The data type of the model"
)
flags.DEFINE_list(
"sharding_axis_dims",
default=[1, 1, 1, -1],
help="Sharding Axis dimensions for the model"
)
flags.DEFINE_bool(
"use_sharded_kv_caching",
default=False,
help="whether to use sharded kv for Large Sequence model up to 1M"
)
flags.DEFINE_bool(
"scan_ring_attention",
default=True,
help="whether to scan ring attention for Large Sequence model up to 1M (works with attn_mechanism='ring')"
)
flags.DEFINE_bool(
"use_scan_mlp",
default=True,
help="whether to scan MLP or FFN Layers for Large Sequence model up to 1M"
)
flags.DEFINE_enum(
"attn_mechanism",
enum_values=["normal", "flash", "ring", "splash"],
default="normal",
help="The attention mechanism to be used in the model"
)
flags.DEFINE_integer(
"block_k",
default=128,
help="the number of chunks for key block in attention (Works with flash, splash, ring Attention mechanism)"
)
flags.DEFINE_integer(
"block_q",
default=128,
help="the number of chunks for query block in attention (Works with flash, splash, ring Attention mechanism)"
)
flags.DEFINE_bool(
"share_gradio",
default=True,
help="whether to share gradio app"
)
flags.DEFINE_string(
"gradio_root_path",
default="",
help="Root Path to host Geadio Server"
)
from abc import ABC
from EasyDel.serve.prompters.base_prompter import BasePrompter
from typing import List, Optional
class MedLlama2Prompter(BasePrompter, ABC):
def init(
self,
):
user_prefix = "[INST]"",
assistant_prefix = "[/INST]"
super().init(
user_message_token=user_prefix,
assistant_message_token=assistant_prefix,
prompter_type="medllama",
end_of_turn_token="
)
def format_history_prefix(
self,
history: list[list[str]],
system_message: str,
):
prompt = ""
for user, assistant in history:
prompt += f"{self.user_message_token}{user} "
prompt += f"{self.assistant_message_token}{assistant} "
print("format_history_prefix", prompt)
return prompt
def format_message(
self,
prompt: str,
history: list[list[str]],
system_message: Optional[str],
prefix: Optional[str]
) -> str:
dialogs = prefix if prefix is not None else ""
for user, assistant in history:
dialogs += f"{self.user_message_token}{user}"
dialogs += f"{self.assistant_message_token}{assistant}"
dialogs += f"{self.user_message_token}{prompt}"
dialogs += self.assistant_message_token
print("format_message", dialogs)
return dialogs
import transformers
from typing import Optional, Mapping, Callable, Dict, Any
from jax.sharding import Mesh, PartitionSpec
from typing import Union, Sequence, List
from EasyDel.modules.auto_easydel_model import AutoEasyDelModelForCausalLM
def main(argv):
server_config = JAXServerConfig(
max_sequence_length=FLAGS.max_sequence_length,
max_compile_tokens=FLAGS.max_compile_tokens,
max_new_tokens=FLAGS.max_compile_tokens * FLAGS.max_new_tokens_ratio,
dtype=FLAGS.dtype
)
prompters = {
"gemma": GemmaPrompter(),
"llama": Llama2Prompter(),
"openchat": OpenChatPrompter(),
"qwen2": Qwen2Prompter(),
"medllama": MedLlama2Prompter()
}
prompter: BasePrompter = prompters[FLAGS.prompter_type]
FLAGS.sharding_axis_dims = tuple([int(s) for s in FLAGS.sharding_axis_dims])
class JAXServerMedLlama(JAXServer):
@staticmethod
def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
return prompter.format_message(
history=[],
prompt=prompt,
system_message=system,
prefix=None
)
@staticmethod
def format_instruct(system: str, instruction: str) -> str:
return prompter.format_message(
prefix=None,
system_message=system,
prompt=instruction,
history=[]
)
@classmethod
def from_torch_pretrained(
cls,
server_config: JAXServerConfig,
pretrained_model_name_or_path: str,
device=jax.devices('cpu')[0],
dtype: jax.numpy.dtype = jax.numpy.float32,
param_dtype: jax.numpy.dtype = jax.numpy.float32,
precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
sharding_axis_dims: Sequence[int] = (1, -1, 1, 1),
sharding_axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
use_shard_map: bool = False,
input_shape: Sequence[int] = (1, 1),
shard_fns: Optional[Mapping[tuple, Callable]] = None,
backend: Optional[str] = None,
add_params_field: bool = True,
do_memory_log: bool = False,
model_config_kwargs: Optional[Mapping[str, Any]] = None,
verbose: bool = True,
**kwargs
):
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=device,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
sharding_axis_names=sharding_axis_names,
sharding_axis_dims=sharding_axis_dims,
query_partition_spec=query_partition_spec,
attention_partition_spec=attention_partition_spec,
value_partition_spec=value_partition_spec,
key_partition_spec=key_partition_spec,
bias_partition_spec=bias_partition_spec,
use_shard_map=use_shard_map,
shard_fns=shard_fns,
input_shape=input_shape,
backend=backend,
config_kwargs=model_config_kwargs,
**kwargs
)
rule = (
("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))),
("self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
("mlp/gate_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))),
("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
("input_layernorm/kernel", PartitionSpec(None)),
("post_attention_layernorm/kernel", PartitionSpec(None)),
("model/norm/kernel", PartitionSpec(None)),
("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")),
(".*", PartitionSpec(None)),
)
model.config.get_partition_rules = lambda _:rule # this will set model config partition to the given custom partition rule
return cls.from_parameters(
model=model,
config_model=model.config,
tokenizer=transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path),
params=params,
server_config=server_config,
verbose=verbose,
do_memory_log=do_memory_log,
add_params_field=add_params_field
)
server = JAXServerMedLlama.from_torch_pretrained(
server_config=server_config,
pretrained_model_name_or_path=FLAGS.pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
dtype=get_dtype(dtype=FLAGS.dtype),
param_dtype=get_dtype(dtype=FLAGS.dtype),
precision=jax.lax.Precision("fastest"),
sharding_axis_dims=FLAGS.sharding_axis_dims,
sharding_axis_names=("dp", "fsdp", "tp", "sp"),
input_shape=(1, server_config.max_sequence_length),
model_config_kwargs=dict(
fully_sharded_data_parallel=True,
attn_mechanism=FLAGS.attn_mechanism,
scan_mlp_chunk_size=FLAGS.max_compile_tokens,
use_scan_mlp=FLAGS.use_scan_mlp,
scan_ring_attention=FLAGS.scan_ring_attention,
block_k=FLAGS.block_k,
block_q=FLAGS.block_q,
use_sharded_kv_caching=FLAGS.use_sharded_kv_caching
)
)
server.gradio_inference().launch(
root_path=FLAGS.gradio_root_path,
server_name="0.0.0.0",
server_port=7680,
show_api=True,
share=FLAGS.share_gradio
)
if name == "main":
run(main)
Is anything missing?
There were some warnings at the start of the execution of the script
site-packages/transformers/generation/configuration_utils.py:406: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
warnings.warn(
/home/neo/research/belie/.venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:411: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
warnings.warn(
Set top k and top p for both models to 1 and sample to false then both models will return same result it's being happen cause you are using sample in huggingface or non greedy in easydel I could recommend you to read about these parameters.
I have tried multiple options, in all the cases, Easydel/model generate wrong result. However, hugging face code gives the correct result 1 out 10 times.
The answer should be 45,X
here is a hugging face code. It is worth checking again. Please check my code for easydel. Do you see anything fishy?
!pip install bitsandbytes bitsandbytes datasets accelerate loralib
!pip install transformers@git+https://github.com/huggingface/transformers.git@main
!pip install peft@git+https://github.com/huggingface/peft.git
!pip install datasets
!pip install datasets --upgrade
!pip install evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("AdaptLLM/medicine-chat", device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("AdaptLLM/medicine-chat")
# Put your input here:
user_input = '''Question: Which of the following is an example of monosomy?
Options:
- 46,XX
- 47,XXX
- 69,XYY
- 45,X
Please provide your choice first and then provide explanations if possible.'''
# Apply the prompt template and system prompt of LLaMA-2-Chat demo for chat models (NOTE: NO prompt template is required for base models!)
our_system_prompt = "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" # Please do NOT change this
prompt = f"<s>[INST] <<SYS>>{our_system_prompt}<</SYS>>\n\n{user_input} [/INST]"
# # NOTE:
# # If you want to apply your own system prompt, please integrate it into the instruction part following our system prompt like this:
# your_system_prompt = "Please, answer this question faithfully."
# prompt = f"<s>[INST] <<SYS>>{our_system_prompt}<</SYS>>\n\n{your_system_prompt}\n{user_input} [/INST]"
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
outputs = model.generate(input_ids=inputs, max_length=4096)[0]
answer_start = int(inputs.shape[-1])
pred = tokenizer.decode(outputs[answer_start:], skip_special_tokens=True)
print(f'### User Input:\n{user_input}\n\n### Assistant Output:\n{pred}')
Finally, Found the issue. The issue was in slight change in the input to transformer/model and easydel/model.
Thanks a lot