facebookresearch/ijepa

is batch size per gpu important to re-implement the accuracy reported in the paper?

Gus-Guo opened this issue · 4 comments

Hi, I recently has run the vit-14-ep300 config on 16 a100 gpus. But since the gpu would run out of memory in the middle of training, I decrease the batch size from 128 into 112 per gpu. But I obtain a lower linear probe accuracy (like -2%). Is it important to preserve the batch size of 128 per gpu? Thank you very much!

Hi @Gus-Guo, batch size is not important for the method, since there is no batch-component to the loss. However, I expect that changing the batch size might require you to change other parameters accordingly, such as the learning-rate or the ema range. However, running out of memory in the middle of training is not great. Would you mind sharing some logs including:

  • std-out
  • config

Thank you

Hi @Gus-Guo, batch size is not important for the method, since there is no batch-component to the loss. However, I expect that changing the batch size might require you to change other parameters accordingly, such as the learning-rate or the ema range. However, running out of memory in the middle of training is not great. Would you mind sharing some logs including:

  • std-out
  • config

Mido, thank you very much for your reply. My training config is as follows:


data:
batch_size: 112 #bs is 128 in the paper
color_jitter_strength: 0.0
crop_scale:

  • 0.3
  • 1.0
    crop_size: 224
    image_folder: imagenet_full_size/061417/
    num_workers: 2
    pin_mem: true
    use_color_distortion: false
    use_gaussian_blur: false
    use_horizontal_flip: false
    use_fake_data: false
    hdfs_data_path: /workspace/Dataset/imagenet1k/train
    train_dataset_size: 1281167
    logging:
    folder: experiments/ijepa/in1k_vith14_epo300_bs112_16gpu_a100_amp_fp16
    write_tag: jepa
    mask:
    allow_overlap: false
    aspect_ratio:
  • 0.75
  • 1.5
    enc_mask_scale:
  • 0.85
  • 1.0
    min_keep: 10
    num_enc_masks: 1
    num_pred_masks: 4
    patch_size: 14
    pred_mask_scale:
  • 0.15
  • 0.2
    meta:
    copy_data: false
    load_checkpoint: false
    model_name: vit_huge
    pred_depth: 12
    pred_emb_dim: 384
    read_checkpoint: null
    use_bfloat16: true
    optimization:
    ema:
  • 0.996
  • 1.0
    epochs: 300
    final_lr: 1.0e-06
    final_weight_decay: 0.4
    ipe_scale: 1.0
    lr: 0.001
    start_lr: 0.0002
    warmup: 40
    weight_decay: 0.04
    num_accumulate_iter: 1

part of my training logs where running out of memory:

INFO:root:rank [0],[Mon Jul 3 14:54:28 2023]: [2, 190] loss: 0.104 masks: 68.8 44.1 [wd: 4.00e-02] [lr: 2.26e-04] [mem: 7.07e+04] (5109.5 ms)
INFO:root:[2, 190] grad_stats: [6.88e-02 2.97e-04] (2.59e-04, 9.89e-02)
INFO:root:rank [0],[Mon Jul 3 14:55:18 2023]: [2, 200] loss: 0.103 masks: 68.7 44.2 [wd: 4.00e-02] [lr: 2.26e-04] [mem: 7.07e+04] (5105.9 ms)
INFO:root:[2, 200] grad_stats: [4.58e-02 1.88e-04] (1.67e-04, 7.15e-02)
INFO:root:rank [0],[Mon Jul 3 14:56:08 2023]: [2, 210] loss: 0.102 masks: 68.6 44.2 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5097.9 ms)
INFO:root:[2, 210] grad_stats: [1.13e-02 7.24e-05] (6.50e-05, 1.72e-02)
INFO:root:rank [0],[Mon Jul 3 14:56:59 2023]: [2, 220] loss: 0.101 masks: 68.7 44.2 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5096.2 ms)
INFO:root:[2, 220] grad_stats: [2.06e-02 1.49e-04] (1.11e-04, 3.02e-02)
INFO:root:rank [0],[Mon Jul 3 14:57:50 2023]: [2, 230] loss: 0.100 masks: 68.6 44.1 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5100.3 ms)
INFO:root:[2, 230] grad_stats: [1.29e-01 5.10e-04] (4.46e-04, 1.94e-01)
INFO:root:rank [0],[Mon Jul 3 14:58:38 2023]: [2, 240] loss: 0.099 masks: 68.6 44.1 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5086.8 ms)
INFO:root:[2, 240] grad_stats: [5.05e-02 2.08e-04] (1.88e-04, 6.79e-02)
INFO:root:rank [0],[Mon Jul 3 14:59:30 2023]: [2, 250] loss: 0.099 masks: 68.6 44.2 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5090.7 ms)
INFO:root:[2, 250] grad_stats: [4.58e-02 2.69e-04] (1.77e-04, 8.33e-02)
INFO:root:rank [0],[Mon Jul 3 15:00:16 2023]: [2, 260] loss: 0.098 masks: 68.5 44.2 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5071.0 ms)
INFO:root:[2, 260] grad_stats: [2.65e-02 2.61e-04] (1.72e-04, 4.26e-02)
INFO:root:rank [0],[Mon Jul 3 15:01:05 2023]: [2, 270] loss: 0.097 masks: 68.5 44.2 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5062.7 ms)
INFO:root:[2, 270] grad_stats: [2.16e-02 1.40e-04] (1.14e-04, 3.57e-02)
INFO:root:rank [0],[Mon Jul 3 15:01:56 2023]: [2, 280] loss: 0.096 masks: 68.4 44.3 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5064.2 ms)
INFO:root:[2, 280] grad_stats: [2.36e-02 2.30e-04] (1.44e-04, 2.88e-02)
INFO:root:rank [0],[Mon Jul 3 15:02:51 2023]: [2, 290] loss: 0.095 masks: 68.3 44.3 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5080.0 ms)
INFO:root:[2, 290] grad_stats: [1.22e-02 1.14e-04] (9.00e-05, 2.19e-02)

File "/opt/tiger/test_merlin_demo/src/train_iter.py", line 419, in forward_context
z = predictor(z, masks_enc, masks_pred)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/distributed.py", line 892, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 319, in forward
x = blk(x)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 170, in forward
x = x + self.drop_path(self.mlp(self.norm2(x)))
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 118, in forward
x = self.fc1(x)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA out of memory. Tried to allocate 204.00 MiB (GPU 7; 79.35 GiB total capacity; 70.82 GiB already allocated; 180.19 MiB free; 76.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It's surprising that it's happening already more than an epoch into training. Not sure if there are other processes running on your GPU, but could you try changing enc_mask_scale to the range (0.7, 0.8)? This change should result in smaller context blocks, which will use less memory. I'm somewhat hopeful that this wouldn't require any hyperparam changes.

Thank you very much. I will have a try.