GuangtaoLyu/FETNet

CUDA OOM

Opened this issue · 0 comments

Hello. Thank you for your work and contribution.
I'm playing with this to remove text and noticed that if I feed a large and medium size images I get the error: OutOfMemoryError: CUDA out of memory. Tried to allocate 485.16 GiB. GPU 0 has a total capacity of 23.50 GiB of which 11.37 GiB is free.

Your model is quite small and I'm surprised why it requires so much memory (more than 10 GB)?

I tried downsizing the input image to 1024 x 1024 but still got OOM.

At the moment I can make an inference model using bfloat16 and an image size of 1024 x 512. In this setup, the model uses 8GB, which is still quite a lot.
My code:

import torch
from modules.Losses import *
from torchvision.utils import make_grid
from torchvision import transforms as T
from PIL import Image
from modules.FETNet import FETNet
from utils.erode import *
from matplotlib import pyplot as plt


def plot(image,si=[12,12]):
    fig, ax = plt.subplots(figsize=si);ax.imshow(image,cmap='gray')
    ax.get_xaxis().set_visible(False);ax.get_yaxis().set_visible(False)
    plt.show()


# Downsample the input image to reduce memory usage
def load_and_preprocess_image(image_path, size=(512, 1024)):           # NOTE: set (1024, 1024) to get OOM
    img = Image.open(image_path).convert('RGB')
    img = img.resize(size)
    img = to_tensor(img).float()
    img = torch.unsqueeze(img, 0)
    return img


to_tensor = T.ToTensor()
to_pil_image = T.ToPILImage()

G = FETNet(3)
ckpt_dict = torch.load("scut_enstext.pth")
G.load_state_dict(ckpt_dict)
G = G.to("cuda").to(torch.bfloat16)
G.eval()

total_params = sum(p.numel() for p in G.parameters())
print(total_params)

img = load_and_preprocess_image("2.jpg")

with torch.no_grad():
    img = img.to("cuda").to(torch.bfloat16)
    fake_B, masks_out = G(img)

comp_B = fake_B * (1 - masks_out) + img * masks_out
for k in range(comp_B.size(0)):
    grid = make_grid(comp_B[k:k + 1])
    grid = to_pil_image(grid)
    plot(grid)

Test image:
2

Therefore the issue remains the same. Don't you think that the model is consuming too much memory?
Perhaps there is a memory leak somewhere in the source code?