TPU-v3 Kaggle not working after update
s-smits opened this issue · 5 comments
Describe the bug
It would be great to keep support for TPU-v3's on Kaggle.
After 0.0.66 I get this error:
ImportError Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:58, in _LazyModule._get_module(self, module_name)
57 try:
---> 58 return importlib.import_module("." + module_name, self.name)
59 except Exception as e:File /usr/local/lib/python3.10/importlib/init.py:126, in import_module(name, package)
125 level += 1
--> 126 return _bootstrap._gcd_import(name[level:], package, level)File :1050, in _gcd_import(name, package, level)
File :1027, in find_and_load(name, import)
File :992, in find_and_load_unlocked(name, import)
File :241, in _call_with_frames_removed(f, *args, **kwds)
File :1050, in _gcd_import(name, package, level)
File :1027, in find_and_load(name, import)
File :1006, in find_and_load_unlocked(name, import)
File :688, in _load_unlocked(spec)
File :883, in exec_module(self, module)
File :241, in _call_with_frames_removed(f, *args, **kwds)
File /usr/local/lib/python3.10/site-packages/easydel/modules/init.py:35
1 from . import (
2 llama,
3 deepseek_v2,
(...)
32 mistral
33 )
---> 35 from .auto_easydel_model import (
36 AutoEasyDeLModelForCausalLM as AutoEasyDeLModelForCausalLM,
37 AutoEasyDeLConfig as AutoEasyDeLConfig,
38 AutoShardAndGatherFunctions as AutoShardAndGatherFunctions
39 )File /usr/local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py:12
11 import jax.numpy
---> 12 from fjformer import match_partition_rules, make_shard_and_gather_fns
14 from flax.traverse_util import unflatten_dictFile /usr/local/lib/python3.10/site-packages/fjformer/init.py:68
67 from . import pallas_operations as pallas_operations
---> 68 from . import optimizers as optimizers
69 from . import linen as linenFile /usr/local/lib/python3.10/site-packages/fjformer/optimizers/init.py:1
----> 1 from .adamw import (
2 get_adamw_with_cosine_scheduler as get_adamw_with_cosine_scheduler,
3 get_adamw_with_warm_up_cosine_scheduler as get_adamw_with_warm_up_cosine_scheduler,
4 get_adamw_with_warmup_linear_scheduler as get_adamw_with_warmup_linear_scheduler,
5 get_adamw_with_linear_scheduler as get_adamw_with_linear_scheduler
6 )
7 from .lion import (
8 get_lion_with_cosine_scheduler as get_lion_with_cosine_scheduler,
9 get_lion_with_with_warmup_linear_scheduler as get_lion_with_with_warmup_linear_scheduler,
10 get_lion_with_warm_up_cosine_scheduler as get_lion_with_warm_up_cosine_scheduler,
11 get_lion_with_linear_scheduler as get_lion_with_linear_scheduler
12 )File /usr/local/lib/python3.10/site-packages/fjformer/optimizers/adamw.py:3
2 import chex
----> 3 import optax
6 def get_adamw_with_cosine_scheduler(
7 steps: int,
8 learning_rate: float = 5e-5,
(...)
16
17 ):File /usr/local/lib/python3.10/site-packages/optax/init.py:17
15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
18 from optax import lossesFile /usr/local/lib/python3.10/site-packages/optax/contrib/init.py:21
20 from optax.contrib.complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw
22 from optax.contrib.dadapt_adamw import DAdaptAdamWStateFile /usr/local/lib/python3.10/site-packages/optax/contrib/dadapt_adamw.py:27
26 from optax._src import base
---> 27 from optax._src import utils
30 class DAdaptAdamWState(NamedTuple):File /usr/local/lib/python3.10/site-packages/optax/_src/utils.py:22
21 import jax.numpy as jnp
---> 22 import jax.scipy.stats.norm as multivariate_normal
24 from optax._src import linear_algebraImportError: cannot import name 'stats' from 'jax.scipy' (/usr/local/lib/python3.10/site-packages/jax/scipy/init.py)
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[8], line 3
1 import transformers
----> 3 from easydel import (
4 AutoEasyDeLModelForCausalLM,
5 TrainArguments,
6 EasyDeLOptimizers,
7 EasyDeLSchedulers,
8 EasyDeLGradientCheckPointers,
9 EasyDeLState,
10 EasyDeLXRapTureConfig,
11 CausalLanguageModelTrainer,
12 get_modules_by_type,
13 easystate_to_huggingface_model,
14 )
15 from datasets import load_dataset
16 from flax.core import FrozenDictFile :1075, in handle_fromlist(module, fromlist, import, recursive)
File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:48, in _LazyModule.getattr(self, name)
46 value = self._get_module(name)
47 elif name in self._class_to_module.keys():
---> 48 module = self._get_module(self._class_to_module[name])
49 value = getattr(module, name)
50 else:File /usr/local/lib/python3.10/site-packages/easydel/utils/lazy_import.py:60, in _LazyModule._get_module(self, module_name)
58 return importlib.import_module("." + module_name, self.name)
59 except Exception as e:
---> 60 raise RuntimeError(
61 f"Failed to import {self.name}.{module_name} because of the following error (look up to see its"
62 f" traceback):\n{e}"
63 ) from eRuntimeError: Failed to import easydel.modules.auto_easydel_model because of the following error (look up to see its traceback):
cannot import name 'stats' from 'jax.scipy' (/usr/local/lib/python3.10/site-packages/jax/scipy/init.py)
To Reproduce
!pip install fjformer datasets gradio wandb sentencepiece git+https://github.com/huggingface/transformers -U -q #transformers=4.41.0
!pip install jax[tpu]==0.4.23 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html #current version 0.4.28 but using .22 for stability
!pip install tensorflow --upgrade
HF_TOKEN = "HF_TOKEN" # Replace with your actual Hugging Face token
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
!apt-get update && apt-get upgrade -y && apt-get install golang -y
!pip install git+https://github.com/erfanzar/EasyDeL.git
import transformers
from easydel import (
AutoEasyDeLModelForCausalLM,
TrainArguments,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
EasyDeLState,
EasyDeLXRapTureConfig,
CausalLanguageModelTrainer,
get_modules_by_type,
easystate_to_huggingface_model,
)
from datasets import load_dataset
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding
import jax
from transformers import AutoConfig
from huggingface_hub import HfApi
from typing import Literal
PartitionSpec = sharding.PartitionSpec
api = HfApi()
sharding_axis_dims = (1, 1, 1, -1)
max_length = 4096
input_shape = (1, max_length)
# input_shape = (8, 8) second try
training_run = 1
pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = f"ssmits/Falcon2-5.5B-Dutch-Chat-cp0"
dtype = jnp.bfloat16
use_lora = False
lora_dim = 16
fully_fine_tune_parameters = False
lora_fine_tune_parameters = False
block_size = 512
attn_mechanism = "sharded_vanilla"
attention_partitions = dict(
query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
value_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
bias_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, None),
attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
)
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=input_shape,
device_map="auto",
sharding_axis_dims=sharding_axis_dims,
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism=attn_mechanism,
**attention_partitions
),
**attention_partitions
)
config = model.config
model_use_tie_word_embedding = config.tie_word_embeddings
model_parameters = FrozenDict({"params": params})
hi
update your jax version to 0.4.28
After doing that it installs correctly, but freezes when importing the easydel libraires:
Loaded pretrained model: ssmits/Falcon2-5.5B-Dutch
Input shape: (1, 4096)
Attention partitions: {'query_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'key_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'value_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'bias_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, None), 'attention_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp')}
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[2], line 32
22 attention_partitions = dict(
23 query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
24 key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
(...)
27 attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
28 )
30 print(f"Attention partitions: {attention_partitions}")
---> 32 model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
33 pretrained_model_name_or_path,
34 device=jax.devices('cpu')[0],
35 input_shape=input_shape,
36 device_map="auto",
37 sharding_axis_dims=sharding_axis_dims,
38 config_kwargs=dict(
39 use_scan_mlp=False,
40 attn_mechanism=attn_mechanism,
41 **attention_partitions
42 ),
43 **attention_partitions
44 )
46 print(f"Loaded model with params shape: {jax.tree_util.tree_map(lambda x: x.shape, params)}")
48 config = model.config
NameError: name 'AutoEasyDeLModelForCausalLM' is not defined
Should I make a separate issue for this?
no it's fine take a look at this
https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example
Yes, thank you, it's working!