/pytorch-pQRNN

Implementation of pQRNN in PyTorch

Primary LanguagePythonMIT LicenseMIT

banner

PyPI Maintenance PyPI - License DOI

Installation

# install with pypi
pip install pytorch-pqrnn
# or install locally with poetry
poetry install

Environment

Because of this issue, pytorch-qrnn is no longer compatible with pytorch and it is also not actively maintained. If you want to use a QRNN layer in this model, you have install pytorch-qrnn with torch <= 1.4 first.

Usage

from pytorch_pqrnn.dataset import create_dataloaders
from pytorch_pqrnn.model import PQRNN

model = PQRNN(
  b=128,
  d=96,
  lr=1e-3,
  num_layers=2,
  dropout=0.5,
  output_size=5,
  rnn_type="GRU",
  multilabel=False,
  nhead=2, # used when rnn_type == "Transformer"
)

# Or load the model from your checkpoint
# model = PQRNN.load_from_checkpoint(checkpoint_path="example.ckpt")

# Text data has to be pre-processed with DummyDataset
dataset = DummyDataset(
    df[["text", "label"]].to_dict("records"),
    has_label=True,
    feature_size=128 * 2,
    add_eos_tag=True,
    add_bos_tag=True,
    max_seq_len=512,
    label2index={"pos": 1, "neg": 0},
)

# Explicit train/val loop
# Add model.eval() when necessary
dataloader = create_dataloaders(dataset)
for batch in dataloader:
  # labels could be an empty tensor if has_label is False when creating the dataset. 
  # To change what are included in a batch, feel free to change the collate_fn function
  # in dataset.py
  projections, lengths, labels = batch 
  logits = model.forward(projections)

  # do your magic

CLI Usage

Usage: run.py [OPTIONS]

Options:
  --task [yelp2|yelp5|toxic]      [default: yelp5]
  --b INTEGER                     [default: 128]
  --d INTEGER                     [default: 96]
  --num_layers INTEGER            [default: 2]
  --batch_size INTEGER            [default: 512]
  --dropout FLOAT                 [default: 0.5]
  --lr FLOAT                      [default: 0.001]
  --nhead INTEGER                 [default: 4]
  --rnn_type [LSTM|GRU|QRNN|Transformer]
                                  [default: GRU]
  --data_path TEXT
  --help                          Show this message and exit.

Datasets

  • yelp2(polarity): it will be downloaded w/ huggingface/datasets automatically
  • yelp5: json file should be downloaded to into data_path
  • toxic: dataset should be downloaded and unzipped to into data_path

Example: Yelp Polarity

python -W ignore run.py --task yelp2 --b 128 --d 64 --num_layers 4 --data_path data/

Benchmarks(not optimized)

Model Model Size Yelp Polarity (error rate) Yelp-5 (accuracy) Civil Comments (mean auroc) Command
PQRNN (this repo)0 78K 6.3 70.4 TODO --b 128 --d 64 --num_layers 4 --rnn_type QRNN
PRNN (this repo) 90K 5.5 70.7 95.57 --b 128 --d 64 --num_layers 1 --rnn_type GRU
PTransformer (this repo) 618K 10.8 68 92.4 --b 128 --d 64 --num_layers 1 --rnn_type Transformer --nhead 8
PRADO1 175K 65.9
BERT 335M 1.81 70.58 98.8562
  1. Not supported with torch >= 1.7
  2. Paper
  3. Best Kaggle Submission

Credits

Citation

@software{chenghao_mou_2021_4661601,
  author       = {Chenghao MOU},
  title        = {ChenghaoMou/pytorch-pQRNN: Add DOI},
  month        = apr,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {0.0.3},
  doi          = {10.5281/zenodo.4661601},
  url          = {https://doi.org/10.5281/zenodo.4661601}
}