kuprel/min-dalle

oo cuda memory?How much gpu memory he needs

York1996OutLook opened this issue ยท 30 comments

oo cuda memory?How much gpu memory he needs

Mini-version of model takes about 2.4 Gb VRAM for one picture generation with is_reusable=False. So you should be able to run model with --no-mega --grid-size=1 parameters. Like this:
python image_from_text.py --no-mega --grid-size=1 --text="Succubus from world of warcraft, fantasy art."
With --grid-size=3 it takes about 4.1 Gb VRAM.

Mega version of this model takes 7.6 Gb of VRAM with is_reusable=False for one picture. Although it's possible to run it even with ~5Gb of VRAM (at the cost of some reduction in the quality of generation), but it requires some code modifications.

Mega version of this model takes 7.6 Gb of VRAM with is_reusable=False for one picture. Although it's possible to run it even with ~5Gb of VRAM (at the cost of some reduction in the quality of generation), but it requires some code modifications.

Can you please tell me what modifications I need to make to the code to run the mega version?

I'm guessing converting everything to float16? Wonder how that would look

I'm guessing converting everything to float16?

Yes, you are right. You need to use float16 instead of float32 wherever possible. Fortunately, torch contains autocast class that does most of the work for you. Only in a few places you need to convert manually between float32/float16.

List of modifications

Modifications in min_dalle.py

  1. Add import of autocast on the next line after import torch:
    from torch.cuda.amp import autocast
  2. In function init_encoder set half-precision mode after encoder creation. Just add at the end of this code:
self.encoder = DalleBartEncoder(
    attention_head_count = self.attention_head_count,
    embed_count = self.embed_count,
    glu_embed_count = self.glu_embed_count,
    text_token_count = self.text_token_count, 
    text_vocab_count = self.text_vocab_count,
    layer_count = self.layer_count
)

suffix .half().eval() so code should look like this:

self.encoder = DalleBartEncoder(
    attention_head_count = self.attention_head_count,
    embed_count = self.embed_count,
    glu_embed_count = self.glu_embed_count,
    text_token_count = self.text_token_count, 
    text_vocab_count = self.text_vocab_count,
    layer_count = self.layer_count
).half().eval()
  1. Do the same thing ( adding .half().eval() ) in the function init_decoder, so this part of code should look like this:
self.decoder = DalleBartDecoder(
    sample_token_count = self.sample_token_count,
    image_token_count = self.image_token_count,
    image_vocab_count = self.image_vocab_count,
    attention_head_count = self.attention_head_count,
    embed_count = self.embed_count,
    glu_embed_count = self.glu_embed_count,
    layer_count = self.layer_count,
    start_token = self.image_vocab_count
).half().eval()
  1. In function generate_image_tokens you need to use autocast when running models. So, change this line
    encoder_state = self.encoder.forward(text_tokens)
    to this:
with autocast():
    encoder_state = self.encoder.forward(text_tokens)

and this code:

image_tokens = self.decoder.forward(
    image_count,
    text_tokens, 
    encoder_state
)

to this:

with autocast():
    image_tokens = self.decoder.forward(
        image_count,
        text_tokens, 
        encoder_state
    )

Modifications in models/dalle_bart_encoder.py

  1. Just after
attention_bias = torch.where(
    attention_mask,
    self.one * 0,
    self.one * (-torch.inf),
)

add manual conversion (you need a lot of precision here to avoid NaN's):

queries = queries.to(torch.float32)
keys = keys.to(torch.float32)
  1. After attention_weights = torch.softmax(attention_weights, -1) add this two lines:
attention_weights = attention_weights.to(torch.float32)
values = values.to(torch.float32)

Modifications in models/dalle_bart_decoder.py

  1. In the forward method of class DecoderSelfAttention change this starting code:
) -> Tuple[FloatTensor, FloatTensor]:
    keys = self.k_proj.forward(decoder_state)
    values = self.v_proj.forward(decoder_state)
    queries = self.q_proj.forward(decoder_state)

to this:

) -> Tuple[FloatTensor, FloatTensor]:
    attention_state = attention_state.to(torch.float32)
    keys = self.k_proj.forward(decoder_state).to(torch.float32)
    values = self.v_proj.forward(decoder_state).to(torch.float32)
    queries = self.q_proj.forward(decoder_state).to(torch.float32)

Then in decode_step method after this lines:

logits: FloatTensor = (
    logits[:image_count, -1] * (1 - a) + 
    logits[image_count:, -1] * a
)

add this:

logits = logits.to(torch.float32)

Conclusion

Be sure that you are importing your modified code, not default unmodified min-dalle installed by pip. You can add some print("This is my code") to your modified files to be sure of it.

Then, run your code like this:
python image_from_text.py --mega --grid-size=1 --text="Twilight Sparkle from MLP, fantasy art."

Half-precision reduces VRAM usage greatly, it also speeds up loading and inferencing model significantly. But, it slightly reduces quality of generated images (can be tuned by top_k or other filtration methods).

Thanks for this! Do you have some example generated images from float16?

Request "Skinny withered green dragon in the rocky wasteland. She has big knotty clawed paws. Knobby big feet. skin folds. fantasy art."
Model "mega". Half precision (float16):
half precision
Mega. Same request, full precision (float32):
Full precision

Wow they seem the same to me. I might update the model to use float16

Is it any faster with float16?

Wow they seem the same to me.

Not entirely the same. In half model chance to get garbage or some weird images slightly higher. If you compare 5x5 grids for the same request, you can easily notice that float32 model gives you more brilliant consistent pictures in a set then half model.

Is it any faster with float16?

Yes, it loads and generates significantly faster. I did not take exact measurements, but the acceleration is noticeable even on a mini model

and this code:

image_tokens = self.decoder.forward(
    image_count,
    text_tokens, 
    encoder_state
)

to this:

with autocast():
    image_tokens = self.decoder.forward(
        image_count,
        text_tokens, 
        encoder_state
    )

I couldn't find these lines and I get an error "RuntimeError: expected scalar type Float but found Half"

Cool maybe I'll make it a flag then. Have you tried bfloat16 by chance?

I couldn't find these lines and I get an error "RuntimeError: expected scalar type Float but found Half"

Try to modify commit 9bdba57 - I use it.

Have you tried bfloat16 by chance?

No, I don't

I couldn't find these lines and I get an error "RuntimeError: expected scalar type Float but found Half"

Try to modify commit 9bdba57 - I use it.

Thanks, man, you're a freaking wizard!

I made float16 diff patch for the current last commit 309cef6, you can apply it:

diff --git a/image_from_text.py b/image_from_text.py
index 7495bc9..652f5ce 100644
--- a/image_from_text.py
+++ b/image_from_text.py
@@ -1,8 +1,11 @@
 import argparse
 import os
 from PIL import Image
-from min_dalle import MinDalle
 
+import sys
+sys.path.insert(0, "./min_dalle")
+
+from min_dalle import MinDalle
 
 parser = argparse.ArgumentParser()
 parser.add_argument('--mega', action='store_true')
@@ -13,6 +16,8 @@ parser.add_argument('--seed', type=int, default=-1)
 parser.add_argument('--grid-size', type=int, default=1)
 parser.add_argument('--image-path', type=str, default='generated')
 parser.add_argument('--models-root', type=str, default='pretrained')
+parser.add_argument('--token-count', type=int, default=256) # for debugging
+parser.add_argument('--gen-count', type=int, default=1)
 
 
 def ascii_from_image(image: Image.Image, size: int = 128) -> str:
@@ -23,11 +28,11 @@ def ascii_from_image(image: Image.Image, size: int = 128) -> str:
     return '\n'.join(''.join(row) for row in chars)
 
 
-def save_image(image: Image.Image, path: str):
+def save_image(image: Image.Image, path: str, num: int):
     if os.path.isdir(path):
-        path = os.path.join(path, 'generated.jpg')
-    elif not path.endswith('.jpg'):
-        path += '.jpg'
+        path = os.path.join(path, 'generated{0:03}.png'.format(num))
+    elif not path.endswith('.png'):
+        path += '{0:03}.png'.format(num)
     print("saving image to", path)
     image.save(path)
     return image
@@ -39,7 +44,9 @@ def generate_image(
     seed: int,
     grid_size: int,
     image_path: str,
-    models_root: str
+    models_root: str,
+    token_count: int,
+    gen_count: int
 ):
     model = MinDalle(
         is_mega=is_mega, 
@@ -48,9 +55,18 @@ def generate_image(
         is_verbose=True
     )
 
-    image = model.generate_image(text, seed, grid_size, is_verbose=True)
-    save_image(image, image_path)
-    print(ascii_from_image(image, size=128))
+    for i in range(gen_count):
+        if token_count < 256:
+            image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2)
+            tokens_array = image_tokens.to('cpu').detach().numpy()
+            print('image tokens', tokens_array)
+            print('Number of img tokens=',len(tokens_array))
+        else:
+            image = model.generate_image(text, seed, grid_size)
+            #image = model.generate_image(text, i, grid_size)
+            save_image(image, image_path, i)
+            #print(ascii_from_image(image, size=128))
+            print("generated")
 
 
 if __name__ == '__main__':
@@ -62,5 +78,7 @@ if __name__ == '__main__':
         seed=args.seed,
         grid_size=args.grid_size,
         image_path=args.image_path,
-        models_root=args.models_root
-    )
\ No newline at end of file
+        models_root=args.models_root,
+        token_count=args.token_count,
+        gen_count = args.gen_count
+    )
diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py
index b0661d0..9eb2574 100644
--- a/min_dalle/min_dalle.py
+++ b/min_dalle/min_dalle.py
@@ -9,8 +9,10 @@ from typing import Iterator
 torch.set_grad_enabled(False)
 torch.set_num_threads(os.cpu_count())
 
-from .text_tokenizer import TextTokenizer
-from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
+from torch.cuda.amp import autocast
+
+from text_tokenizer import TextTokenizer
+from models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
 
 MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
 
@@ -105,7 +107,7 @@ class MinDalle:
             text_token_count = self.text_token_count,
             text_vocab_count = self.text_vocab_count,
             layer_count = self.layer_count
-        )
+        ).half().eval()
         params = torch.load(self.encoder_params_path)
         self.encoder.load_state_dict(params, strict=False)
         del params
@@ -123,7 +125,7 @@ class MinDalle:
             glu_embed_count = self.glu_embed_count,
             layer_count = self.layer_count,
             start_token = self.image_vocab_count
-        )
+        ).half().eval()
         params = torch.load(self.decoder_params_path)
         self.decoder.load_state_dict(params, strict=False)
         del params
@@ -174,6 +176,7 @@ class MinDalle:
         tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
         if len(tokens) > self.text_token_count: 
             tokens = tokens[:self.text_token_count]
+        print("Number of text tokens=",len(tokens))
         if is_verbose: print("text tokens", tokens)
         text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
         text_tokens[0, :2] = [tokens[0], tokens[-1]]
@@ -183,35 +186,38 @@ class MinDalle:
         if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
 
         if not self.is_reusable: self.init_encoder()
-        if is_verbose: print("encoding text tokens")
-        encoder_state = self.encoder.forward(text_tokens)
+        if self.is_verbose: print("encoding text tokens")
+        with autocast():
+            encoder_state = self.encoder.forward(text_tokens)
         if not self.is_reusable: del self.encoder
         if torch.cuda.is_available(): torch.cuda.empty_cache()
 
         if not self.is_reusable: self.init_decoder()
 
-        encoder_state, attention_mask, attention_state, image_tokens = ( 
-            self.decoder.decode_initial(
-                seed, 
-                grid_size ** 2, 
-                text_tokens, 
-                encoder_state
+        with autocast():
+            encoder_state, attention_mask, attention_state, image_tokens = ( 
+                self.decoder.decode_initial(
+                    seed, 
+                    grid_size ** 2, 
+                    text_tokens, 
+                    encoder_state
+                )
             )
-        )
 
         row_count = 16
         for row_index in range(row_count):
             if is_verbose: 
                 print('sampling row {} of {}'.format(row_index + 1, row_count))
-            attention_state, image_tokens = self.decoder.decode_row(
-                row_index,
-                log2_k,
-                log2_supercondition_factor,
-                encoder_state,
-                attention_mask,
-                attention_state,
-                image_tokens
-            )
+            with autocast():
+                attention_state, image_tokens = self.decoder.decode_row(
+                    row_index,
+                    log2_k,
+                    log2_supercondition_factor,
+                    encoder_state,
+                    attention_mask,
+                    attention_state,
+                    image_tokens
+                )
             if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
                 tokens = image_tokens[:, 1:]
                 image = self.image_from_tokens(grid_size, tokens, is_verbose)
@@ -237,4 +243,4 @@ class MinDalle:
             log2_supercondition_factor,
             is_verbose
         )
-        return next(image_stream)
\ No newline at end of file
+        return next(image_stream)
diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py
index 393618a..0a7bd4b 100644
--- a/min_dalle/models/dalle_bart_decoder.py
+++ b/min_dalle/models/dalle_bart_decoder.py
@@ -35,8 +35,8 @@ class DecoderSelfAttention(AttentionBase):
         attention_state: FloatTensor,
         token_index: LongTensor
     ) -> Tuple[FloatTensor, FloatTensor]:
-        keys = self.k_proj.forward(decoder_state)
-        values = self.v_proj.forward(decoder_state)
+        keys = self.k_proj.forward(decoder_state).to(torch.float32)
+        values = self.v_proj.forward(decoder_state).to(torch.float32)
         queries = self.q_proj.forward(decoder_state)
         attn_mask = self.token_indices < token_index + 1
         attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
@@ -170,6 +170,7 @@ class DalleBartDecoder(nn.Module):
             logits[:image_count, -1] * (1 - a) + 
             logits[image_count:, -1] * a
         )
+        logits = logits.to(torch.float32)
 
         top_logits, _ = logits.topk(2 ** log2_k, dim=-1)
         probs = torch.where(
@@ -238,4 +239,4 @@ class DalleBartDecoder(nn.Module):
 
         if seed > 0: torch.manual_seed(seed)
 
-        return encoder_state, attention_mask, attention_state, image_tokens_sequence
\ No newline at end of file
+        return encoder_state, attention_mask, attention_state, image_tokens_sequence
diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py
index a96cd6b..fdb3c82 100644
--- a/min_dalle/models/dalle_bart_encoder.py
+++ b/min_dalle/models/dalle_bart_encoder.py
@@ -54,6 +54,8 @@ class AttentionBase(nn.Module):
             self.one * 0,
             self.one * (-torch.inf),
         )
+        queries = queries.to(torch.float32)
+        keys = keys.to(torch.float32)
         attention_weights: FloatTensor = torch.einsum(
             'bqhc,bkhc->bhqk',
             queries, 
@@ -61,6 +63,8 @@ class AttentionBase(nn.Module):
         )
         attention_weights += attention_bias[:, None, None, :]
         attention_weights = torch.softmax(attention_weights, -1)
+        attention_weights = attention_weights.to(torch.float32)
+        values = values.to(torch.float32)
         attention_output: FloatTensor = torch.einsum(
             "bhqk,bkhc->bqhc",
             attention_weights, 
@@ -146,4 +150,4 @@ class DalleBartEncoder(nn.Module):
         for layer in self.layers:
             encoder_state = layer.forward(encoder_state, attention_mask)
         encoder_state = self.final_ln.forward(encoder_state)
-        return encoder_state
\ No newline at end of file
+        return encoder_state

Interestingly, it may or may not be a coincidence, but after switching to fp16, even the no-mega version has produced better results to me.

Yes, I noticed that the mini version, unlike the mega one, in fp16 gives more pleasant results. I thought that it was my imagination, because it should be worse, not better. But strangely it really seems that mini-version "likes" low precision.

@iScriptLex I added support for bfloat16 and float16. For float16 though some of the images are black. Did you encounter this problem at all? I tried casting all the things you pointed out to float32 but that didn't solve it.
Unknown-5

Also both bfloat16 and float16 are a little slower than float32 when I test the mega on an A100

No, I have never experienced anything like this. Could you provide the running options that reproduce the problem?

In the colab, the only change would be:

model = MinDalle(is_mega=True, is_reusable=True, dtype=torch.float16)

It seems that you converted your detokenizer (VQGAN) too. This model doesn't work with half precision. It should be float32 to work well (in my code I keep VQGAN as float32). Only bart_encoder and bart_decoder should be converted.
May be I should make a pull request with my float16 code.

Ah that makes sense, I'll try it

That solved it, thanks!!

any idea why it might be slower than float32 though?

any idea why it might be slower than float32 though?

Because that's A100. Try on V100 or (even better) desktop GPU's like GeForce 1080.
A100 has very quick fp32 tensor cores, and its fp16 arithmetic is optimized for very specific memory model. So you should not just convert model to float16, but also change the way how tensors are packed in memory to get fp16 speed-up on convolution layers with A100.
I don't think it's worth optimizing your code for specific GPU architectures...

float16 is slower on a P100 too. Would you expect that?

Would you expect that?

Probably.

Ok cool

how about uint8 model? less memory?