Text Generation with Mixtral fails
Closed this issue · 2 comments
clintg6 commented
Throws this error when generate is called
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1, 4096) for operand shape (1, 20).
To Reproduce
import copy
import jax
from EasyDel import AutoEasyDelModelForCausalLM, AutoEasyDelConfig, get_modules_by_type
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import MixtralForCausalLM
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel
pretrained_model_name_or_path = "/LLMs/Mixtral-8x7B-v0.1/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
dtype=jax.numpy.bfloat16,
param_dtype=jax.numpy.bfloat16,
precision=jax.lax.Precision("fastest"),
device=jax.devices('cpu')[0]
)
seq_len = 128
config = MixtralConfig(
hidden_size=256,
num_attention_heads=8,
num_hidden_layers=1,
num_key_value_heads=4,
intermediate_size=512,
num_local_experts=8,
max_position_embeddings=seq_len
)
torch_model = MixtralForCausalLM(
config=copy.deepcopy(config)
)
params = {"params":
huggingface_to_easydel(
torch_model.state_dict(),
embedding_layer_names=["embed_tokens"],
device=jax.devices("cpu")[0]
)
}
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer("Can you tell me who is the current president of the united states?", max_length=4096, padding='max_length', return_tensors='jax')
input_ids, attention_mask = tokens.input_ids, tokens.attention_mask
predict = model.generate(
input_ids,
attention_mask=attention_mask,
params=params)
erfanzar commented
hello and I'm sorry mixtral models are not working right now (as i said in our previous discussions they are noy jit compatible)
and i'm working on them to fix them as soon as possible.