Out of Memory when used with Tiled VAE
Haoming02 opened this issue · 2 comments
I have a RTX 3070 Ti with 8GB VRAM
Using Automatic1111 Webui
In img2img, without the ToMe patch, I was able to upscale a 1024x1024 image to 2048x2048 using Tiled VAE, with Encoder Tile Size
set to 1024
and Decoder Tile Size
set to 96
. The VRAM usage was around 6~7 GB.
However, if I apply the ToME patch, regular generation does become faster. But when I try to upscale 1024x1024 image again, it starts throwing Out of Memory Error, even when I set the Encoder Tile Size
lower to 512
and Decoder Tile Size
to 64
.
The implementation I used was this, which simply calls tomesd.apply_patch(sd_model, ratio=0.3)
inside on_model_loaded(sd_model)
.
Is this a problem on my part? Did I write the implementation wrong?
Or is it something else?
Full Error Below:
Traceback (most recent call last):
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\call_queue.py", line 56, in f
res = list(func(*args, **kwargs))
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\call_queue.py", line 37, in f
res = func(*args, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\img2img.py", line 172, in img2img
processed = process_images(p)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\processing.py", line 503, in process_images
res = process_images_inner(p)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\extensions\sd-webui-controlnet\scripts\batch_hijack.py", line 42, in processing_process_images_hijack
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\processing.py", line 653, in process_images_inner
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\processing.py", line 1087, in sample
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 336, in sample_img2img
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 239, in launch_sampling
return func()
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 336, in <lambda>
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\sampling.py", line 553, in sample_dpmpp_sde
denoised = model(x, sigmas[i] * s_in, **extra_args)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 127, in forward
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\external.py", line 112, in forward
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\external.py", line 138, in get_eps
return self.inner_model.apply_model(*args, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_hijack_utils.py", line 17, in <lambda>
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\modules\sd_hijack_utils.py", line 28, in __call__
return self.__orig_func(*args, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\models\diffusion\ddpm.py", line 858, in apply_model
x_recon = self.model(x_noisy, t, **cond)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\models\diffusion\ddpm.py", line 1335, in forward
out = self.diffusion_model(x, t, context=cc)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\openaimodel.py", line 802, in forward
h = module(h, emb, context)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\openaimodel.py", line 84, in forward
x = layer(x, context)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\attention.py", line 334, in forward
x = block(x, context=context[i])
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\attention.py", line 269, in forward
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\util.py", line 121, in checkpoint
return CheckpointFunction.apply(func, len(inputs), *args)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\repositories\stable-diffusion-stability-ai\ldm\modules\diffusionmodules\util.py", line 136, in forward
output_tensors = ctx.run_function(*ctx.input_tensors)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\tomesd\patch.py", line 51, in _forward
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(x, self._tome_info)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\tomesd\patch.py", line 24, in compute_merge
m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, not use_rand)
File "C:\Users\admin\Documents\GitHub\stable-diffusion-webui\venv\lib\site-packages\tomesd\merge.py", line 85, in bipartite_soft_matching_random2d
scores = a @ b.transpose(-1, -2)
torch.cuda.OutOfMemoryError: CUDA out of memory.
Additionally,
With ToMe patch applied, I was able to upscale from 1024x1024 to 1536x1536 using 1024
and 128
settings. Just not to 2048x2048 no matter the settings.
It's unfortunate, but the current implementation of ToMe uses more memory when also using xformers/flash attn/torch 2.0 sdp attn or whatever.
Without those implementation, ToMe reduces memory usage by reducing the size of the attention matrices (which were absolutely massive to begin with). But flash attn-like methods already make computing attention a (linear?) space operation, because they don't compute the whole thing at once.
That leaves ToMe in an awkward spot, because it computes similarities for merging all at once, creating a (3*#tokens / 4) x (#tokens / 4)
matrix before immediately argmaxing it down to a 3 * #tokens / 4
vector. I think first matrix is the problem here. Normally, the smaller attn matrices more than make up for the extra space taken by that similarity matrix, but flash attn-like methods make that not the case anymore.
Now, ToMe doesn't actually need to compute this whole matrix, so there is hope. We only need the argmax over the similarities, not the similarities themselves. I'm just not sure how to implement that in native pytorch (flash attn et al. implement it using custom cuda kernels, which I don't want to use because that's what makes it require compilation).