locuslab/convmixer

Cifar10 baseline doesn't reach 95%

K-H-Ismail opened this issue · 13 comments

Hello,
I tried convmixer256 on Cifar-10 with the same timm options specified for ImageNet (except the num_classes) and it doesn't go beyond 90% accuracy. Could you please specify the options used for Cifar-10 experiment ?

Thanks for pointing this out!

I think the key parameter we didn't clearly specify for CIFAR-10 is the allowable "scale" for random cropping. The default parameter setting in timm allows the crop to be as low as 8% the original area of the image (before being resized to the original shape). We thought this didn't make sense for 32x32 CIFAR-10 images, so we changed this to 75%. It also would probably be a good idea to specify the CIFAR-10 mean and standard deviation, though I don't think this will change much.

In particular, try adding the following flags: --scale 0.75 1.0 --mean 0.4914 0.4822 0.4465 --std 0.2471 0.2435 0.2616.

I've also updated the README to mention this.

Hello, thanks for your help. Indeed the crop parameter has an effect on the accuracy, with --scale 0.75 1.0 we could reach up to 93.94% accuracy with convmixer256/16. This is still below the announced 96.74% of the paper, in order to reach this last accuracy, I've tried different batch sizes, adding more epochs ... But still I couldn't manage to get it. Could you please give me a hint ?
This is the timm command I used, and the model is implemented as in the paper:

sh distributed_train.sh 2 
--dataset cifar10     
/path/CIFAR-10-images/     
--train-split /path/CIFAR-10-images/train      
--val-split /path/CIFAR-10-images/test    
--model convmixer_256_16       
-b 128   
-j 2     
--opt adamw     
--epochs 200   
--amp     
--input-size 3 32 32   
--lr 0.01          
--num-classes 10     
--warmup-epochs 0  
--weight-decay 0.01 
--sched onecycle   
--opt-eps=1e-3     
--clip-grad 1.0 
--scale 0.75 1.0 
--mean 0.4914 0.4822 0.4465 
--std 0.2471 0.2435 0.261  
--aa rand-m9-mstd0.5-inc1     
--cutmix 0.5     
--mixup 0.5     
--reprob 0.25     
--remode pixel

What patch size and kernel size are you using? I think this should easily achieve >96% if patch_size=1 and around 95% for patch_size=2 for large kernels. We used batch size 128 (whereas yours is 2*128), but I'm not sure if that would cause such a big difference. I have trained a ConvMixer-256/16 with patch_size=1 and kernel_size=9 on CIFAR-10 with almost the same settings (except for batch size) that achieved 96% by epoch 140/200.

I'll see if increasing the batch size would actually have such a significant effect and get back to you.

Hello, I use a patch size of 1 and a kernel size of 9. Tried smaller batch size (64) and it changed nothing.

def ConvMixer(dim, depth, kernel_size=9, dilation=1, patch_size=7, n_classes=1000):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size, dilation=dilation, groups=dim, padding=dilation * (kernel_size - 1) // 2),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
        ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes)
    )

and I use this model

@register_model
def convmixer_256_16(pretrained=False, **kwargs):
    model = ConvMixer(256, 16, kernel_size=9, patch_size=1, n_classes=10)
    model.default_cfg = _cfg
    return model

dmezh commented

Hi, the paper mentions using a "simple triangular learning rate schedule" - we're trying to replicate your work on CIFAR-10 (in TensorFlow) - we're wondering which LR schedule and parameters you used for the results in Table B. Thank you!

@K-H-Ismail I'm currently training the same model as you using a freshly-cloned version of this repo with the same parameters, other than the batch size which I have set to -b 64. It's on epoch 114/200 and has already reached 94.5% accuracy. The reason for the difference isn't clear to me...

@dmezh We included the learning rate schedule in this repo, though you kind of need to hunt through code to find it. The most important line is this one, which I'll paste below:

sched = lambda t, lr_max: np.interp([t], [0, self.t_initial*2//5, self.t_initial*4//5, self.t_initial], 
                                      [0, lr_max, lr_max/20.0, 0])[0]

Where t_initial should be the total number of epochs you're going to train for, lr_max is your learning rate (we used 0.01 everywhere), and t should be the_current_epoch + (batch_idx + 1) / batches_per_epoch with indices starting at 0.

That said, I think you'll get approximately the same results using something more standard like cosine decay with one cycle.

Let me know if you have any other questions about your replication! Given the interest in our CIFAR-10 results, we'll try to release a more compact training script and model weights for it sometime soon (but in PyTorch).

Hello @tmp-iclr,
Thanks for your time and support. The only thing we have not checked so far and that may differ is the dataset itself: as timm uses raw images for Imagenet and the official pytorch Cifar10 dataset is made directly into downloadable batches, I tried to download the raw Cifar10 images from this repository :
https://github.com/YoongiKim/CIFAR-10-images

Is it the same for you ?

Glad to help. I used https://github.com/knjcode/cifar2png to construct the dataset. I'll see if there's any difference with the one you linked.

By the way, the model I ran with your settings and -b 64 ended up getting 97% accuracy.

So I'm training the same model on the dataset you used, and it does indeed seem to be lagging behind by a few percent. It's too early to say for sure, but this might be the problem...

Upon inspection, it looks like the CIFAR dataset you used has substantial JPEG artifacts -- the images actually look noticeably less sharp and colorful. I'm now pretty sure the dataset discrepancy is, in fact, the problem.

Hello,
Indeed, the dataset was the problem, sorry for that! Usually I am not very luck when reproducing baselines 🤷‍♀️. I will make an issue on the repository https://github.com/YoongiKim/CIFAR-10-images. Thanks for your help and very good article by the way!

No problem! Glad we figured it out, and thanks.

I'm going to go ahead and close this issue, but feel free to reopen it or open another if you have more questions (likewise, @dmezh).