erfanzar/EasyDeL

Text Generation with Mixtral fails

Closed this issue · 2 comments

Throws this error when generate is called
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1, 4096) for operand shape (1, 20).
image

To Reproduce

import copy
import jax
from EasyDel import AutoEasyDelModelForCausalLM, AutoEasyDelConfig, get_modules_by_type
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import MixtralForCausalLM
from EasyDel import MixtralConfig, FlaxMixtralForCausalLM
from EasyDel.transform.easydel_transform import huggingface_to_easydel

pretrained_model_name_or_path = "/LLMs/Mixtral-8x7B-v0.1/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        dtype=jax.numpy.bfloat16,
        param_dtype=jax.numpy.bfloat16,
        precision=jax.lax.Precision("fastest"),
        device=jax.devices('cpu')[0]
)

seq_len = 128
config = MixtralConfig(
        hidden_size=256,
        num_attention_heads=8,
        num_hidden_layers=1,
        num_key_value_heads=4,
        intermediate_size=512,
        num_local_experts=8,
        max_position_embeddings=seq_len
)

torch_model = MixtralForCausalLM(
        config=copy.deepcopy(config)
)

params = {"params":
        huggingface_to_easydel(
            torch_model.state_dict(),
            embedding_layer_names=["embed_tokens"],
            device=jax.devices("cpu")[0]
        )
}
    
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer("Can you tell me who is the current president of the united states?", max_length=4096, padding='max_length', return_tensors='jax')
input_ids, attention_mask = tokens.input_ids, tokens.attention_mask 
predict = model.generate(
            input_ids,
            attention_mask=attention_mask,
            params=params)

hello and I'm sorry mixtral models are not working right now (as i said in our previous discussions they are noy jit compatible)
and i'm working on them to fix them as soon as possible.

Thank you @erfanzar. Please let me know when fixed