kuprel/min-dalle

GPU OOM if model ran in Python multithreading

xcharleslin opened this issue · 0 comments

Minimum repro:

import torch
from min_dalle import MinDalle
from concurrent.futures import ThreadPoolExecutor

USE_GPU = True
def f(text: str, root: str):
    return MinDalle(
        models_root=f'./{root}',
        dtype=torch.float32,
        device="cuda",
        is_mega=False, 
        is_reusable=True,
    ).generate_image(
        text,
        seed=-1,
        grid_size=1,
        is_seamless=False,
        temperature=1,
        top_k=256,
        supercondition_factor=32,
    )

# No threading works
f("hello", "root1")  

# Threading does not work
tpe = ThreadPoolExecutor()
tpe.submit(f, "hello2", "root2").result()  # GPU OOMs here

The last line fails with OutOfMemoryError: CUDA out of memory.

(click for full stack trace)
using device cuda
downloading tokenizer params
intializing TextTokenizer
downloading encoder params
initializing DalleBartEncoder
downloading decoder params
initializing DalleBartDecoder
downloading detokenizer params
initializing VQGanDetokenizer
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
<ipython-input-11-7b2fe8b33ac0> in <module>
      5 fut = tpe.submit(f, "abc", "def")
      6 
----> 7 fut.result()

12 frames
/usr/lib/python3.8/concurrent/futures/_base.py in result(self, timeout)
    442                     raise CancelledError()
    443                 elif self._state == FINISHED:
--> 444                     return self.__get_result()
    445                 else:
    446                     raise TimeoutError()

/usr/lib/python3.8/concurrent/futures/_base.py in __get_result(self)
    387         if self._exception:
    388             try:
--> 389                 raise self._exception
    390             finally:
    391                 # Break a reference cycle with the exception in self._exception

/usr/lib/python3.8/concurrent/futures/thread.py in run(self)
     55 
     56         try:
---> 57             result = self.fn(*self.args, **self.kwargs)
     58         except BaseException as exc:
     59             self.future.set_exception(exc)

<ipython-input-9-7e6467e7527a> in f(text, dir)
      5 USE_GPU = True
      6 def f(text: str, dir: str) -> PIL.Image.Image:
----> 7     return MinDalle(
      8         models_root=f'./{dir}',
      9         dtype=torch.float32,

/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image(self, *args, **kwargs)
    279             progressive_outputs=False
    280         )
--> 281         return next(image_stream)
    282 
    283 

/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image_stream(self, *args, **kwargs)
    259     def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
    260         image_stream = self.generate_raw_image_stream(*args, **kwargs)
--> 261         for image in image_stream:
    262             image = image.to(torch.uint8).to('cpu').numpy()
    263             yield Image.fromarray(image)

/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_raw_image_stream(self, text, seed, grid_size, progressive_outputs, is_seamless, temperature, top_k, supercondition_factor, is_verbose)
    238             torch.cuda.empty_cache()
    239             with torch.cuda.amp.autocast(dtype=self.dtype):
--> 240                 image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens(
    241                     settings=settings,
    242                     attention_mask=attention_mask,

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in sample_tokens(self, settings, **kwargs)
    175 
    176     def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]:
--> 177         logits, attention_state = self.forward(**kwargs)
    178         image_count = logits.shape[0] // 2
    179         temperature = settings[[0]]

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, attention_mask, encoder_state, attention_state, prev_tokens, token_index)
    162         decoder_state = self.layernorm_embedding.forward(decoder_state)
    163         for i in range(self.layer_count):
--> 164             decoder_state, attention_state[i] = self.layers[i].forward(
    165                 decoder_state,
    166                 encoder_state,

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, encoder_state, attention_state, attention_mask, token_index)
     88         residual = decoder_state
     89         decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
---> 90         decoder_state, attention_state = self.self_attn.forward(
     91             decoder_state=decoder_state,
     92             attention_state=attention_state,

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, attention_state, attention_mask, token_index)
     43             values = attention_state[batch_count:]
     44 
---> 45         decoder_state = super().forward(keys, values, queries, attention_mask)
     46         return decoder_state, attention_state
     47 

/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_encoder.py in forward(self, keys, values, queries, attention_mask)
     47         queries /= queries.shape[-1] ** 0.5
     48         attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
---> 49         attention_weights: FloatTensor = torch.einsum(
     50             'bqhc,bkhc->bhqk',
     51             queries,

/usr/local/lib/python3.8/dist-packages/torch/functional.py in einsum(*args)
    376         # the path for contracting 0 or 1 time(s) is already optimized
    377         # or the user has disabled using opt_einsum
--> 378         return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    379 
    380     path = None

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.67 GiB already allocated; 17.88 MiB free; 13.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF