hollowstrawberry/kohya-colab

AttributeError: module 'jax.random' has no attribute 'KeyArray'

Closed this issue · 8 comments

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/kohya-trainer/train_network.py:15 in │
│ │
│ 12 from tqdm import tqdm │
│ 13 import torch │
│ 14 from accelerate.utils import set_seed │
│ ❱ 15 from diffusers import DDPMScheduler │
│ 16 │
│ 17 import library.train_util as train_util │
│ 18 from library.train_util import ( │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │
│ │
│ 35 │ │ get_polynomial_decay_schedule_with_warmup, │
│ 36 │ │ get_scheduler, │
│ 37 │ ) │
│ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │
│ 39 │ from .pipelines import ( │
│ 40 │ │ DanceDiffusionPipeline, │
│ 41 │ │ DDIMPipeline, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │
│ │
│ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │
│ 36 from .hub_utils import http_user_agent, send_telemetry │
│ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │
│ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │
│ 39 from .utils import ( │
│ 40 │ CONFIG_NAME, │
│ 41 │ DIFFUSERS_CACHE, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │
│ │
│ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │
│ 48 else: │
│ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │
│ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │
│ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │
│ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │
│ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │
│ │
│ │
│ 77 │ state: DDPMSchedulerState │
│ 78 │
│ 79 │
│ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │
│ 81 │ """ │
│ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │
│ 83 │ Langevin dynamics sampling. │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │
│ FlaxDDPMScheduler │
│ │
│ 213 │ │ model_output: jnp.ndarray, │
│ 214 │ │ timestep: int, │
│ 215 │ │ sample: jnp.ndarray, │
│ ❱ 216 │ │ key: random.KeyArray, │
│ 217 │ │ return_dict: bool = True, │
│ 218 │ │ **kwargs, │
│ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │
│ │
│ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │
│ │
│ 51 │ │ raise AttributeError(message) │
│ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │
│ 53 │ return fn │
│ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │
│ 55 │
│ 56 return getattr │
│ 57 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax.random' has no attribute 'KeyArray'

please help 😌

Same here, please help =(

this is literally happening with all the kohya lora training colab out there

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /content/kohya-trainer/train_network.py:15 in │ │ │ │ 12 from tqdm import tqdm │ │ 13 import torch │ │ 14 from accelerate.utils import set_seed │ │ ❱ 15 from diffusers import DDPMScheduler │ │ 16 │ │ 17 import library.train_util as train_util │ │ 18 from library.train_util import ( │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │ │ │ │ 35 │ │ get_polynomial_decay_schedule_with_warmup, │ │ 36 │ │ get_scheduler, │ │ 37 │ ) │ │ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │ │ 39 │ from .pipelines import ( │ │ 40 │ │ DanceDiffusionPipeline, │ │ 41 │ │ DDIMPipeline, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │ │ │ │ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │ │ 36 from .hub_utils import http_user_agent, send_telemetry │ │ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │ │ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │ │ 39 from .utils import ( │ │ 40 │ CONFIG_NAME, │ │ 41 │ DIFFUSERS_CACHE, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │ │ │ │ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │ │ 48 else: │ │ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │ │ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │ │ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │ │ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │ │ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │ │ │ │ │ │ 77 │ state: DDPMSchedulerState │ │ 78 │ │ 79 │ │ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │ │ 81 │ """ │ │ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │ │ 83 │ Langevin dynamics sampling. │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │ │ FlaxDDPMScheduler │ │ │ │ 213 │ │ model_output: jnp.ndarray, │ │ 214 │ │ timestep: int, │ │ 215 │ │ sample: jnp.ndarray, │ │ ❱ 216 │ │ key: random.KeyArray, │ │ 217 │ │ return_dict: bool = True, │ │ 218 │ │ **kwargs, │ │ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │ │ │ │ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │ │ │ │ 51 │ │ raise AttributeError(message) │ │ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │ │ 53 │ return fn │ │ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │ │ 55 │ │ 56 return getattr │ │ 57 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ AttributeError: module 'jax.random' has no attribute 'KeyArray'

please help 😌

hey friend,i finally settle this problem!try this code maybe can treat>>>>

!pip install timm==0.6.12 fairscale==0.4.13 transformers==4.26.0 requests==2.28.2 accelerate==0.15.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6 jax==0.4.23 jaxlib==0.4.23

this code must be start before starting the 1.1sections,after finish this code you can using it normally,hopefully can help,

wow. that works. thank you.

Thanks a lot! =)

Was it fixed on your XL trainer as well? I still receive the error from the Colab I have saved in my drive.

Delete the saved copy and make a new one from the original. Or just use the original and don't make a copy, whichever you prefer. All 3 colabs were fixed at the same time yesterday.