GPU OOM if model ran in Python multithreading
xcharleslin opened this issue · 0 comments
xcharleslin commented
Minimum repro:
import torch
from min_dalle import MinDalle
from concurrent.futures import ThreadPoolExecutor
USE_GPU = True
def f(text: str, root: str):
return MinDalle(
models_root=f'./{root}',
dtype=torch.float32,
device="cuda",
is_mega=False,
is_reusable=True,
).generate_image(
text,
seed=-1,
grid_size=1,
is_seamless=False,
temperature=1,
top_k=256,
supercondition_factor=32,
)
# No threading works
f("hello", "root1")
# Threading does not work
tpe = ThreadPoolExecutor()
tpe.submit(f, "hello2", "root2").result() # GPU OOMs here
The last line fails with OutOfMemoryError: CUDA out of memory.
(click for full stack trace)
using device cuda
downloading tokenizer params
intializing TextTokenizer
downloading encoder params
initializing DalleBartEncoder
downloading decoder params
initializing DalleBartDecoder
downloading detokenizer params
initializing VQGanDetokenizer
---------------------------------------------------------------------------
OutOfMemoryError Traceback (most recent call last)
<ipython-input-11-7b2fe8b33ac0> in <module>
5 fut = tpe.submit(f, "abc", "def")
6
----> 7 fut.result()
12 frames
/usr/lib/python3.8/concurrent/futures/_base.py in result(self, timeout)
442 raise CancelledError()
443 elif self._state == FINISHED:
--> 444 return self.__get_result()
445 else:
446 raise TimeoutError()
/usr/lib/python3.8/concurrent/futures/_base.py in __get_result(self)
387 if self._exception:
388 try:
--> 389 raise self._exception
390 finally:
391 # Break a reference cycle with the exception in self._exception
/usr/lib/python3.8/concurrent/futures/thread.py in run(self)
55
56 try:
---> 57 result = self.fn(*self.args, **self.kwargs)
58 except BaseException as exc:
59 self.future.set_exception(exc)
<ipython-input-9-7e6467e7527a> in f(text, dir)
5 USE_GPU = True
6 def f(text: str, dir: str) -> PIL.Image.Image:
----> 7 return MinDalle(
8 models_root=f'./{dir}',
9 dtype=torch.float32,
/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image(self, *args, **kwargs)
279 progressive_outputs=False
280 )
--> 281 return next(image_stream)
282
283
/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_image_stream(self, *args, **kwargs)
259 def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
260 image_stream = self.generate_raw_image_stream(*args, **kwargs)
--> 261 for image in image_stream:
262 image = image.to(torch.uint8).to('cpu').numpy()
263 yield Image.fromarray(image)
/usr/local/lib/python3.8/dist-packages/min_dalle/min_dalle.py in generate_raw_image_stream(self, text, seed, grid_size, progressive_outputs, is_seamless, temperature, top_k, supercondition_factor, is_verbose)
238 torch.cuda.empty_cache()
239 with torch.cuda.amp.autocast(dtype=self.dtype):
--> 240 image_tokens[:, i + 1], attention_state = self.decoder.sample_tokens(
241 settings=settings,
242 attention_mask=attention_mask,
/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in sample_tokens(self, settings, **kwargs)
175
176 def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]:
--> 177 logits, attention_state = self.forward(**kwargs)
178 image_count = logits.shape[0] // 2
179 temperature = settings[[0]]
/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, attention_mask, encoder_state, attention_state, prev_tokens, token_index)
162 decoder_state = self.layernorm_embedding.forward(decoder_state)
163 for i in range(self.layer_count):
--> 164 decoder_state, attention_state[i] = self.layers[i].forward(
165 decoder_state,
166 encoder_state,
/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, encoder_state, attention_state, attention_mask, token_index)
88 residual = decoder_state
89 decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
---> 90 decoder_state, attention_state = self.self_attn.forward(
91 decoder_state=decoder_state,
92 attention_state=attention_state,
/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_decoder.py in forward(self, decoder_state, attention_state, attention_mask, token_index)
43 values = attention_state[batch_count:]
44
---> 45 decoder_state = super().forward(keys, values, queries, attention_mask)
46 return decoder_state, attention_state
47
/usr/local/lib/python3.8/dist-packages/min_dalle/models/dalle_bart_encoder.py in forward(self, keys, values, queries, attention_mask)
47 queries /= queries.shape[-1] ** 0.5
48 attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
---> 49 attention_weights: FloatTensor = torch.einsum(
50 'bqhc,bkhc->bhqk',
51 queries,
/usr/local/lib/python3.8/dist-packages/torch/functional.py in einsum(*args)
376 # the path for contracting 0 or 1 time(s) is already optimized
377 # or the user has disabled using opt_einsum
--> 378 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
379
380 path = None
OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 14.76 GiB total capacity; 13.67 GiB already allocated; 17.88 MiB free; 13.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF