/mup-vit

Everything you need to reproduce "Better plain ViT baselines for ImageNet-1k" in PyTorch, and more

Primary LanguageJupyter NotebookMIT LicenseMIT

The main branch of this repository aims to reproduce Better plain ViT baselines for ImageNet-1k in pytorch, in particular the 76.7% top-1 validation set accuracy of the Head: MLP → linear variant after 90 epochs. This variant is no inferior than the default, and personally I have better experience with simpler prediction head. The changes I have made to the big_vision reference implementation in my attempts to make the results converge reside in the grad_accum_wandb branch. In the rest of this README I would like to highlight some of the discrepancies I resolved.

mup-vit main branch

Training data and budget

In Better plain ViT baselines for ImageNet-1k only the first 99% of the training data is used for training while the remaining 1% is used for minival "to encourage the community to stop selecting design choices on the validation (de-facto test) set". This however is difficult to reproduce with torchvision.datasets since datasets.ImageNet() is ordered by class label, unlike tfds where the ordering is somewhat randomized:

import tensorflow_datasets as tfds
ds = tfds.builder('imagenet2012').as_dataset(split='train[99%:]')
from collections import Counter
c = Counter(int(e['label']) for e in ds)
>>> len(c)
999
>>> max(c.values())
27
>>> min(c.values())
3

Naively trying to do the same with torchvision.datasets prevented the model from learning the last few classes and resulted in near-random performance on the minival: the model only learned the class that happened to stride across the first 99% and the last 1%. Instead of randomly selecting 99% of the training data or copying the tfds 99% slice, I just fell back to training on 100% of the training data. ImageNet-1k has 1281167 training images, so 1024 batch size results in 1281167 // 1024 = 1251 steps if we drop the last odd lot. big_vision however doesn't train the model epoch by epoch: Instead, it makes the dataset iterator infinite and trains for the equivalent number of steps. Furthermore, it round() the number of steps instead of dropping the last. The 90-epoch equivalent therefore would be round(1281167 / 1024 * 90) = 112603 steps and mup-vit main follows this practice.

Warmup

big_vision warms up from 0 learning rate but torch.optim.lr_scheduler.LinearLR() disallows starting from 0 learning rate. I implemented warming up from 0 learning rate with torch.optim.lr_scheduler.LambdaLR() instead.

Weight decay

In big_vision config.wd is only scaled by the global LR scheduling, but for torch.optim.AdamW() "weight_decay" is first multiplied by the LR. The correct equivalent value for weight_decay is therefore 0.1 to match config.lr = 0.001 and config.wd = 0.0001.

Model

The simplified ViT described in Better plain ViT baselines for ImageNet-1k is not readily available in pytorch. E.g. vit_pytorch's simple_vit and simple_flash_attn_vit are rather dated without taking advantage of torch.nn.MultiheadAttention(), so I rolled my own. I have to fix some of the parameter initialization, however:

  1. torch.nn.MultiheadAttention() comes with its own issues. When QKV are of the same dimension, their projection matrices are combined into self.in_proj_weight whose initial values are set with xavier_uniform_(). Likely unintentionally, this means that the values are sampled from uniform distribution U(−a,a) where a = sqrt(3 / (2 * hidden_dim)) instead sqrt(3 / hidden_dim). Furthermore, the output projection is initialized as NonDynamicallyQuantizableLinear() whose initial values are sampled from U(-sqrt(k), sqrt(k)), k = 1 / hidden_dim. Both are therefore re-initialized with U(−a,a) where a = sqrt(3 / hidden_dim)1 to conform with the jax.nn.initializers.xavier_uniform() used by the reference ViT from big_vision.
  2. pytorch's own nn.init.trunc_normal_() doesn't take the effect of truncation on stddev into account, so I used the magic factor from the JAX repo to re-initialize the patchifying nn.Conv2d.

After 1 and 2 all of the summary statistics of the model parameters match that of the reference implementation at initialization.

Data preprocessing and augmentation

Torchvision transforms of v2.RandAugment() default to zero padding whereas big_vision randaug() uses RGB values (128, 128, 128) as the replacement value. In both cases I have specified the latter to conform to the reference implementation. Model trained with all of the above for 90 epoches reached 76.91% top-1 validation set accuracy, but the loss curve and the gradient L2 norm clearly show that it deviates from the reference:

Screenshot 2024-07-01 at 10 28 58 PM

It turned out that RandAugment(num_ops=2, magnitude=10) means very different things in torchvision vs. big_vision. I created the following 224 × 224 black & white calibration grid consists of 56 × 56 black & white squares:

calibration_grid

and applied both versions of RandAugment(2, 10) 100000 times to gather the stats. All of the resulting pixels remain colorless (i.e. for RGB values (r, g, b) r == g == b remains true) so we can sort them from black to white into a spectrum. For the following 2000 × 200 spectra, pixels are sorted top-down, left-right, and each pixel represents 224 * 224 * 100000 / (2000 * 200) = 112 * 112 pixels of the aggregated output, i.e. 1/4 of one output image. In case one batch of 12544 pixels happens to be of different values, I took the average. Here is the spectrum of torchvision RandAugment(2, 10):

torch_vision_randaugment_2_10

Here is the spectrum of torchvision RandAugment(2, 10, fill=[128] * 3). We can see that it just shifts the zero-padding part of the black into the (128, 128, 128) gray:

torch_vision_randaugment_2_10_mid_fill

And here is the spectrum of big_vision randaug(2, 10):

big_vision_randaugment_2_10

Digging into the codebase, we can see that while torchvision's v2.RandAugment() sticks with the original 14-transform lineup of RandAugment: Practical automated data augmentation with a reduced search space, big_vision's own randaug() omits the Identity no-op and adds 3 new transforms Invert, SolarizeAdd, and Cutout, along with other subtler discrepancies (e.g. Sharpness is considered "signed" in torchvision so half of the time the transform blurs the image instead, while in big_vision it always sharpens the image). What I did then is to subclass torchvision's v2.RandAugment(), remove & add transforms accordingly, and use a variety of calibration grids to make sure that they are within ±1 of the RGB values given by the big_vision's counterpart. The sole exception is Contrast: more on that later. Even with that exception, the near-replication of big_vision's randaug(2, 10) results in near-identical spectrum:

torch_vision_randaugment17_2_10

Training with the near-replication of big_vision randaug(2, 10) for 90 epoches reached 77.27% top-1 validation set accuracy and the gradient L2 norm looks the same, but the loss curve still differs:

Screenshot 2024-07-02 at 1 43 47 PM

Screenshot 2024-07-02 at 1 45 30 PM

It turned out that besides the default min scale (8% vs. 5%), the "Inception crop" implemented as torchvision v2.RandomResizedCrop() is not the same as calling tf.slice() with the bbox returned by tf.image.sample_distorted_bounding_box():

  1. They both rejection-sample the crop, but v2.RandomResizedCrop() is hardcoded to try 10 times while tf.image.sample_distorted_bounding_box() defaults to 100 attempts.
  2. v2.RandomResizedCrop() samples the aspect ratio uniformly in log space, tf.image.sample_distorted_bounding_box() samples uniformly in linear space.
  3. v2.RandomResizedCrop() samples the area cropped uniformly while tf.image.sample_distorted_bounding_box() samples the crop height uniformly given the aspect ratio and area range.
  4. If all attempts fail, v2.RandomResizedCrop() at least crops the image to make sure that the aspect ratio falls within range before resizing. tf.image.sample_distorted_bounding_box() just returns the whole image (to be resized).

We can verify this by taking stats of the crop size given the same image. Here is the density plot of (h, w) returned by v2.RandomResizedCrop.get_params(..., scale=(0.05, 1.0), ratio=(3/4, 4/3)), given an image of (height, width) = (256, 512), N = 10000000:

torch_hw_counts

I got almost 14340 crop failures resulting in a bright pixel at the bottom right, but otherwise the density is roughly uniform. In comparison, here is what tf.image.sample_distorted_bounding_box(..., area_range=[0.05, 1]) returns:

tf_hw_counts

While cropping never failed, we can see clearly that it's oversampling smaller crop areas, as if there were light shining from top-left (notebook). The last discrepancy goes away after re-implementing tf.image.sample_distorted_bounding_box()'s sampling logic:

Screenshot 2024-07-28 at 2 50 01 PM Screenshot 2024-07-28 at 2 50 20 PM

This true reproduction model reached 76.94% top-1 validation set accuracy after 90 epoches. Now let's turn our attention to big_vision itself and double-check the effects of its bugs, inconsistencies, and unusual features.

big_vision grad_accum_wandb branch

I first bolted on wandb logging and revived utils.accumulate_gradient() to run 1024 batch size on my GeForce RTX 3080 Laptop GPU. TensorBook is unable to handle shuffle_buffer_size = 250_000 so I shrank it to 150_000. Finally, I fell back to training on 100% of the training data to converge to what I had to do with pytorch. This resulted in 76.74% top-1 validation set accuracy big-vision-repo-attempt referenced above and consistent with the reported 76.7% top-1 validation set accuracy.

contrast() transform

It turned out that one of big_vision's randaug() transforms, contrast(), is broken. In short, what meant to calculate the average grayscale of the image

  # Compute the grayscale histogram, then compute the mean pixel value,
  # and create a constant image size of that value.  Use that as the
  # blending degenerate target of the original image.
  hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
  mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0

is instead calculating image_area / 256, so in our case of 224 × 224 image, mean is always 196. What it should do is the following:

  # Compute the grayscale histogram, then compute the mean pixel value,
  # and create a constant image size of that value.  Use that as the
  # blending degenerate target of the original image.
  hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
  mean = tf.reduce_sum(
      tf.cast(hist, tf.float32) * tf.linspace(0., 255., 256)) / float(image_height * image_width)

We can visualize this bug by using the following calibration grid as the input:

download (8)

and compare the output given by the broken contrast():

download (11)

vs. the output after the fix:

download (12)

Some CV people are aware of this bug (1, 2) but AFAIK it wasn't documented anywhere in the public. As an aside, solarize() transform has its own integer overflow bug but just happens to have no effect when magnitude=_MAX_LEVEL here.

Inconsistent anti-aliasing between training vs. validation

decode_jpeg_and_inception_crop() used by the training data pipeline defaults to bilinear interpolation without anti-aliasing for resizing, but resize_small() used by the validation data pipeline defaults to area interpolation that "always anti-aliases". Furthermore, torchvision doesn't support resizing with area interpolation. For consistency, I changed both to bilinear interpolation with anti-aliasing.

JPEG decoding

tf.io.decode_jpeg() by default lets the system decide the JPEG decompression algorithm. Specifying dct_method="INTEGER_ACCURATE" makes it behave like the PIL/cv2/PyTorch counterpart (see also the last few cells of RandAugmentCalibration.ipynb). This option is exposed as decode(precise=True) in big_vision but is left unexposed for decode_jpeg_and_inception_crop(), so I added the precise argument to the latter.

Changing all of the above seems to have no apparent effect on the model, however (76.87% top-1 validation set accuracy).

Adam 1st order accumulator precision

optax.scale_by_adam() supports the unusual option of using a different dtype for the 1st order accumulator, mu_dtype and the reference implementation uses bfloat16 instead of float32 like the rest of the model. Changing it back to float32, however, still has no apparent effect (76.77% top-1 validation set accuracy).

Shuffle buffer size

Finally, back to shuffle_buffer_size. Unlike torch.utils.data.DataLoader(shuffle=True) which always fully shuffles by indices, tf.data.Dataset.shuffle(buffer_size) needs to load buffer_size's worth of training examples into the main memory and fully shuffles iff buffer_size=dataset.cardinality(). To test whether incomplete shuffle so far has hurt performance, I launched a 8x A100-SXM4-40GB instance on Lambda and trained a big_vision model on it with all of the above and config.input.shuffle_buffer_size = 1281167, size of the ImageNet-1k training set. It still has no apparent effect (76.85% top-1 validation set accuracy).

As a by-product, this also proves that big_vision gradient accumulation and multi-GPU training are fully equivalent.

Conclusions

This is the true end of reproducing the Better plain ViT baselines for ImageNet-1k in pytorch. There is no if/but, no mystery left. It's rather ironic that after checking and fixing discrepancies for months, fixing the last discrepancy turned out to be a step-down (77.27% vs. 76.7%-76.94%) in terms of model performance. I have therefore added --torchvision-inception-crop as an option to switch back to torchvision's Inception crop.

Postscript: Metrics of the models aside, in terms of training walltime, modern (2.2+) PyTorch with compile() and JAX are nearly identical on the same GPU. The tiny difference may well be fully-explained by the overhead of transposing from channels-last to channels-first and converting from tf.Tensor to torch.Tensor. As for hardware comparison, here are the walltimes reported:

Hardware Walltime
TPUv3-8 node 6h30 (99%)
8x A100-SXM4-40GB 5h41m
RTX 3080 Laptop 5d19h32m

8x A100-SXM4-40GB is comparable but faster than a TPUv3-8 node. RTX 3080 Laptop is unsurprisingly out of the league: 1 day on it is about the same as 1 hour on the other two.

1 See pytorch/pytorch#57109 (comment) for the origin of this discrepancy.