bghira/SimpleTuner

Cannot resume training using optimi-stableadamw

Closed this issue · 1 comments

When running a config with optimi-stableadamw and scheduler: constant if you have to stop training and resume it will error out with the stack trace below.

Expected: If you have a constant scheduler and do not change any config settings it should resume from last checkpoint when "resume_from_checkpoint": "latest"

Branch:
commit a5ca5a2daeb78a477f0c4da77c703595ed121e87 (HEAD -> release, tag: v1.0.1, origin/release)

relavent config settings

"model_family": "flux",
"model_type": "lora",
"lora_type": "lycoris",
"lora_rank": 32,
"flux_lora_target": "all+ffs",
"optimizer": "optimi-lion",
"lr_scheduler": "constant",
"learning_rate": "8e-5",
"lr_warmup_steps": 0,
"optimizer_config":"weight_decay=0.001",
"train_batch_size": 1,
"gradient_accumulation_steps": 1,
"init_lokr_norm": "1e-3",
"max_train_steps": 10000,

Stacktrack

Traceback (most recent call last):
  File "/mnt/code/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/mnt/code/SimpleTuner/helpers/training/trainer.py", line 2194, in train
    self.optimizer.step()
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/accelerate/optimizer.py", line 172, in step
    self.optimizer.step(closure)
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 130, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/optimi/stableadamw.py", line 155, in step
    stableadamw(
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/optimi/stableadamw.py", line 264, in stableadamw
    func(
  File "/mnt/code/SimpleTuner/.venv/lib/python3.10/site-packages/optimi/stableadamw.py", line 425, in _foreach_stableadamw
    torch._foreach_mul_(dev_params, scalars=new_wds)
TypeError: _foreach_mul_() received an invalid combination of arguments - got (list, scalars=list), but expected one of:
 * (tuple of Tensors self, tuple of Scalars scalars)
      didn't match because some of the arguments have invalid types: (list of [Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Parameter, Para

***** Running training *****
-  Num batches = 60
-  Num Epochs = 3
  - Current Epoch = 2
-  Total train batch size (w. parallel, distributed & accumulation) = 4
  - Instantaneous batch size per device = 4
  - Gradient Accumulation steps = 1
-  Total optimization steps = 180
  - Steps completed: 110
-  Total optimization steps remaining = 70
Epoch 2/3, Steps:  62%|█████████████████████████▉                | 111/180 [00:14<15:31, 13.50s/it, lr=8e-5, step_loss=0.125]

thanks, fixed on main