erfanzar/EasyDeL

Step time increasing as training progresses

Closed this issue · 8 comments

In one of the longer training runs that is now running on a tpu-v3-8 I noticed the training ETA kept getting later and later.
Also in the step-time wandb log (picture below) the higher the step number, the longer the lookup time.

image

Any ideas what could be the cause? I looked a bit at DataLoader prefetch_factor but its only available when using multiprocessing / num_workers > 0

PS: Thanks for creating EasyDel - its amazing what you've created!

hello
thank you :).

and about the issue can i have access to your training arguments (like batch size and etc)
and i would also like to have screen shot of you buffer size chart in WANDB.

For this specific run the config was like this.
(looking at it, maybe it is due to the shuffle() operation?) -- unfortunately TRC access ended today, will retry without shuffle after renewing.

dataset_train = datasets.load_dataset("yhavinga/nedd_x_chat_instruct_tokenized_zephyr_7b_alpha_padright__b1_", split="train").shuffle()
context_length = len(dataset_train[0]['input_ids'])  # 1024
print(f"Using context length of {context_length}")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(model_id)
config = model.config
config.freq_max_position_embeddings = config.max_position_embeddings  # 32768
config.max_position_embeddings = context_length
config.c_max_position_embeddings = config.max_position_embeddings

max_length = config.max_position_embeddings

train_args = TrainArguments(
    model_class=EasyDel.FlaxMistralForCausalLM,
    configs_to_init_model_class={
        'config': config,
        'dtype': jnp.bfloat16,
        'param_dtype': jnp.bfloat16,
        'input_shape': (1, 1)
    },
    custom_rule=config.get_partition_rules(True),
    model_name='TowerDutchTest',
    num_train_epochs=1,
    learning_rate=1e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.WARM_UP_LINEAR,
    warmup_steps=500,
    weight_decay=0.1,
    total_batch_size=2,
    max_steps=48000,
    save_steps=8000,
    do_train=True,
    do_eval=False,
    backend='tpu',
    max_length=max_length,
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,
    use_flash_attention=False,
    gradient_accumulation_steps=8,
    remove_ckpt_after_load=True,
    ids_to_pop_from_dataset=['token_type_ids'],
    loss_remat="",
    dtype=jnp.bfloat16
)

image
Wandb link : https://wandb.ai/yepster/EasyDeL-TowerDutchTest/runs/6cweckwk/workspace

To check if shuffling a dataset could cause these increasing times I plotted per sample reading time of a HF shuffled dataset. This is reading with an iterator, but indexed lookup looked almost the same

image

I don't see any issue with your configurations and I guess it might be related to any of the back processes that you might be running or your kernel running for you I have pre-trained more than 10+ models and this is the first time I'm seeing buffer size increasing over time

image

Do you have any suggestion for me to fix your issue, i guess you can try disabling shuffle in TrainingArguments and see if the buffer size is still increasing.

this issue is being closed due to no response has been given

Got renewed TRC access and looked into it a bit more. I tried a couple of things:

  1. replace jnp with np in the data-collator: synthetic test showed per loop time much faster than jnp, unfortunately didn't seem to solve increasing step time.
  2. replace jnp with np in the train loops perplexity calculation
  3. commented out mean loss and mean accuracy stats - arrays they operate on grow with step size
  4. set track_mem to off

Result below:
green is unchanged easydel 0.0.42 tag
orange is with only change 1. on easydel from few days ago
red is with the four changes above, on easydel from few days ago

image

Hello and thank for letting me know where the issue is or how i can make it better

I have tried using numpy instead in jax numpy but that will cause issues and a lot of errors in multiple host training

If you want you can create a pull request from HEAD and make changes other wise you can tell me which parts exactly you want to be modified in order to make it more efficient.

I will do some more experimentation, I suspect only changing the mean calculations is sufficient. When finished I'll create a PR.