This is the official Pytorch implementation of the paper Pretrained Generalized Autoregressive Model with Adaptive Probabilistic Label Clusters for Extreme Multi-label Text Classification
- Linux
- Python ≥ 3.6
# We recommend you to use Anaconda to create a conda environment conda create --name aplc_xlnet python=3.6 conda activate aplc_xlnet
- PyTorch ≥ 1.4.0
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
- Other requirements:
pip install -r requirements.txt
Download our preprocessed datasets from Google Drive and save them to data/
- Create
train.csv
anddev.csv
. Reference our preprocessed dataset for the format of the CSV file - Create
labels.txt
. Labels should be sorted in descending order according to their frequency - Count the number of positive labels of each sample, select the largest one in all samples, and assign it to the hyperparameter
--pos_label
- Add the
dataset name
into the dictionaryprocessors
andoutput_modes
in the source fileutils_multi_label.py
- Create the bash file and set the hyperparameters in
code/run/
preprocessed data for AttentionXML
Download our preprocessed datasets from Google Drive
- For dataset EURlex, the raw text is from the website
- For dataset Wiki500k, the raw text is from Google Drive
- For dataset Wiki10, AmazonCat and Amazon670k, the raw texts are from The Extreme Classification Repository
Run the commands
- For dataset EURlex:
bash ./run/eurlex.bash
- For dataset Wiki10:
bash ./run/wiki10.bash
- For dataset AmazonCat:
bash ./run/amazoncat.bash
- For dataset Wiki500k:
bash ./run/wiki500k.bash
- For dataset Amazon670k:
bash ./run/amazon670k.bash
-
Download our pretrained models from Google Drive and save them to
models/
-
Run the commands
- For dataset EURlex:
bash ./run/eurlex_eval.bash
- For dataset Wiki10:
bash ./run/wiki10_eval.bash
- For dataset AmazonCat:
bash ./run/amazoncat_eval.bash
- For dataset Wiki500k:
bash ./run/wiki500k_eval.bash
- For dataset Amazon670k:
bash ./run/amazon670k_eval.bash
- For dataset EURlex: