Semi-supervised text classification based on BERT backbone. The project adapts FixMatch algorithm (https://arxiv.org/abs/2001.07685) by introducing an adaptive weak/strong augmentations selection among 6 basic NLP augmentations (from nlpaug library):
- WordEmbsAug (top n similar word random substitutions)
- BackTranslationAug (back translation)
- AbstSummAug (abstractive summarization)
- SynonymAug (random synonims substitution)
- ContextualWordEmbsAug (contextual word embeddings random substitutions)
- ContextualWordEmbsForSentenceAug (extra sentence generation)
Project is based on
- Pytorch-Lightning - deep-learning models
- Hydra - command line arguments managment
- MlFlow - experiments tracking
- Make sure, you have Python 3.7
- Create a virtual environment:
pip install virtualenv
virtualenv venv
source venv/bin/activate
pip3 install -r requirements.txt
One can either run their own mlflow server:
mlflow server --default-artifact-root='/home/ubuntu/semi-supervised-stance-detection/mlruns/'
or connect to an existing one (LRZ server: 10.195.1.127):
ssh -N -f -L localhost:5000:localhost:5000 <user>@10.195.1.127
While running scripts, one should indicate the path to dataset. There are two possible scenarios:
-
In-topic scenario. Train/test/validation split is done randomly, without considering the topics.
data.test_id
run argument should beNone
and the files should be structured in the following way:├── data │ ├── <dataset-name> <- Dataset name | | ├── train.tcv <- Train labelled data | | ├── augmentations_labelled <- Train labelled data augmentations | | | ├── SynonymAug.tsv | | | ├── WordEmbsAug.tsv | | | └── ... | | ├── unlabelled.tsv <- Train unlabelled data | | ├── augmentations_unlabelled <- Train unlabelled data augmentations | | | ├── SynonymAug.tsv | | | ├── WordEmbsAug.tsv | | | └── ... | | ├── test.tcv <- Test data | | └── val.tcv <- Val data ...
To generate augmentations look to Offline augmentations section.
To generate offline augmentations for fully-supervised/semi-supervised settings:
PYTHONPATH=. python3 runnables/generate_augmentations.py
All the configurations are in the .yaml format and could be found in the config/
folder.
Fully-supervised experiments (config/config.yaml
and config/setting/supervised.yaml
):
PYTHONPATH=. python3 ./runnables/train.py -m +setting=supervised
data.path='data/IMDB-clean'
optimizer.lr=1e-6
exp.task_name=SL
data.labels_list=[neg,pos]
exp.gpus='-1'
exp.logging=True
exp.max_epochs=1000
data.max_seq_length=512
exp.early_stopping_patience=1000
data.augment=True
PYTHONPATH=. python3 ./runnables/train.py -m +setting=supervised
data.path='data/in-topic/REVIEWS-clean'
optimizer.lr=1e-6
exp.task_name=SL3
exp.gpus='-1'
exp.logging=True
exp.max_epochs=1000
data.max_seq_length=512
exp.early_stopping_patience=1000
data.augment=True
Semi-supervised setting (config/config.yaml
and config/setting/ssl.yaml
):
PYTHONPATH=. python3 ./runnables/train.py -m +setting=ssl
data.path='data/IMDB-clean'
exp.task_name=SSL
data.labels_list=[neg,pos]
exp.logging=True
exp.gpus="-1"
model.threshold=0.9
model.lambda_u=0.01
optimizer.lr=1e-6
exp.early_stopping_patience=5000
data.max_seq_length=512
model.max_ul_batch_size_per_gpu=200
model.choose_only_wrongly_predicted_branches=True
exp.tsa=False
exp.max_epochs=5000
PYTHONPATH=. python3 ./runnables/train.py -m +setting=ssl
data.path='data/in-topic/REVIEWS-clean'
exp.task_name=SSL3
exp.logging=True
exp.gpus="-1"
model.threshold=0.9
model.lambda_u=0.01
optimizer.lr=1e-5
exp.early_stopping_patience=1000
data.max_seq_length=512
model.max_ul_batch_size_per_gpu=200
model.choose_only_wrongly_predicted_branches=True
exp.tsa=False
exp.max_epochs=1000
Look to scripts/train.sh
or scripts/hparam_search.sh
:
sbatch train.sh
or sbatch hparam_search.sh
We experiment with OpenReview peer-reviews dataset with 3 classes. Reported: test accuracy and F1 score (after hparam tuning wrt. validation accuracy).
Method | Number of classes | Test Accuracy | Test F1 (macro) score |
---|---|---|---|
Fully-Supervised | 3 (pos / neg / non-stance) | 71.6312 | 70.7114 |
FixMatch (Semi-supervised) | 3 (pos / neg / non-stance) | 72.3404 | 72.2475 |
Gain of accuracy is only marginal. FixMatch heavily relies on diverse augmentations (useful for image classification). Unfortunately, augmentations for text data don't provide enough inductive bias for semi-supervised text classification.
Another issue we faced were the generation of augmentations, as it could not be tenzorized. Thus, it was performed offline and augmentations were saved as temporary files.
Project based on the cookiecutter data science project template. #cookiecutterdatascience