Well it goes as follows. Since we want to have an attention that is stable towards attacks we craft one by attacking it and making it stronger:p. Similarly for making it similar to the vanilla attention we try to punish our proposed attention when it fails to mimic the original attention. In short, we define an objective function and train our proposed attention to make it the way we want it to be
It is a whole bunch of maths so let's dumb it down!
- Wo_bar is basically our SEAT and to make it stronger we find a perturbation delta by using the PGD attack
- Then we add the delta to Wo_bar and see how it performs compared to the vanilla attention and penalize it accordingly
- At the same time we punish our attention for not mimicing the predictions and also the vanilla attention
Finally we update our Wo_bar via a standard SGD procedure!
To be honest, theoretically if I can say our goal is to find an attention that is stable to perturbation and also retains the way the original attention works. So in a way we are looking for a perturbation so large that even if we add it to our attention we still get a pretty good result and similarly even though our attention is really stable now the way our attention is highlighting the important regions is not so different from the original attention. That's why the objective function has a PGD attack that maximized the perturbation that can be added which in turn will make the attention more stable and for the second effect we try to minimize the dis-similarity with the original attention via two functions.
So to prove this I tried out two things first compared it with the vanilla attention and second compared the performance after adding some perturbation not on the attention but on the embedding. I used a BiLSTM with Attention and trained it for 10 epochs on IMDB(Reviews) Dataset.
- Jenshen-Shannon-Divergence(Comparing the attention) - 0.00135
- Total-Variation-Distance(Comparing difference in predictions) - 0.318
- Jenshen-Shannon-Divergence(Comparing the attention) - 0.00136
- Total-Variation-Distance(Comparing difference in predictions) - 0.318
- Trying out more models like BERT
- Some more illustrations to prove this better
- Working on more perturbation styles like perturbing the sentence by finding a similar word instead
- Improving the quality of SEAT to match with the results in the paper referred
git clone git@github.com:lazyCodes7/SEAT.git
pip install -r requirements.txt
cd seat
python train.py -d 'cuda'
//yay
@misc{hu2022seat,
title={SEAT: Stable and Explainable Attention},
author={Lijie Hu and Yixin Liu and Ninghao Liu and Mengdi Huai and Lichao Sun and Di Wang},
year={2022},
eprint={2211.13290},
archivePrefix={arXiv},
primaryClass={cs.CL}
}