GPT2 (150M model) support on Tv2.8. Example scripts goes out of memory
jchauhan opened this issue · 1 comments
jchauhan commented
Describe the bug
Out of memory for a smaller gpt2 model with 150M params
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 9.35G of 7.48G hbm. Exceeded hbm capacity by 1.86G.
Total hbm usage >= 9.86G:
reserved 530.00M
program 9.35G
arguments 0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 9.35G:
HLO temp 9.35G (3.1% utilization: Unpadded (294.48M) Padded (9.35G), 0.0% fragmentation (30.0K))
Largest program allocations in hbm:
1. Size: 9.20G
Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/erf_inv" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
Shape: f32[19298688,2]{1,0:T(8,128)}
Unpadded size: 147.24M
Extra memory due to padding: 9.06G (64.0x expansion)
XLA label: copy.1 = copy(fusion.7)
Allocation type: HLO temp
==========================
2. Size: 147.38M
Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/erf_inv" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
Shape: f32[768,50257]{1,0:T(8,128)}
Unpadded size: 147.24M
Extra memory due to padding: 141.0K (1.0x expansion)
XLA label: reshape.84 = reshape(copy.1)
Allocation type: HLO temp
==========================
3. Size: 1.0K
Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
Shape: (u32[1]{0:T(256)}, u32[1]{0:T(256)})
Unpadded size: 1.0K
XLA label: fusion.53 = fusion(Arg_0.1), kind=kLoop, calls=fused_computation.52
Allocation type: HLO temp
==========================
4. Size: 1.0K
Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/add" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
Shape: (u32[9649344]{0:T(1024)}, u32[9649344]{0:T(1024)})
Unpadded size: 1.0K
XLA label: fusion.50 = fusion(xor.27, bitcast.1, bitcast), kind=kLoop, calls=fused_computation.50
Allocation type: HLO temp
==========================
To Reproduce
Install deps
pip install EasyDeL@git+https://github.com/erfanzar/EasyDeL.git@main
pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/libtpu-releases/index.html
Use the following code
from EasyDel.modules import AutoEasyDelModelForCausalLM
from EasyDel.serve import JAXServer
from transformers import AutoTokenizer
import jax
model_huggingface_repo_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_huggingface_repo_id, trust_remote_code=True)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
model_huggingface_repo_id,
jax.devices("cpu")[0],
jax.numpy.float16,
jax.numpy.float16,
jax.lax.Precision("fastest"),
(1, -1, 1, 1),
device_map="auto"
)
params = model.init_weights(jax.random.PRNGKey(0), (1, 1))
server = JAXServer.from_parameters(
model=model,
config_model=model.config,
tokenizer=tokenizer,
params=params,
add_params_field=True
)
response_printed = 0
for response, tokens_used in server.process(
"String To The Model", stream=True
):
print(response[response_printed:], end="")
response_printed = len(response)
erfanzar commented
use this peace of code
from EasyDel import JAXServer, JAXServerConfig
import jax
from jax import numpy as jnp, lax
scan_mlp_chunk_size = 128
server_config = JAXServerConfig(
max_sequence_length=1024,
max_compile_tokens=scan_mlp_chunk_size,
max_new_tokens=scan_mlp_chunk_size * 10,
dtype="bf16"
)
server = JAXServer.from_torch_pretrained(
server_config=server_config,
pretrained_model_name_or_path="gpt2",
device=jax.devices('cpu')[0],
dtype=jax.numpy.bfloat16,
param_dtype=jax.numpy.bfloat16,
precision=jax.lax.Precision("fastest"),
sharding_axis_dims=(1, 1, 1, -1),
sharding_axis_names=("dp", "fsdp", "tp", "sp"),
input_shape=(1, server_config.max_sequence_length),
model_config_kwargs=dict(
fully_sharded_data_parallel=True,
attn_mechanism="normal",
scan_mlp_chunk_size=scan_mlp_chunk_size,
use_scan_mlp=True,
scan_ring_attention=True,
block_k=128,
block_q=128,
use_sharded_kv_caching=False
)
)
prompt = "string to model"
seq_len = 0
for f,o in server.sample(prompt):
print(f[seq_len:],end="")
seq_len = len(f)