# install with pypi
pip install pytorch-pqrnn
# or install locally with poetry
poetry install
Because of this issue, pytorch-qrnn
is no longer compatible with pytorch and it is also not actively maintained. If you want to use a QRNN layer in this model, you have install pytorch-qrnn
with torch <= 1.4
first.
from pytorch_pqrnn.dataset import create_dataloaders
from pytorch_pqrnn.model import PQRNN
model = PQRNN(
b=128,
d=96,
lr=1e-3,
num_layers=2,
dropout=0.5,
output_size=5,
rnn_type="GRU",
multilabel=False,
nhead=2, # used when rnn_type == "Transformer"
)
# Or load the model from your checkpoint
# model = PQRNN.load_from_checkpoint(checkpoint_path="example.ckpt")
# Text data has to be pre-processed with DummyDataset
dataset = DummyDataset(
df[["text", "label"]].to_dict("records"),
has_label=True,
feature_size=128 * 2,
add_eos_tag=True,
add_bos_tag=True,
max_seq_len=512,
label2index={"pos": 1, "neg": 0},
)
# Explicit train/val loop
# Add model.eval() when necessary
dataloader = create_dataloaders(dataset)
for batch in dataloader:
# labels could be an empty tensor if has_label is False when creating the dataset.
# To change what are included in a batch, feel free to change the collate_fn function
# in dataset.py
projections, lengths, labels = batch
logits = model.forward(projections)
# do your magic
Usage: run.py [OPTIONS]
Options:
--task [yelp2|yelp5|toxic] [default: yelp5]
--b INTEGER [default: 128]
--d INTEGER [default: 96]
--num_layers INTEGER [default: 2]
--batch_size INTEGER [default: 512]
--dropout FLOAT [default: 0.5]
--lr FLOAT [default: 0.001]
--nhead INTEGER [default: 4]
--rnn_type [LSTM|GRU|QRNN|Transformer]
[default: GRU]
--data_path TEXT
--help Show this message and exit.
Datasets
- yelp2(polarity): it will be downloaded w/ huggingface/datasets automatically
- yelp5: json file should be downloaded to into
data_path
- toxic: dataset should be downloaded and unzipped to into
data_path
python -W ignore run.py --task yelp2 --b 128 --d 64 --num_layers 4 --data_path data/
Model | Model Size | Yelp Polarity (error rate) | Yelp-5 (accuracy) | Civil Comments (mean auroc) | Command |
---|---|---|---|---|---|
--b 128 --d 64 --num_layers 4 --rnn_type QRNN |
|||||
PRNN (this repo) | 90K | 5.5 | 70.7 | 95.57 | --b 128 --d 64 --num_layers 1 --rnn_type GRU |
PTransformer (this repo) | 618K | 10.8 | 68 | 92.4 | --b 128 --d 64 --num_layers 1 --rnn_type Transformer --nhead 8 |
PRADO1 | 175K | 65.9 | |||
BERT | 335M | 1.81 | 70.58 | 98.8562 |
- Not supported with
torch >= 1.7
- Paper
- Best Kaggle Submission
-
Powered by pytorch-lightning and grid.ai
@software{chenghao_mou_2021_4661601,
author = {Chenghao MOU},
title = {ChenghaoMou/pytorch-pQRNN: Add DOI},
month = apr,
year = 2021,
publisher = {Zenodo},
version = {0.0.3},
doi = {10.5281/zenodo.4661601},
url = {https://doi.org/10.5281/zenodo.4661601}
}