thuanz123/enhancing-transformers

OOM for imagenet_gpt_vitvq_base and a 100M params GPT on A100 40G

BlinkDL opened this issue · 3 comments

model:
    target: enhancing.modules.stage2.transformer.CondTransformer
    params:
        cond_key: class
        cond: 
            target: enhancing.modules.cond.dummycond.ClassCond
            params:
                image_size: 256
                class_name: assets/class/imagenet.txt
        stage1:
            target: enhancing.modules.stage1.vitvqgan.ViTVQ
            params:
                image_key: image
                path: weight/imagenet_vitvq_base.ckpt
                image_size: 256
                patch_size: 8
                encoder:
                    dim: 768
                    depth: 12
                    heads: 12
                    mlp_dim: 3072
                decoder:
                    dim: 768
                    depth: 12
                    heads: 12
                    mlp_dim: 3072
                quantizer:
                    embed_dim: 32
                    n_embed: 8192
                loss:
                    target: enhancing.losses.vqperceptual.DummyLoss
        transformer:
            target: enhancing.modules.stage2.layers.GPT
            params:
                vocab_cond_size: 1000
                vocab_img_size: 8192
                embed_dim: 768
                cond_num_tokens: 1
                img_num_tokens: 1024 
                n_heads: 12
                n_layers: 12
                
dataset:
    target: enhancing.dataloader.DataModuleFromConfig
    params:
        batch_size: 1
        num_workers: 4
        train:
            target: enhancing.dataloader.imagenet.ImageNetTrain
            params:
                root: /fsx/ilsvrc2012
                resolution: 256
                
        validation:
            target: enhancing.dataloader.imagenet.ImageNetValidation
            params:
                root: /fsx/ilsvrc2012
                resolution: 256

    parser.add_argument('-c', '--config', type=str, required=True)
    parser.add_argument('-s', '--seed', type=int, default=0)
    parser.add_argument('-nn', '--num_nodes', type=int, default=1)
    parser.add_argument('-ng', '--num_gpus', type=int, default=1)
    parser.add_argument('-u', '--update_every', type=int, default=1)
    parser.add_argument('-e', '--epochs', type=int, default=100)
    parser.add_argument('-lr', '--base_lr', type=float, default=4.5e-4)
    parser.add_argument('-a', '--use_amp', default=False, action='store_true')
    parser.add_argument('-b', '--batch_frequency', type=int, default=750)
    parser.add_argument('-m', '--max_images', type=int, default=4)
    args = parser.parse_args()
Sanity Checking DataLoader 0:   0%|                                                                                                                                                      | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/fsx/BlinkDL/CODE/enhancing-transformers/main.py", line 61, in <module>
    trainer.fit(model, data)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train
    self._run_sanity_check()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check
    val_loop.run()
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 134, in advance
    self._on_evaluation_batch_end(output, **kwargs)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 267, in _on_evaluation_batch_end
    self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/utils/callback.py", line 141, in on_validation_batch_end
    self.log_img(pl_module, batch, batch_idx, split="val")
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/utils/callback.py", line 108, in log_img
    images = pl_module.log_images(batch, split=split, pl_module=pl_module)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/transformer.py", line 204, in log_images
    log["first samples"] = self.sample(cond_codes, return_pixels=True)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/transformer.py", line 87, in sample
    logits, codes = self.transformer.sample(conds=conds, top_k=top_k, top_p=top_p,
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 214, in sample
    logits_, presents = self.sample_step(codes_, conds, pos_code, use_fp16, past)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 278, in sample_step
    x, present = block.sample(x, layer_past= past[i])
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 122, in sample
    attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 70, in forward
    att = F.softmax(att, dim=-1)
  File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/torch/nn/functional.py", line 1834, in softmax
    ret = input.softmax(dim)
RuntimeError: CUDA out of memory. Tried to allocate 11.21 GiB (GPU 0; 39.59 GiB total capacity; 24.12 GiB already allocated; 6.96 GiB free; 30.91 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Hi @BlinkDL, look like the code for sampling is so inefficient that A100 cannot run 🤣 I will investigate this

Hi @BlinkDL, I think the code is good for training but the sampling code has memory leak problem. For now, I think you can make this function return an empty dict {} to skip sampling during training phase

The issues seem to be fixed as now I can train a very big GPT and generate images without getting OOM so I will close this for now