AttributeError: module 'jax.random' has no attribute 'KeyArray'
Closed this issue · 8 comments
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/kohya-trainer/train_network.py:15 in │
│ │
│ 12 from tqdm import tqdm │
│ 13 import torch │
│ 14 from accelerate.utils import set_seed │
│ ❱ 15 from diffusers import DDPMScheduler │
│ 16 │
│ 17 import library.train_util as train_util │
│ 18 from library.train_util import ( │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │
│ │
│ 35 │ │ get_polynomial_decay_schedule_with_warmup, │
│ 36 │ │ get_scheduler, │
│ 37 │ ) │
│ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │
│ 39 │ from .pipelines import ( │
│ 40 │ │ DanceDiffusionPipeline, │
│ 41 │ │ DDIMPipeline, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │
│ │
│ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │
│ 36 from .hub_utils import http_user_agent, send_telemetry │
│ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │
│ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │
│ 39 from .utils import ( │
│ 40 │ CONFIG_NAME, │
│ 41 │ DIFFUSERS_CACHE, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │
│ │
│ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │
│ 48 else: │
│ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │
│ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │
│ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │
│ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │
│ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │
│ │
│ │
│ 77 │ state: DDPMSchedulerState │
│ 78 │
│ 79 │
│ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │
│ 81 │ """ │
│ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │
│ 83 │ Langevin dynamics sampling. │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │
│ FlaxDDPMScheduler │
│ │
│ 213 │ │ model_output: jnp.ndarray, │
│ 214 │ │ timestep: int, │
│ 215 │ │ sample: jnp.ndarray, │
│ ❱ 216 │ │ key: random.KeyArray, │
│ 217 │ │ return_dict: bool = True, │
│ 218 │ │ **kwargs, │
│ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │
│ │
│ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │
│ │
│ 51 │ │ raise AttributeError(message) │
│ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │
│ 53 │ return fn │
│ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │
│ 55 │
│ 56 return getattr │
│ 57 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax.random' has no attribute 'KeyArray'
please help 😌
Same here, please help =(
this is literally happening with all the kohya lora training colab out there
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /content/kohya-trainer/train_network.py:15 in │ │ │ │ 12 from tqdm import tqdm │ │ 13 import torch │ │ 14 from accelerate.utils import set_seed │ │ ❱ 15 from diffusers import DDPMScheduler │ │ 16 │ │ 17 import library.train_util as train_util │ │ 18 from library.train_util import ( │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │ │ │ │ 35 │ │ get_polynomial_decay_schedule_with_warmup, │ │ 36 │ │ get_scheduler, │ │ 37 │ ) │ │ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │ │ 39 │ from .pipelines import ( │ │ 40 │ │ DanceDiffusionPipeline, │ │ 41 │ │ DDIMPipeline, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │ │ │ │ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │ │ 36 from .hub_utils import http_user_agent, send_telemetry │ │ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │ │ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │ │ 39 from .utils import ( │ │ 40 │ CONFIG_NAME, │ │ 41 │ DIFFUSERS_CACHE, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │ │ │ │ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │ │ 48 else: │ │ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │ │ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │ │ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │ │ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │ │ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │ │ │ │ │ │ 77 │ state: DDPMSchedulerState │ │ 78 │ │ 79 │ │ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │ │ 81 │ """ │ │ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │ │ 83 │ Langevin dynamics sampling. │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │ │ FlaxDDPMScheduler │ │ │ │ 213 │ │ model_output: jnp.ndarray, │ │ 214 │ │ timestep: int, │ │ 215 │ │ sample: jnp.ndarray, │ │ ❱ 216 │ │ key: random.KeyArray, │ │ 217 │ │ return_dict: bool = True, │ │ 218 │ │ **kwargs, │ │ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │ │ │ │ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │ │ │ │ 51 │ │ raise AttributeError(message) │ │ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │ │ 53 │ return fn │ │ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │ │ 55 │ │ 56 return getattr │ │ 57 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ AttributeError: module 'jax.random' has no attribute 'KeyArray'
please help 😌
hey friend,i finally settle this problem!try this code maybe can treat>>>>
!pip install timm==0.6.12 fairscale==0.4.13 transformers==4.26.0 requests==2.28.2 accelerate==0.15.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6 jax==0.4.23 jaxlib==0.4.23
this code must be start before starting the 1.1sections,after finish this code you can using it normally,hopefully can help,
wow. that works. thank you.
This was already fixed, please use the latest trainer.
https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Lora_Trainer.ipynb
Thanks a lot! =)
Was it fixed on your XL trainer as well? I still receive the error from the Colab I have saved in my drive.
Delete the saved copy and make a new one from the original. Or just use the original and don't make a copy, whichever you prefer. All 3 colabs were fixed at the same time yesterday.