Import error EasyDeL libraries examples/flash_attention_training_example.py
Closed this issue · 6 comments
s-smits commented
Describe the bug
Does not import libraries.
To Reproduce
# %%
!pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -U
!pip install easydel==0.0.69 wandb sentencepiece zstandard -q
HF_TOKEN = ""
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
WANDB_TOKEN = ""
!wandb login <WANDB_TOKEN>
!apt-get update && apt-get upgrade -y -q && apt-get install golang -y -q
# %%
from easydel import (
AutoEasyDeLModelForCausalLM,
TrainArguments,
CausalLanguageModelTrainer,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
get_modules_by_type
)
from datasets import load_dataset
from huggingface_hub import HfApi
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp
import jax
from fjformer import GenerateRNG
rng_g = GenerateRNG()
api = HfApi()
def launch():
pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch" #changable
device_num = len(jax.devices())
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=(device_num, 1),
device_map="auto"
)
config = model.config
model_parameters = FrozenDict({"params": params})
config.add_basic_configurations(
attn_mechanism="flash",
block_b=1,
block_q=512,
block_k=512,
block_k_major=512
)
original_max_position_embeddings = config.max_position_embeddings
config.freq_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = 4096
config.c_max_position_embeddings = config.max_position_embeddings
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True
)
max_sequence_length = config.max_position_embeddings
configs_to_initialize_model_class = {
'config': config,
'dtype': jnp.bfloat16,
'param_dtype': jnp.bfloat16,
'input_shape': (device_num, config.block_q)
}
if tokenizer.pad_token == None:
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(
"yahma/alpaca-cleaned",
split="train",
)
tokenization_process = lambda data_chunk: tokenizer(
data_chunk["prompt"],
add_special_tokens=False,
max_length=max_sequence_length,
padding="max_length"
)
dataset = dataset.map(
tokenization_process,
num_proc=18,
remove_columns=dataset.column_names
)
train_args = TrainArguments(
model_class=get_modules_by_type(config.model_type)[1],
configs_to_initialize_model_class=configs_to_initialize_model_class,
custom_rule=config.get_partition_rules(True),
model_name="FlashAttentionTest",
num_train_epochs=1,
learning_rate=8e-5,
learning_rate_end=3e-05,
warmup_steps=200,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.LINEAR,
weight_decay=0.02,
total_batch_size=4,
max_sequence_length=max_sequence_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.EVERYTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1),
gradient_accumulation_steps=2,
dtype=jnp.bfloat16,
init_input_shape=(8, config.block_q),
step_start_point=0,
training_time="7H"
)
trainer = CausalLanguageModelTrainer(
train_args,
dataset.shuffle().shuffle().shuffle(),
checkpoint_path=None
)
output = trainer.train(
model_parameters=model_parameters,
state=None
)
if __name__ == "__main__":
launch()
/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
WARNING: Logging before InitGoogle() is written to STDERR
E0000 00:00:1720125449.236290 13 common_lib.cc:800] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:481
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Apparently this warning is not fatal, afterwards it does not show anything anymore which means it does not import the libraries, load the data etc.
erfanzar commented
hello and thanks for using easydel.
are you using kaggle?
s-smits commented
Yes, does it work on Colab? I'll use that after the testing phase (TPU-v3x8 on Kaggle).
erfanzar commented
as far as i notice it's a problem with new env of kaggle can you pin to any env version before 2024-3-20
s-smits commented
Are there any plans to implement support for the latest Kaggle and Colab environments?
Second option would be to install all the corresponding libraries from the older environments, have not found how to do that after a quick search.
erfanzar commented
Ill try to fix that before 0.0.71