erfanzar/EasyDeL

Import error EasyDeL libraries examples/flash_attention_training_example.py

Closed this issue · 6 comments

Describe the bug
Does not import libraries.

To Reproduce

# %%
!pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -U
!pip install easydel==0.0.69 wandb sentencepiece zstandard -q
HF_TOKEN = ""
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
WANDB_TOKEN = ""
!wandb login <WANDB_TOKEN> 
!apt-get update && apt-get upgrade -y -q && apt-get install golang -y -q

# %%
from easydel import (
    AutoEasyDeLModelForCausalLM,
    TrainArguments,
    CausalLanguageModelTrainer,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    get_modules_by_type
)
from datasets import load_dataset
from huggingface_hub import HfApi
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp
import jax
from fjformer import GenerateRNG

rng_g = GenerateRNG()
api = HfApi()


def launch():
    pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch" #changable
    device_num = len(jax.devices())
    model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device=jax.devices('cpu')[0],
        input_shape=(device_num, 1),
        device_map="auto"
    )

    config = model.config

    model_parameters = FrozenDict({"params": params})

    config.add_basic_configurations(
        attn_mechanism="flash",
        block_b=1,
        block_q=512,
        block_k=512,
        block_k_major=512
    )

    original_max_position_embeddings = config.max_position_embeddings
    config.freq_max_position_embeddings = config.max_position_embeddings
    config.max_position_embeddings = 4096
    config.c_max_position_embeddings = config.max_position_embeddings

    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path,
        trust_remote_code=True
    )

    max_sequence_length = config.max_position_embeddings

    configs_to_initialize_model_class = {
        'config': config,
        'dtype': jnp.bfloat16,
        'param_dtype': jnp.bfloat16,
        'input_shape': (device_num, config.block_q)
    }

    if tokenizer.pad_token == None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset(
        "yahma/alpaca-cleaned",
        split="train",
    )
    tokenization_process = lambda data_chunk: tokenizer(
        data_chunk["prompt"],
        add_special_tokens=False,
        max_length=max_sequence_length,
        padding="max_length"
    )
    dataset = dataset.map(
        tokenization_process,
        num_proc=18,
        remove_columns=dataset.column_names
    )
    train_args = TrainArguments(
        model_class=get_modules_by_type(config.model_type)[1],
        configs_to_initialize_model_class=configs_to_initialize_model_class,
        custom_rule=config.get_partition_rules(True),
        model_name="FlashAttentionTest",
        num_train_epochs=1,
        learning_rate=8e-5,
        learning_rate_end=3e-05,
        warmup_steps=200,
        optimizer=EasyDeLOptimizers.ADAMW,
        scheduler=EasyDeLSchedulers.LINEAR,
        weight_decay=0.02,
        total_batch_size=4,
        max_sequence_length=max_sequence_length,
        gradient_checkpointing=EasyDeLGradientCheckPointers.EVERYTHING_SAVEABLE,
        sharding_array=(1, -1, 1, 1),
        gradient_accumulation_steps=2,
        dtype=jnp.bfloat16,
        init_input_shape=(8, config.block_q),
        step_start_point=0,
        training_time="7H"
    )

    trainer = CausalLanguageModelTrainer(
        train_args,
        dataset.shuffle().shuffle().shuffle(),
        checkpoint_path=None
    )

    output = trainer.train(
        model_parameters=model_parameters,
        state=None
    )


if __name__ == "__main__":
    launch()
/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
WARNING: Logging before InitGoogle() is written to STDERR
E0000 00:00:1720125449.236290      13 common_lib.cc:800] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:481
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()

Apparently this warning is not fatal, afterwards it does not show anything anymore which means it does not import the libraries, load the data etc.

hello and thanks for using easydel.
are you using kaggle?

Yes, does it work on Colab? I'll use that after the testing phase (TPU-v3x8 on Kaggle).

as far as i notice it's a problem with new env of kaggle can you pin to any env version before 2024-3-20

Are there any plans to implement support for the latest Kaggle and Colab environments?
Second option would be to install all the corresponding libraries from the older environments, have not found how to do that after a quick search.

Ill try to fix that before 0.0.71

@s-smits import issue is now fixed.