/Text-Classification-PyTorch

Implementation of papers for text classification task on SST-1/SST-2

Primary LanguagePython

Text-Classification-PyTorch 🐋

Here is a new boy 🙇 who wants to become a NLPer and his repository for Text Classification. Besides TextCNN and TextAttnBiLSTM, more models will be added in the near future.

Thanks for you Star:star:, Fork and Watch!

Dataset

  • Stanford Sentiment Treebank(SST)
    • SST-1: 5 classes(fine-grained), SST-2: 2 classes(binary)
  • Preprocess
    • Map sentiment values to labels
    • Remove tokens consisting of all non-alphanumeric characters, such as ...

Pre-trained Word Vectors

  • Word2Vec : GoogleNews-vectors-negative300.bin
  • GloVe : glove.840B.300d.txt
    • Because the OOV Rate of GloVe is lower than Word2Vec and the experiment performance is also better than the other one, we use GloVe as pre-trained word vectors.
    • Options for different format word vectors are still preserved in the code.

Model

Result

  • Baseline from the paper
model SST-1 SST-2
CNN-rand 45.0 82.7
CNN-static 45.5 86.8
CNN-non-static 48.0 87.2
CNN-multichannel 47.4 88.1
  • Re-Implementation
model SST-1 SST-2
CNN-rand 34.841 74.500
CNN-static 45.056 84.125
CNN-non-static 46.974 85.886
CNN-multichannel 45.129 85.993
Attention + BiLSTM 47.015 85.632
Attention + BiGRU 47.854 85.102

Requirement

Please install the following library requirements first.

pandas==0.24.2
torch==1.1.0
fire==0.1.3
numpy==1.16.2
gensim==3.7.3

Structure

│  .gitignoreconfig.py            # Global Configurationdatasets.py          # Create Dataloadermain.pypreprocess.pyREADME.mdrequirements.txtutils.py   
│  
├─checkpoints           # Save checkpoint and best model
│      
├─data                  # pretrained word vectors and datasets
│  │  glove.6B.300d.txt
│  │  GoogleNews-vectors-negative300.bin
│  └─stanfordSentimentTreebank # datasets folder
│          
├─modelsTextAttnBiLSTM.pyTextCNN.py__init__.py
│      
└─output_data           # Preprocessed data and vocabulary, etc.

Usage

  • Set global configuration parameters in config.py

  • Preprocess the datasets

$python preprocess.py
  • Train
$python main.py run

You can set the parameters in the config.py and models/TextCNN.py or models/TextAttnBiLSTM.py in the command line.

$python main.py run [--option=VALUE]

For example,

$python main.py run --status='train' --use_model="TextAttnBiLSTM"
  • Test
$python main.py run --status='test' --best_model="checkpoints/BEST_checkpoint_SST-2_TextCNN.pth"

Conclusion

  • The TextCNN model uses the n-gram-like convolution kernel extraction feature, while the TextAttnBiLSTM model uses BiLSTM to capture semantics and long-term dependencies, combined with the attention mechanism for classification.
  • TextCNN Parameter tuning:
    • glove is better than word2vec
    • Use a smaller batch size
    • Add weight decay ($l_2$ constraint), learning rate decay, early stop, etc.
    • Do not set padding_idx=0 in embedding layer
  • TextAttnBiLSTM
    • Apply dropout on embedding layer, LSTM layer, and fully-connected layer

Acknowledge

Reference

[1] Convolutional Neural Networks for Sentence Classification

[2] A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification

[3] Attention-Based Bidirection LSTM for Text Classification