erfanzar/EasyDeL

TPU-v3 Kaggle not working after update

s-smits opened this issue · 5 comments

Describe the bug
It would be great to keep support for TPU-v3's on Kaggle.
After 0.0.66 I get this error:


ImportError Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:58, in _LazyModule._get_module(self, module_name)
57 try:
---> 58 return importlib.import_module("." + module_name, self.name)
59 except Exception as e:

File /usr/local/lib/python3.10/importlib/init.py:126, in import_module(name, package)
125 level += 1
--> 126 return _bootstrap._gcd_import(name[level:], package, level)

File :1050, in _gcd_import(name, package, level)

File :1027, in find_and_load(name, import)

File :992, in find_and_load_unlocked(name, import)

File :241, in _call_with_frames_removed(f, *args, **kwds)

File :1050, in _gcd_import(name, package, level)

File :1027, in find_and_load(name, import)

File :1006, in find_and_load_unlocked(name, import)

File :688, in _load_unlocked(spec)

File :883, in exec_module(self, module)

File :241, in _call_with_frames_removed(f, *args, **kwds)

File /usr/local/lib/python3.10/site-packages/easydel/modules/init.py:35
1 from . import (
2 llama,
3 deepseek_v2,
(...)
32 mistral
33 )
---> 35 from .auto_easydel_model import (
36 AutoEasyDeLModelForCausalLM as AutoEasyDeLModelForCausalLM,
37 AutoEasyDeLConfig as AutoEasyDeLConfig,
38 AutoShardAndGatherFunctions as AutoShardAndGatherFunctions
39 )

File /usr/local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py:12
11 import jax.numpy
---> 12 from fjformer import match_partition_rules, make_shard_and_gather_fns
14 from flax.traverse_util import unflatten_dict

File /usr/local/lib/python3.10/site-packages/fjformer/init.py:68
67 from . import pallas_operations as pallas_operations
---> 68 from . import optimizers as optimizers
69 from . import linen as linen

File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/init.py:1
----> 1 from .adamw import (
2 get_adamw_with_cosine_scheduler as get_adamw_with_cosine_scheduler,
3 get_adamw_with_warm_up_cosine_scheduler as get_adamw_with_warm_up_cosine_scheduler,
4 get_adamw_with_warmup_linear_scheduler as get_adamw_with_warmup_linear_scheduler,
5 get_adamw_with_linear_scheduler as get_adamw_with_linear_scheduler
6 )
7 from .lion import (
8 get_lion_with_cosine_scheduler as get_lion_with_cosine_scheduler,
9 get_lion_with_with_warmup_linear_scheduler as get_lion_with_with_warmup_linear_scheduler,
10 get_lion_with_warm_up_cosine_scheduler as get_lion_with_warm_up_cosine_scheduler,
11 get_lion_with_linear_scheduler as get_lion_with_linear_scheduler
12 )

File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/adamw.py:3
2 import chex
----> 3 import optax
6 def get_adamw_with_cosine_scheduler(
7 steps: int,
8 learning_rate: float = 5e-5,
(...)
16
17 ):

File /usr/local/lib/python3.10/site-packages/optax/init.py:17
15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
18 from optax import losses

File /usr/local/lib/python3.10/site-packages/optax/contrib/init.py:21
20 from optax.contrib.complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw
22 from optax.contrib.dadapt_adamw import DAdaptAdamWState

File /usr/local/lib/python3.10/site-packages/optax/contrib/dadapt_adamw.py:27
26 from optax._src import base
---> 27 from optax._src import utils
30 class DAdaptAdamWState(NamedTuple):

File /usr/local/lib/python3.10/site-packages/optax/_src/utils.py:22
21 import jax.numpy as jnp
---> 22 import jax.scipy.stats.norm as multivariate_normal
24 from optax._src import linear_algebra

ImportError: cannot import name 'stats' from 'jax.scipy' (/usr/local/lib/python3.10/site-packages/jax/scipy/init.py)

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
Cell In[8], line 3
1 import transformers
----> 3 from easydel import (
4 AutoEasyDeLModelForCausalLM,
5 TrainArguments,
6 EasyDeLOptimizers,
7 EasyDeLSchedulers,
8 EasyDeLGradientCheckPointers,
9 EasyDeLState,
10 EasyDeLXRapTureConfig,
11 CausalLanguageModelTrainer,
12 get_modules_by_type,
13 easystate_to_huggingface_model,
14 )
15 from datasets import load_dataset
16 from flax.core import FrozenDict

File :1075, in handle_fromlist(module, fromlist, import, recursive)

File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:48, in _LazyModule.getattr(self, name)
46 value = self._get_module(name)
47 elif name in self._class_to_module.keys():
---> 48 module = self._get_module(self._class_to_module[name])
49 value = getattr(module, name)
50 else:

File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:60, in _LazyModule._get_module(self, module_name)
58 return importlib.import_module("." + module_name, self.name)
59 except Exception as e:
---> 60 raise RuntimeError(
61 f"Failed to import {self.name}.{module_name} because of the following error (look up to see its"
62 f" traceback):\n{e}"
63 ) from e

RuntimeError: Failed to import easydel.modules.auto_easydel_model because of the following error (look up to see its traceback):
cannot import name 'stats' from 'jax.scipy' (/usr/local/lib/python3.10/site-packages/jax/scipy/init.py)

To Reproduce

!pip install fjformer datasets gradio wandb sentencepiece git+https://github.com/huggingface/transformers -U -q #transformers=4.41.0 
!pip install jax[tpu]==0.4.23 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html #current version 0.4.28 but using .22 for stability
!pip install tensorflow --upgrade
HF_TOKEN = "HF_TOKEN"  # Replace with your actual Hugging Face token
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
!apt-get update && apt-get upgrade -y && apt-get install golang -y
!pip install git+https://github.com/erfanzar/EasyDeL.git

import transformers

from easydel import (
    AutoEasyDeLModelForCausalLM,
    TrainArguments,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    EasyDeLState,
    EasyDeLXRapTureConfig,
    CausalLanguageModelTrainer,
    get_modules_by_type,
    easystate_to_huggingface_model,
)
from datasets import load_dataset
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding
import jax
from transformers import AutoConfig
from huggingface_hub import HfApi
from typing import Literal

PartitionSpec = sharding.PartitionSpec
api = HfApi()

sharding_axis_dims = (1, 1, 1, -1)
max_length = 4096
input_shape = (1, max_length)
# input_shape = (8, 8) second try
training_run = 1

pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = f"ssmits/Falcon2-5.5B-Dutch-Chat-cp0"


dtype = jnp.bfloat16
use_lora = False
lora_dim = 16
fully_fine_tune_parameters = False
lora_fine_tune_parameters = False
block_size = 512
attn_mechanism = "sharded_vanilla"

attention_partitions = dict(
    query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    value_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    bias_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, None),
    attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
)

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    input_shape=input_shape,
    device_map="auto",
    sharding_axis_dims=sharding_axis_dims,
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism=attn_mechanism,
        **attention_partitions
    ),
    **attention_partitions
)

config = model.config

model_use_tie_word_embedding = config.tie_word_embeddings

model_parameters = FrozenDict({"params": params})

hi
update your jax version to 0.4.28

After doing that it installs correctly, but freezes when importing the easydel libraires:

Loaded pretrained model: ssmits/Falcon2-5.5B-Dutch
Input shape: (1, 4096)
Attention partitions: {'query_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'key_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'value_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'bias_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, None), 'attention_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp')}
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[2], line 32
     22 attention_partitions = dict(
     23     query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
     24     key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
   (...)
     27     attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
     28 )
     30 print(f"Attention partitions: {attention_partitions}")
---> 32 model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
     33     pretrained_model_name_or_path,
     34     device=jax.devices('cpu')[0],
     35     input_shape=input_shape,
     36     device_map="auto",
     37     sharding_axis_dims=sharding_axis_dims,
     38     config_kwargs=dict(
     39         use_scan_mlp=False,
     40         attn_mechanism=attn_mechanism,
     41         **attention_partitions
     42     ),
     43     **attention_partitions
     44 )
     46 print(f"Loaded model with params shape: {jax.tree_util.tree_map(lambda x: x.shape, params)}")
     48 config = model.config

NameError: name 'AutoEasyDeLModelForCausalLM' is not defined

Should I make a separate issue for this?

Yes, thank you, it's working!