dbolya/tomesd

Out of Memory when used with Tiled VAE

Haoming02 opened this issue · 2 comments

I have a RTX 3070 Ti with 8GB VRAM

Using Automatic1111 Webui

In img2img, without the ToMe patch, I was able to upscale a 1024x1024 image to 2048x2048 using Tiled VAE, with Encoder Tile Size set to 1024 and Decoder Tile Size set to 96. The VRAM usage was around 6~7 GB.

However, if I apply the ToME patch, regular generation does become faster. But when I try to upscale 1024x1024 image again, it starts throwing Out of Memory Error, even when I set the Encoder Tile Size lower to 512 and Decoder Tile Size to 64.

The implementation I used was this, which simply calls tomesd.apply_patch(sd_model, ratio=0.3) inside on_model_loaded(sd_model).

Is this a problem on my part? Did I write the implementation wrong?
Or is it something else?

Full Error Below:

Traceback (most recent call last):
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\call_queue.py", line 56, in f
    res = list(func(*args, **kwargs))
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\call_queue.py", line 37, in f
    res = func(*args, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\img2img.py", line 172, in img2img
    processed = process_images(p)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\processing.py", line 503, in process_images
    res = process_images_inner(p)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\extensions\sd-webui-controlnet\scripts\batch_hijack.py", line 42, in processing_process_images_hijack
    return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\processing.py", line 653, in process_images_inner
    samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\processing.py", line 1087, in sample
    samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 336, in sample_img2img
    samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 239, in launch_sampling
    return func()
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 336, in <lambda>
    samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\sampling.py", line 553, in sample_dpmpp_sde
    denoised = model(x, sigmas[i] * s_in, **extra_args)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 127, in forward
    x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\external.py", line 112, in forward
    eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\external.py", line 138, in get_eps
    return self.inner_model.apply_model(*args, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_hijack_utils.py", line 17, in <lambda>
    setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_hijack_utils.py", line 28, in __call__
    return self.__orig_func(*args, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\models\diffusion\ddpm.py", line 858, in apply_model
    x_recon = self.model(x_noisy, t, **cond)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\models\diffusion\ddpm.py", line 1335, in forward
    out = self.diffusion_model(x, t, context=cc)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\openaimodel.py", line 802, in forward
    h = module(h, emb, context)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\openaimodel.py", line 84, in forward
    x = layer(x, context)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\attention.py", line 334, in forward
    x = block(x, context=context[i])
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\attention.py", line 269, in forward
    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\util.py", line 121, in checkpoint
    return CheckpointFunction.apply(func, len(inputs), *args)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\util.py", line 136, in forward
    output_tensors = ctx.run_function(*ctx.input_tensors)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\tomesd\patch.py", line 51, in _forward
    m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(x, self._tome_info)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\tomesd\patch.py", line 24, in compute_merge
    m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, not use_rand)
  File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\tomesd\merge.py", line 85, in bipartite_soft_matching_random2d
    scores = a @ b.transpose(-1, -2)
torch.cuda.OutOfMemoryError: CUDA out of memory. 

Additionally,
With ToMe patch applied, I was able to upscale from 1024x1024 to 1536x1536 using 1024 and 128 settings. Just not to 2048x2048 no matter the settings.

dbolya commented

It's unfortunate, but the current implementation of ToMe uses more memory when also using xformers/flash attn/torch 2.0 sdp attn or whatever.
Without those implementation, ToMe reduces memory usage by reducing the size of the attention matrices (which were absolutely massive to begin with). But flash attn-like methods already make computing attention a (linear?) space operation, because they don't compute the whole thing at once.

That leaves ToMe in an awkward spot, because it computes similarities for merging all at once, creating a (3*#tokens / 4) x (#tokens / 4) matrix before immediately argmaxing it down to a 3 * #tokens / 4 vector. I think first matrix is the problem here. Normally, the smaller attn matrices more than make up for the extra space taken by that similarity matrix, but flash attn-like methods make that not the case anymore.

Now, ToMe doesn't actually need to compute this whole matrix, so there is hope. We only need the argmax over the similarities, not the similarities themselves. I'm just not sure how to implement that in native pytorch (flash attn et al. implement it using custom cuda kernels, which I don't want to use because that's what makes it require compilation).