OOM for imagenet_gpt_vitvq_base and a 100M params GPT on A100 40G
BlinkDL opened this issue · 3 comments
BlinkDL commented
model:
target: enhancing.modules.stage2.transformer.CondTransformer
params:
cond_key: class
cond:
target: enhancing.modules.cond.dummycond.ClassCond
params:
image_size: 256
class_name: assets/class/imagenet.txt
stage1:
target: enhancing.modules.stage1.vitvqgan.ViTVQ
params:
image_key: image
path: weight/imagenet_vitvq_base.ckpt
image_size: 256
patch_size: 8
encoder:
dim: 768
depth: 12
heads: 12
mlp_dim: 3072
decoder:
dim: 768
depth: 12
heads: 12
mlp_dim: 3072
quantizer:
embed_dim: 32
n_embed: 8192
loss:
target: enhancing.losses.vqperceptual.DummyLoss
transformer:
target: enhancing.modules.stage2.layers.GPT
params:
vocab_cond_size: 1000
vocab_img_size: 8192
embed_dim: 768
cond_num_tokens: 1
img_num_tokens: 1024
n_heads: 12
n_layers: 12
dataset:
target: enhancing.dataloader.DataModuleFromConfig
params:
batch_size: 1
num_workers: 4
train:
target: enhancing.dataloader.imagenet.ImageNetTrain
params:
root: /fsx/ilsvrc2012
resolution: 256
validation:
target: enhancing.dataloader.imagenet.ImageNetValidation
params:
root: /fsx/ilsvrc2012
resolution: 256
parser.add_argument('-c', '--config', type=str, required=True)
parser.add_argument('-s', '--seed', type=int, default=0)
parser.add_argument('-nn', '--num_nodes', type=int, default=1)
parser.add_argument('-ng', '--num_gpus', type=int, default=1)
parser.add_argument('-u', '--update_every', type=int, default=1)
parser.add_argument('-e', '--epochs', type=int, default=100)
parser.add_argument('-lr', '--base_lr', type=float, default=4.5e-4)
parser.add_argument('-a', '--use_amp', default=False, action='store_true')
parser.add_argument('-b', '--batch_frequency', type=int, default=750)
parser.add_argument('-m', '--max_images', type=int, default=4)
args = parser.parse_args()
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
File "/fsx/BlinkDL/CODE/enhancing-transformers/main.py", line 61, in <module>
trainer.fit(model, data)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
self._call_and_handle_interrupt(
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
results = self._run_stage()
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
return self._run_train()
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train
self._run_sanity_check()
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check
val_loop.run()
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
self.advance(*args, **kwargs)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 204, in run
self.advance(*args, **kwargs)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 134, in advance
self._on_evaluation_batch_end(output, **kwargs)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 267, in _on_evaluation_batch_end
self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
fn(self, self.lightning_module, *args, **kwargs)
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/utils/callback.py", line 141, in on_validation_batch_end
self.log_img(pl_module, batch, batch_idx, split="val")
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/utils/callback.py", line 108, in log_img
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/transformer.py", line 204, in log_images
log["first samples"] = self.sample(cond_codes, return_pixels=True)
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/transformer.py", line 87, in sample
logits, codes = self.transformer.sample(conds=conds, top_k=top_k, top_p=top_p,
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 214, in sample
logits_, presents = self.sample_step(codes_, conds, pos_code, use_fp16, past)
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 278, in sample_step
x, present = block.sample(x, layer_past= past[i])
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 122, in sample
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/BlinkDL/CODE/enhancing-transformers/enhancing/modules/stage2/layers.py", line 70, in forward
att = F.softmax(att, dim=-1)
File "/fsx/BlinkDL/conda/lib/python3.9/site-packages/torch/nn/functional.py", line 1834, in softmax
ret = input.softmax(dim)
RuntimeError: CUDA out of memory. Tried to allocate 11.21 GiB (GPU 0; 39.59 GiB total capacity; 24.12 GiB already allocated; 6.96 GiB free; 30.91 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
thuanz123 commented
Hi @BlinkDL, look like the code for sampling is so inefficient that A100 cannot run 🤣 I will investigate this
thuanz123 commented
thuanz123 commented
The issues seem to be fixed as now I can train a very big GPT and generate images without getting OOM so I will close this for now