erfanzar/EasyDeL

GPT2 (150M model) support on Tv2.8. Example scripts goes out of memory

jchauhan opened this issue · 1 comments

Describe the bug
Out of memory for a smaller gpt2 model with 150M params

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 9.35G of 7.48G hbm. Exceeded hbm capacity by 1.86G.

Total hbm usage >= 9.86G:
    reserved        530.00M 
    program           9.35G 
    arguments            0B 

Output size 0B; shares 0B with arguments.

Program hbm requirement 9.35G:
    HLO temp          9.35G (3.1% utilization: Unpadded (294.48M) Padded (9.35G), 0.0% fragmentation (30.0K))

  Largest program allocations in hbm:

  1. Size: 9.20G
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/erf_inv" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: f32[19298688,2]{1,0:T(8,128)}
     Unpadded size: 147.24M
     Extra memory due to padding: 9.06G (64.0x expansion)
     XLA label: copy.1 = copy(fusion.7)
     Allocation type: HLO temp
     ==========================

  2. Size: 147.38M
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/erf_inv" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: f32[768,50257]{1,0:T(8,128)}
     Unpadded size: 147.24M
     Extra memory due to padding: 141.0K (1.0x expansion)
     XLA label: reshape.84 = reshape(copy.1)
     Allocation type: HLO temp
     ==========================

  3. Size: 1.0K
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: (u32[1]{0:T(256)}, u32[1]{0:T(256)})
     Unpadded size: 1.0K
     XLA label: fusion.53 = fusion(Arg_0.1), kind=kLoop, calls=fused_computation.52
     Allocation type: HLO temp
     ==========================

  4. Size: 1.0K
     Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/add" source_file="/home/neo/research/easydel/.venv/lib/python3.10/site-packages/flax/core/scope.py" source_line=979
     Shape: (u32[9649344]{0:T(1024)}, u32[9649344]{0:T(1024)})
     Unpadded size: 1.0K
     XLA label: fusion.50 = fusion(xor.27, bitcast.1, bitcast), kind=kLoop, calls=fused_computation.50
     Allocation type: HLO temp
     ==========================

To Reproduce

Install deps

 pip install EasyDeL@git+https://github.com/erfanzar/EasyDeL.git@main
 pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/libtpu-releases/index.html

Use the following code

from EasyDel.modules import AutoEasyDelModelForCausalLM
from EasyDel.serve import JAXServer
from transformers import AutoTokenizer
import jax

model_huggingface_repo_id = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_huggingface_repo_id, trust_remote_code=True)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
    model_huggingface_repo_id,
    jax.devices("cpu")[0],
    jax.numpy.float16,
    jax.numpy.float16,
    jax.lax.Precision("fastest"),
    (1, -1, 1, 1),
    device_map="auto"
)

params = model.init_weights(jax.random.PRNGKey(0), (1, 1))
server = JAXServer.from_parameters(
    model=model,
    config_model=model.config,
    tokenizer=tokenizer,
    params=params,
    add_params_field=True
)

response_printed = 0
for response, tokens_used in server.process(
        "String To The Model", stream=True
):
    print(response[response_printed:], end="")
    response_printed = len(response)

use this peace of code

from EasyDel import JAXServer, JAXServerConfig 
import jax
from jax import numpy as jnp, lax

scan_mlp_chunk_size = 128

server_config = JAXServerConfig(
    max_sequence_length=1024,
    max_compile_tokens=scan_mlp_chunk_size,
    max_new_tokens=scan_mlp_chunk_size * 10,
    dtype="bf16"
)

server = JAXServer.from_torch_pretrained(
    server_config=server_config,
    pretrained_model_name_or_path="gpt2",
    device=jax.devices('cpu')[0],
    dtype=jax.numpy.bfloat16,
    param_dtype=jax.numpy.bfloat16,
    precision=jax.lax.Precision("fastest"),
    sharding_axis_dims=(1, 1, 1, -1),
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    input_shape=(1, server_config.max_sequence_length),
    model_config_kwargs=dict(
        fully_sharded_data_parallel=True,
        attn_mechanism="normal",
        scan_mlp_chunk_size=scan_mlp_chunk_size,
        use_scan_mlp=True,
        scan_ring_attention=True,
        block_k=128,
        block_q=128,
        use_sharded_kv_caching=False
    )
)

prompt = "string to model"
seq_len = 0
for f,o  in server.sample(prompt):
    print(f[seq_len:],end="")
    seq_len = len(f)