Transformers-like API for inference
Froggy111 opened this issue · 19 comments
Is there a versatile transformers-like API (like model.generate()) equivalent for this? I tried JAXServer but it is quite confusing, and I couldnt get flashattention to work. Could you maybe provide some guidance? Thanks very much, appreciate it
Also, how can we load quantized models (like GPTQ) onto TPUs?
Hello and thanks for using EasyDeL
No you can't load quantized model onto EasyDeL but 80% of LLMs from hf and PyTorch are supported
And in case of using flash attention and generate function you can tell me clearly what you need so i can create an example for you
I need mistral-7b and mixtral 8x7b flash attention generate function, I have been trying with mistral but it gives error of block_q=128 has to be smaller or equals to seq_len_q=1, and am unable to find why this occurs. I am running on Google cloud TPUs. Again, thanks very much for the help, it is really appreciated
Also, are there other ways to load quantised models in Jax?
This is my code for reference
import EasyDel, jax, transformers
tokenizer = transformers.AutoTokenizer.from_pretrained (
pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
)
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer (
["hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello "
]* 4,
return_tensors = "jax",
pad_to_multiple_of = 128,
padding = True,
)
print(type(input_ids))
print(input_ids)
attention_mask = input_ids.attention_mask
print(type(attention_mask))
print(attention_mask)
input_ids = input_ids.input_ids
print(type(input_ids))
print(input_ids)
model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained (
pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
device = jax.devices('cpu')[0],
device_map = "auto",
dtype = jax.numpy.bfloat16,
param_dtype = jax.numpy.bfloat16,
precision = jax.lax.Precision("fastest"),
sharding_axis_dims = (1, -1, 1, 1),
sharding_axis_names = ("dp", "fsdp", "tp", "sp"),
backend = "tpu",
input_shape = (4, 2048),
config_kwargs = {
"attn_mechanism": "flash",
},
)
print(type(model))
print(model)
print(type(params))
print(params.keys())
#print(params)
# transformers.GenerationConfig()
generated_ids = model.generate (
input_ids = input_ids,
attention_mask = attention_mask,
params = {"params": params},
generation_config = transformers.GenerationConfig (
max_new_tokens = 1024,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
bos_token_id = tokenizer.bos_token_id,
temperature = 0.7,
do_sample = True,
num_beams = 1,
top_p = 100,
top_k = 100,
repetition_penalty = 0.01,
),
max_new_tokens = 1024,
)
print(generated_ids)
output = tokenizer.decode (
generated_ids,
skip_special_tokens = True,
clean_up_tokenization_spaces = True)
print(output)
I know why you are getting error in generating process give me 5 hours and ill fix it.
can you try running that code again?
Hi, it still does not work.
code:
import jax, transformers
from EasyDeL.lib.python import EasyDel
from jax.sharding import PartitionSpec
from typing import Sequence, Optional
tokenizer = transformers.AutoTokenizer.from_pretrained (
pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
)
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer (
["hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello "
]* 4,
return_tensors = "jax",
pad_to_multiple_of = 128,
padding = True,
)
print(type(input_ids))
print(input_ids)
attention_mask = input_ids.attention_mask
print(type(attention_mask))
print(attention_mask)
input_ids = input_ids.input_ids
print(type(input_ids))
print(input_ids)
def load_model(
pretrained_model_name_or_path: str,
device=jax.devices('cpu')[0], # Device to be used in order to Load Model on (Offload device)
dtype: jax.numpy.dtype = jax.numpy.float32,
param_dtype: jax.numpy.dtype = jax.numpy.float32,
precision: Optional[jax.lax.Precision] = jax.lax.Precision("fastest"),
sharding_axis_dims: Sequence[int] = (1, -1, 1, 1),
sharding_axis_names: Sequence[str] = ("dp", "fsdp", "tp", "sp"),
query_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
generation_query_partition_spec = PartitionSpec(("dp", "fsdp"), "tp", None, None),
key_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
value_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
bias_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), None, None, None),
attention_partition_spec: PartitionSpec = PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
use_shard_map: bool = False,
input_shape: Sequence[int] = (1, 1),
backend: Optional[str] = None,
config_kwargs: dict = None,
):
model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=device,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
sharding_axis_names=sharding_axis_names,
sharding_axis_dims=sharding_axis_dims,
query_partition_spec=query_partition_spec,
generation_query_partition_spec=generation_query_partition_spec,
key_partition_spec=key_partition_spec,
value_partition_spec=value_partition_spec,
bias_partition_spec=bias_partition_spec,
attention_partition_spec=attention_partition_spec,
use_shard_map=use_shard_map,
input_shape=input_shape,
backend=backend,
config_kwargs=config_kwargs
)
return model, params
model, params = load_model (
pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
dtype = jax.numpy.bfloat16,
param_dtype = jax.numpy.bfloat16,
precision = jax.lax.Precision("fastest"),
input_shape = (4, 2048),
config_kwargs = {
"attn_mechanism": "flash",
},
backend = "tpu"
)
# model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained (
# pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
# device = jax.devices('cpu')[0],
# device_map = "auto",
# dtype = jax.numpy.bfloat16,
# param_dtype = jax.numpy.bfloat16,
# precision = jax.lax.Precision("fastest"),
# sharding_axis_dims = (1, -1, 1, 1),
# #sharding_axis_names = ("dp", "fsdp", "tp", "sp"),
# #backend = "tpu",
# input_shape = (4, 2048),
# config_kwargs = {
# "attn_mechanism": "flash",
# },
# )
print(type(model))
print(model)
print(type(params))
print(params.keys())
#print(params)
# transformers.GenerationConfig()
generated_ids = model.generate (
input_ids = input_ids,
attention_mask = attention_mask,
params = {"params": params},
generation_config = transformers.GenerationConfig (
max_new_tokens = 1024,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
bos_token_id = tokenizer.bos_token_id,
temperature = 0.7,
do_sample = True,
num_beams = 1,
top_p = 100,
top_k = 100,
repetition_penalty = 0.01,
),
#max_new_tokens = 1024,
)
print(generated_ids)
output = tokenizer.decode (
generated_ids,
skip_special_tokens = True,
clean_up_tokenization_spaces = True)
print(output)
error is still the same error as previously:
ValueError: block_q=128 should be smaller or equal to q_seq_len=1
fixed. can you try again? i can run the code.
and change model_config to this
config_kwargs = {
"attn_mechanism": "flash",
"gradient_checkpointing": ""
}
Can you send the code you used? I am still getting the same issue. Thanks.
Have you updated your EasyDel?
Cause the argument you are getting error from is removed
import jax, transformers
import EasyDel
from jax.sharding import PartitionSpec
from typing import Sequence, Optional
from jax.sharding import PartitionSpec
dev_len = 6
tokenizer = transformers.AutoTokenizer.from_pretrained (
pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
)
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer (
["hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello hello "
]* dev_len,
return_tensors = "jax",
max_length = 512,
padding = "max_length",
)
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
model, params = EasyDel.AutoEasyDelModelForCausalLM.from_pretrained(
pretrained_model_name_or_path = "mistralai/Mistral-7B-v0.1",
device = jax.devices('cpu')[0],
device_map = "auto",
dtype = jax.numpy.bfloat16,
param_dtype = jax.numpy.bfloat16,
precision = jax.lax.Precision("fastest"),
sharding_axis_dims = (1, -1, 1, 1),
input_shape = (1, 2048),
query_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
generation_query_partition_spec=PartitionSpec(("dp", "fsdp"), None, None, "tp"),
key_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
value_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
bias_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "sp"),
attention_partition_spec=PartitionSpec(("dp", "fsdp"), None,"sp", "tp"),
config_kwargs = {
"attn_mechanism": "flash",
"gradient_checkpointing" : ""
},
)
generated_ids = model.generate (
input_ids = input_ids,
attention_mask = attention_mask,
params = {"params": params},
generation_config = transformers.GenerationConfig (
max_new_tokens = 1024,
max_length = 512,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
bos_token_id = tokenizer.bos_token_id,
temperature = 0.7,
do_sample = True,
num_beams = 1,
top_p = 0.1,
top_k = 2,
repetition_penalty = 1.25,
),
)
print(generated_ids)
output = tokenizer.decode (
generated_ids,
skip_special_tokens = True,
clean_up_tokenization_spaces = True)
print(output)
Using jax[tpu]==0.4.22, and the latest commit of EasyDel on main, on a google cloud tpu-v4-8 VM:
Running your code, after changing input_shape to (4, 2048) and dev_len to 4, I get the following error:
ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec(('dp', 'fsdp'), None, ('sp',), ('sp',)) has duplicate entries for `sp`
And when I change bias_partition_spec to
bias_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", None),
I get the old error
ValueError: block_q=128 should be smaller or equal to q_seq_len=1
If I don't change the input shape to 4, I get this error:
ValueError: shard_map applied to the function 'functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 4, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
* args[0] of shape float32[1,32,2048,128], where args[0] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'tp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
* args[1] of shape float32[1,32,2048,128], where args[1] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'tp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
* args[2] of shape float32[1,32,2048,128], where args[2] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'tp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
* args[3] of shape bfloat16[1,32,2048,2048], where args[3] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)'s parameter 'ab', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', 'sp'), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<PjitFunction of <function flash_attention at 0x7f0e9d63b0a0>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128), debug=False)' appropriately.
I guess the issue is now fixed, the issue was in detecting the Generation Process to change the block sizes
from
is_generating = query_states.shape[1] == 1
query_sequence_partition = self.generation_query_partition_spec if is_generating else self.query_partition_spec
bias_partition_spec = self.generation_bias_partition_spec if is_generating else self.bias_partition_spec
block_q = 1 if is_generating else self.block_q
block_q_major_dkv = 1 if is_generating else self.block_q_major_dkv
block_q_dkv = 1 if is_generating else self.block_q_dkv
block_q_dq = 1 if is_generating else self.block_q_dq
to
is_generating = query_states.shape[2] == 1
query_sequence_partition = self.generation_query_partition_spec if is_generating else self.query_partition_spec
bias_partition_spec = self.generation_bias_partition_spec if is_generating else self.bias_partition_spec
block_q = 1 if is_generating else self.block_q
block_q_major_dkv = 1 if is_generating else self.block_q_major_dkv
block_q_dkv = 1 if is_generating else self.block_q_dkv
block_q_dq = 1 if is_generating else self.block_q_dq
Now, I am getting this issue:
RuntimeError: Internal TPU kernel compiler error: Not implemented: Non-trivial layouts unsupported
The MLIR operation involved:
%130 = "tpu.repeat"(%129) {dimension = 1 : i32, in_layout = [#tpu.vpad<"32,{0,0},(1,128)">], out_layout = [#tpu.vpad<"32,{0,0},(1,128)">], times = 1 : i32} : (vector<1x128xf32>) -> vector<1x128xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
with jax[tpu]==0.4.22 and jax[tpu]==0.4.23
and 0.4.24 and above do not work as it seems some APIs were removed
;\ which TPU version you are using?
can you upgrade jax to 0.4.25?
jax-0.4.25 jaxlib-0.4.25 libtpu-nightly-0.1.dev20240224
AttributeError: 'Config' object has no attribute 'define_bool_state'
is the error coming from EasyDeL or FJFormer?
I never created a value named define_bool_state
or even use one