Source code of paper "Transformer Uncertainty Estimation with Hierarchical Stochastic Attention" (AAAI 2022)
This repository releases the source code of stochastic models proposed in paper "Transformer Uncertainty Estimation with Hierarchical Stochastic Attention", which is accepted by AAAI conference in 2022. We implemented stochastic transformer models for the following 2 NLP tasks:
- Sentiment Analysis (code/IMDB);
- Linguistic Acceptability (code/CoLA);
We implemented the model based on pytorch 1.8.1 and python 3.7.6. And config experimental enviroment by the following steps.
conda create -n pytorch_latest_p37 python=3.7 anaconda # creat the virtual environment
source activate pytorch_latest_p37 # activate the environment
sh setup.sh # install all dependent packages
We have experiment on two datasets:
- IMDB: https://pytorch.org/text/_modules/torchtext/datasets/imdb.html
- COLA: https://nyu-mll.github.io/CoLA/
- Preprocessing
python code/IMDB/Run.py --mode=pre --model_name=IMDB --model_type=tf-sto --exp_name=default --job_id=123456 --debug=0
- Train & test N_RUN times with uncertainty
python code/IMDB/Run.py --mode=uncertain-train-test --model_name=IMDB --model_type=tf-sto --exp_name=single_t1 --debug=0
More details can be found at code/IMDB/README.md.
- Downloaded the CoLA dataset from the repository (https://github.com/pranavajitnair/CoLA)
- Train and validate the model run:
python train.py --model_type sto_transformer --inference True --sto_transformer True --model_name dual --dual True
More details can be found at code/CoLA/README.md.
- Emails:
- Jiahuan Pei, jpei@amazon.com
- Cheng Wang, cwngam@amazon.com
- György Szarvas, szarvasg@amazon.com
- Paper
- Direct link
- Citation with bibtex
@inproceedings{pei2022transformer,
title={Transformer uncertainty estimation with hierarchical stochastic attention},
author={Pei, Jiahuan and Wang, Cheng and Szarvas, Gy{\"o}rgy},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={36},
number={10},
pages={11147--11155},
year={2022}
}
See CONTRIBUTING for more information.
This project is licensed under the Apache-2.0 License.