Here is a new boy 🙇 who wants to become a NLPer and his repository for Text Classification. Besides TextCNN and TextAttnBiLSTM, more models will be added in the near future.
Thanks for you Star:star:, Fork and Watch!
- Stanford Sentiment Treebank(SST)
- SST-1: 5 classes(fine-grained), SST-2: 2 classes(binary)
- Preprocess
- Map sentiment values to labels
- Remove tokens consisting of all non-alphanumeric characters, such as
...
- Word2Vec :
GoogleNews-vectors-negative300.bin
- GloVe :
glove.840B.300d.txt
- Because the OOV Rate of GloVe is lower than Word2Vec and the experiment performance is also better than the other one, we use GloVe as pre-trained word vectors.
- Options for different format word vectors are still preserved in the code.
-
TextCNN
- Paper: Convolutional Neural Networks for Sentence Classification
- See:
models/TextCNN.py
-
TextAttnBiLSTM
- Paper: Attention-Based Bidirection LSTM for Text Classification
- See:
models/TextAttnBiLSTM.py
- Baseline from the paper
model | SST-1 | SST-2 |
---|---|---|
CNN-rand | 45.0 | 82.7 |
CNN-static | 45.5 | 86.8 |
CNN-non-static | 48.0 | 87.2 |
CNN-multichannel | 47.4 | 88.1 |
- Re-Implementation
model | SST-1 | SST-2 |
---|---|---|
CNN-rand | 34.841 | 74.500 |
CNN-static | 45.056 | 84.125 |
CNN-non-static | 46.974 | 85.886 |
CNN-multichannel | 45.129 | 85.993 |
Attention + BiLSTM | 47.015 | 85.632 |
Attention + BiGRU | 47.854 | 85.102 |
Please install the following library requirements first.
pandas==0.24.2
torch==1.1.0
fire==0.1.3
numpy==1.16.2
gensim==3.7.3
│ .gitignore
│ config.py # Global Configuration
│ datasets.py # Create Dataloader
│ main.py
│ preprocess.py
│ README.md
│ requirements.txt
│ utils.py
│
├─checkpoints # Save checkpoint and best model
│
├─data # pretrained word vectors and datasets
│ │ glove.6B.300d.txt
│ │ GoogleNews-vectors-negative300.bin
│ └─stanfordSentimentTreebank # datasets folder
│
├─models
│ TextAttnBiLSTM.py
│ TextCNN.py
│ __init__.py
│
└─output_data # Preprocessed data and vocabulary, etc.
-
Set global configuration parameters in config.py
-
Preprocess the datasets
$python preprocess.py
- Train
$python main.py run
You can set the parameters in the config.py
and models/TextCNN.py
or models/TextAttnBiLSTM.py
in the command line.
$python main.py run [--option=VALUE]
For example,
$python main.py run --status='train' --use_model="TextAttnBiLSTM"
- Test
$python main.py run --status='test' --best_model="checkpoints/BEST_checkpoint_SST-2_TextCNN.pth"
- The
TextCNN
model uses the n-gram-like convolution kernel extraction feature, while theTextAttnBiLSTM
model uses BiLSTM to capture semantics and long-term dependencies, combined with the attention mechanism for classification. - TextCNN Parameter tuning:
- glove is better than word2vec
- Use a smaller batch size
- Add weight decay (
$l_2$ constraint), learning rate decay, early stop, etc. - Do not set
padding_idx=0
in embedding layer
- TextAttnBiLSTM
- Apply dropout on embedding layer, LSTM layer, and fully-connected layer
- Motivated by https://github.com/TobiasLee/Text-Classification
- Thanks to https://github.com/bigboNed3/chinese_text_cnn
- Thanks to https://github.com/ShawnyXiao/TextClassification-Keras
[1] Convolutional Neural Networks for Sentence Classification
[3] Attention-Based Bidirection LSTM for Text Classification