mashrurmorshed/Torch-KWT

The reason why the test set and validation set have reached an incredible accuracy rate of 99.6%.

qiangqiang-he opened this issue · 7 comments

Hello, ID56. I have raised the question why the accuracy of training set is lower than that of verification set. Later, I have been studying why the accuracy of the code you write is higher than that of the original author, which makes me very confused. After repeated experiments these days, I may know the reason. The reason is mainly divided into two parts. The first one is that my experiments find that it is better to only use spect augment, which is what your code does. The second reason is the most important reason why the accuracy of training set is only 94%. For each epoch, you will augment the original training set and use the augmented data for training. This means that every epoch training set is different, so the model cannot achieve high accuracy in your training set. However, this approach has greatly improved the accuracy of the model in the validation set and the test set. I have done a comparative experiment, which shows that this method is very effective.
Thank you very much for your code, especially for your prompt reply.

Hello, ID56. I have raised the question why the accuracy of training set is lower than that of verification set. Later, I have been studying why the accuracy of the code you write is higher than that of the original author, which makes me very confused. After repeated experiments these days, I may know the reason. The reason is mainly divided into two parts. The first one is that my experiments find that it is better to only use spect augment, which is what your code does. The second reason is the most important reason why the accuracy of training set is only 94%. For each epoch, you will augment the original training set and use the augmented data for training. This means that every epoch training set is different, so the model cannot achieve high accuracy in your training set. However, this approach has greatly improved the accuracy of the model in the validation set and the test set. I have done a comparative experiment, which shows that this method is very effective. Thank you very much for your code, especially for your prompt reply.

Thank you very much for your reply. Even if I use the smallest KWT-1, the accuracy in training set and validation set is 99.5%. Because of the limitation of the GPU memory, I set batch to 256 and epoch to 140, so my steps are equal to 58000, which are twice as large as the original paper.
After my research, I think the first possibility you mentioned is relatively large. I try to use the model to predict the original training set without data augmentation, and the results show that the model can still achieve 99% accuracy. This shows that the model can learn the features of the original data well. On the other hand, I check all your code and find no bugs.

Hello, I also set batch to 256 and epoch to 140, the steps are equal to 48000+, and the KWT-1 model can not achieve 99% accuracy , which is 95%, do you know what's wrong with it?

Hello, I also set batch to 256 and epoch to 140, the steps are equal to 48000+, and the KWT-1 model can not achieve 99% accuracy , which is 95%, do you know what's wrong with it?

Hi @blessyyyu. KWT-1 is indeed supposed to achieve 95-96% accuracy, as per the paper. One of Mars234's observations is correct, that validation and test accuracies are higher than training accuracy. You can see this in issue #6 where I shared a screenshot of my own training logs. I also emailed the author of the KWT paper, Axel Berg, and he sent a picture of his training curves as well, which is similar.

As a sanity check, I also re-ran training with KWT-1 with batch 256 and 140 epochs earlier today. You can see the W&B logs here: https://wandb.ai/dealer56/TorchKWT-1

The test accuracy is 96.12%, which is within the normally expected range.

Now the reason why Mars234 found 99%+ accuracy can be narrowed down to either some problem in the data (i.e. the training, validation and test set somehow got mixed), or a very optimal set of hyper-parameters.

I think you can only do an in-depth investigation into the possibility of 99% if @mars234 shares his training logs, exact hyper-parameter settings, experiment environment settings etc. with you! Otherwise, ~95% with KWT-1 is fine.

Hello, I also set batch to 256 and epoch to 140, the steps are equal to 48000+, and the KWT-1 model can not achieve 99% accuracy , which is 95%, do you know what's wrong with it?

Hi @blessyyyu. KWT-1 is indeed supposed to achieve 95-96% accuracy, as per the paper. One of Mars234's observations is correct, that validation and test accuracies are higher than training accuracy. You can see this in issue #6 where I shared a screenshot of my own training logs. I also emailed the author of the KWT paper, Axel Berg, and he sent a picture of his training curves as well, which is similar.

As a sanity check, I also re-ran training with KWT-1 with batch 256 and 140 epochs earlier today. You can see the W&B logs here: https://wandb.ai/dealer56/TorchKWT-1

The test accuracy is 96.12%, which is within the normally expected range.

Now the reason why Mars234 found 99%+ accuracy can be narrowed down to either some problem in the data (i.e. the training, validation and test set somehow got mixed), or a very optimal set of hyper-parameters.

I think you can only do an in-depth investigation into the possibility of 99% if @mars234 shares his training logs, exact hyper-parameter settings, experiment environment settings etc. with you! Otherwise, ~95% with KWT-1 is fine.


@ID56
Thanks for your sincerely reply! And I also have a question about KWT-3 , I have alreadly set the parameters as same as the "KWT-3" in thesis, but I only get 84% test accurancy. My training log and configs are below:

The training log:

Step: 22600 | epoch: 136 | loss: 1.347745418548584 | lr: 2.1718275077009617e-06
Step: 22620 | epoch: 136 | loss: 1.3820128440856934 | lr: 2.0387097712025697e-06
Step: 22640 | epoch: 136 | loss: 1.3949477672576904 | lr: 1.909813434584174e-06
Step: 22660 | epoch: 136 | loss: 1.4540565013885498 | lr: 1.785139590536111e-06
Step: 22680 | epoch: 136 | loss: 1.4017443656921387 | lr: 1.6646892959535414e-06
Step: 22700 | epoch: 136 | loss: 1.3092560768127441 | lr: 1.5484635719274015e-06
Step: 22720 | epoch: 136 | loss: 1.3603119850158691 | lr: 1.4364634037356905e-06
Step: 22740 | epoch: 136 | loss: 1.3909404277801514 | lr: 1.3286897408352526e-06
Step: 22742 | epoch: 136 | time_per_epoch: 36.41216802597046 | train_acc: 0.7509871173815165 | avg_loss_per_ep: 1.3864288538335316
Step: 22760 | epoch: 137 | loss: 1.3362438678741455 | lr: 1.2251434968537285e-06
Step: 22780 | epoch: 137 | loss: 1.3984098434448242 | lr: 1.1258255495816172e-06
Step: 22800 | epoch: 137 | loss: 1.3996495008468628 | lr: 1.030736740964951e-06
Step: 22820 | epoch: 137 | loss: 1.4183783531188965 | lr: 9.398778770982974e-07
Step: 22840 | epoch: 137 | loss: 1.380089521408081 | lr: 8.532497282176561e-07
Step: 22860 | epoch: 137 | loss: 1.3470677137374878 | lr: 7.708530286942412e-07
Step: 22880 | epoch: 137 | loss: 1.443914532661438 | lr: 6.926884770278756e-07
Step: 22900 | epoch: 137 | loss: 1.4271886348724365 | lr: 6.187567358414398e-07
Step: 22908 | epoch: 137 | time_per_epoch: 35.87131881713867 | train_acc: 0.7505510177622197 | avg_loss_per_ep: 1.3891103942710232
Step: 22920 | epoch: 138 | loss: 1.334773063659668 | lr: 5.490584318750435e-07
Step: 22940 | epoch: 138 | loss: 1.3356997966766357 | lr: 4.83594155980696e-07
Step: 22960 | epoch: 138 | loss: 1.4480316638946533 | lr: 4.22364463117478e-07
Step: 22980 | epoch: 138 | loss: 1.3032946586608887 | lr: 3.65369872346655e-07
Step: 23000 | epoch: 138 | loss: 1.3263620138168335 | lr: 3.126108668272929e-07
Step: 23020 | epoch: 138 | loss: 1.480348825454712 | lr: 2.640878938123728e-07
Step: 23040 | epoch: 138 | loss: 1.3896548748016357 | lr: 2.1980136464468206e-07
Step: 23060 | epoch: 138 | loss: 1.3756014108657837 | lr: 1.7975165475359544e-07
Step: 23074 | epoch: 138 | time_per_epoch: 35.420955181121826 | train_acc: 0.7515057223341938 | avg_loss_per_ep: 1.3880014254386166
Step: 23080 | epoch: 139 | loss: 1.3905423879623413 | lr: 1.4393910365163307e-07
Step: 23100 | epoch: 139 | loss: 1.4403746128082275 | lr: 1.1236401493196271e-07
Step: 23120 | epoch: 139 | loss: 1.3234858512878418 | lr: 8.502665626540214e-08
Step: 23140 | epoch: 139 | loss: 1.419400691986084 | lr: 6.192725939836523e-08
Step: 23160 | epoch: 139 | loss: 1.368742227554321 | lr: 4.306602015097448e-08
Step: 23180 | epoch: 139 | loss: 1.456369400024414 | lr: 2.8443098415284763e-08
Step: 23200 | epoch: 139 | loss: 1.4342108964920044 | lr: 1.805861815378466e-08
Step: 23220 | epoch: 139 | loss: 1.3412597179412842 | lr: 1.1912667398730207e-08
Step: 23240 | epoch: 139 | time_per_epoch: 35.43096947669983 | train_acc: 0.7490187758565822 | avg_loss_per_ep: 1.3912603689963559
Step: 23240 | epoch: 139 | val_loss: 1.091590517759323 | val_acc: 0.8639414888287746
Saved ./runs/exp-0.0.3/last.pth with accuracy 0.8639414888287746.
Step: 23406 | test_loss_last: 1.135682306506417 | test_acc_last: 0.8426169922762381
Step: 23406 | test_loss_best: 1.1353701949119568 | test_acc_best: 0.8425261244888687

The config:

data_root: ./data/
train_list_file: ./data/training_list.txt
val_list_file: ./data/validation_list.txt
test_list_file: ./data/testing_list.txt
label_map: ./data/label_map.json

exp:
    wandb: False
    wandb_api_key: <path/to/api/key>
    proj_name: torch-kwt-3
    exp_dir: ./runs
    exp_name: exp-0.0.3
    device: auto
    log_freq: 20    # log every l_f steps
    log_to_file: True
    log_to_stdout: True
    val_freq: 5    # validate every v_f epochs
    n_workers: 1
    pin_memory: True
    cache: 2 # 0 -> no cache | 1 -> cache wavs | 2 -> cache specs; stops wav augments
    

hparams:
    seed: 0
    batch_size: 512
    n_epochs: 140
    l_smooth: 0.1

    audio:
        sr: 16000
        n_mels: 40
        n_fft: 480
        win_length: 480
        hop_length: 160
        center: False
    
    model:
        name: # if name is provided below settings will be ignored during model creation   
        input_res: [40, 98]
        patch_res: [40, 1]
        num_classes: 35
        mlp_dim: 768
        dim: 192
        heads: 3
        depth: 12
        dropout: 0.0
        emb_dropout: 0.1
        pre_norm: False

    optimizer:
        opt_type: adamw
        opt_kwargs:
          lr: 0.001
          weight_decay: 0.1
    
    scheduler:
        n_warmup: 10
        max_epochs: 140
        scheduler_type: cosine_annealing

    augment:
        resample:
            r_min: 0.85
            r_max: 1.15
        
        time_shift:
            s_min: -0.1
            s_max: 0.1

        bg_noise:
            bg_folder: ./data/_background_noise_/

        spec_aug:
            n_time_masks: 2
            time_mask_width: 25
            n_freq_masks: 2
            freq_mask_width: 7


I truly appreciate your help in resolving my many problems!!

Okay that's interesting. I never tried training KWT3 as it was a large model. KWT (and Vision Transformers in general) are prone to high variance, and need careful regularization or pretraining. I am guessing the training diverged because KWT-3 is too large (overparameterized) and the cache level was 2, which stops some augmentations.

When we use cache: 2, we preemptively convert all audios into MFFCs of shape (40, 98). However this stops the wav-level augmentations like resample, time-shift and bg_noise; only spectral augmentation is applied. So one option is to either use cache: 1 or cache: 0, that should use all the augmentations. However, mode 1 needs a large amount of memory (instead of storing 40x98 MFCCs, you have to store 1x16000 audio arrays) and is still quite slower than mode 2, since you need to convert audios to MFCC at each training step. Mode 0 is even slower as you need to load many audios from disk at each step and convert them to MFCC, not feasible at all.

I suggest you try the following steps:

  1. Start by reducing the learning rate. Try 0.0001 or 0.0003.
    • If accuracy improves, then we can clearly conclude that the large KWT-3 model was overfitting and diverged
    • I will give a training run with 0.0001 now to verify this.
  2. Assuming step 1 worked, we can also try to increase some of the regularization parameters, like:
    • Use weight_decay: 0.15
    • Use dropout: 0.2 or dropout: 0.5
    • Use emb_dropout: 0.2 or emb_dropout: 0.5
  3. Try training with cache: 1 if possible

I will update and let you know if I find a good group of settings for KWT-3. If you find a config that works great, please let me as well!


Also, you can try checking out KW-MLP if you're aiming for higher accuracy than KWT-1; you can reach 97% and higher easily. It's a variant of KWT but uses g-MLP blocks instead of attention. The repository is designed to be quite similar to Torch-KWT (i.e. almost same config files) so you should find it quite easy to use if you've already used Torch-KWT.

@blessyyyu it was as I said.

  • Reducing the initial learning rate to 0.0003 makes KWT-3 have an accuracy of 93.86%. While still a bit away from the paper's accuracy, this is a lot better than the worrying 84% which comes with lr: 0.001.
    • Further reducing initial lr from 0.0003 to 0.0001 does not help, as I get 91.73% accuracy then. So 0.0001 is a too small lr.
    • If you're curious as to why I tried exactly 0.0003 and 0.0001, it's because this is what Andrew Ng personally suggests for manually adjusting learning rates (https://youtu.be/zLRB4oupj6g?t=479), i.e., searching in a factor of 3.

I think you can achieve higher accuracy by slightly increasing the dropouts and/or the weight decay, Try seeing what happens if you use:

model:
        name:
        input_res: [40, 98]
        patch_res: [40, 1]
        num_classes: 35
        mlp_dim: 768
        dim: 192
        heads: 3
        depth: 12
        dropout: 0.2
        emb_dropout: 0.2
        pre_norm: False

optimizer:
        opt_type: adamw
        opt_kwargs:
          lr: 0.0003
          weight_decay: 0.1

And if it is possible for you to train with cache mode 1 (i.e. if you aren't running this on colab/kaggle), that should also help.

@ID56 Thanks very much for your prompt reply and showing your expriment result.
I think your expriment design is resonable, and I will try it on my server. (not on colab and kaggle).

by the way, I noted that there are some expriments about model compression and testing latency in the original thesis, but somehow, I seemingly not see these codes in this project.

@blessyyyu yes, the original paper did a latency experiment using Tensorflow (TF) Lite format on a OnePlus 6 mobile phone (Snapdragon 845 / 4x Arm Cortex-A75, 4xArm Cortex-A55). As I don't have the OnePlus 6 device, and also because these are done in TensorFlow, I cannot replicate these experiments.

The KWT paper also does some experiments with Knowledge Distillation, yes. However, while KD is typically used to compress models, in KWT, no actual compression is done. Rather, the distillation loss is used to only provide a small boost to accuracy. Implementing this distillation process will also need me to implement and train the Multi-Headed Attention RNN model, which will be used as the teacher network---and I haven't found the time yet to do this!