bghira/SimpleTuner

Runtime error when using int8-quanto with SDXL full finetuning

a-l-e-x-d-s-9 opened this issue · 7 comments

can you try the new commit on main?

I updated and tried again, getting this error now:

got: c10::BFloat16 != float
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2411, in train
    self.accelerator.backward(loss)
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/accelerate/accelerator.py", line 2196, in backward
    loss.backward(**kwargs)
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/quantisation/quanto_workarounds.py", line 95, in reshape_qlf_backward
    other_gO = torch.matmul(gO.reshape(-1, out_features).t(), input.reshape(-1, in_features))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float

progress. i can fix that..

Bumping this, I was also running into the exact same error when attempting a full fine-tune of Flux with int8-quanto. After updating to the most recent version, the error has now changed to:

Traceback (most recent call last):
  File "/home/administrator/Documents/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/administrator/Documents/SimpleTuner/helpers/training/trainer.py", line 2265, in train
    model_pred = self.transformer(**flux_transformer_kwargs)[0]
  File "/home/administrator/Documents/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/administrator/Documents/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/administrator/Documents/SimpleTuner/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
  File "/home/administrator/Documents/SimpleTuner/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/administrator/Documents/SimpleTuner/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/administrator/Documents/SimpleTuner/.venv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 380, in forward
    ids = torch.cat((txt_ids, img_ids), dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 512 but got size 4096 for tensor number 1 in the list.

My current configuration is:

    "--resume_from_checkpoint": "latest",
    "--data_backend_config": "config/flux_stickers_config.json",
    "--aspect_bucket_rounding": 2,
    "--seed": 42,
    "--minimum_image_size": 0,
    "--disable_benchmark": false,
    "--output_dir": "/home/administrator/Documents/SimpleTuner/workdir/TrainedLoRAs/full-fine-tune-test",
    "--max_train_steps": 5000,
    "--num_train_epochs": 0,
    "--checkpointing_steps": 500,
    "--checkpoints_total_limit": 2,
    "--tracker_project_name": "training",
    "--tracker_run_name": "full-fine-tune-test",
    "--report_to": "wandb",
    "--model_type": "full",
    "--pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
    "--model_family": "flux",
    "--train_batch_size": 1,
    "--gradient_checkpointing": "true",
    "--caption_dropout_probability": 0.05,
    "--resolution_type": "pixel_area",
    "--resolution": 1024,
    "--validation_seed": 42,
    "--validation_steps": 500,
    "--validation_resolution": "1024",
    "--validation_guidance": 3.0,
    "--validation_guidance_rescale": "0.0",
    "--validation_num_inference_steps": "20",
    "--validation_prompt": "ohwx Man, walking on the beach, sunset, beautiful ocean in the background",
    "--disable_tf32": "true",
    "--mixed_precision": "bf16",
    "--optimizer": "adamw_bf16",
    "--learning_rate": "1e-6",
    "--lr_scheduler": "polynomial",
    "--lr_warmup_steps": 100,
    "--base_model_precision": "int8-quanto",
    "--text_encoder_1_precision": "no_change",
    "--text_encoder_2_precision": "no_change",
    "--validation_torch_compile": "false",
    "--user_prompt_library": "config/prompt_library.json",
    "--override_dataset_config": true
}

you probably can't do a full tune of a model using quantised weights, this is the issue there.

(or make sure you're on all of the latest dependencies. the handling of txt_ids and img_ids has changed.)

i will disable quantisation on full models instead. it isnt really supposed to be done.