/MrMP

Multi-relation Message Passing for Multi-label Text Classification (ICASSP 2022)

Primary LanguagePythonMIT LicenseMIT

MrMP

This repository contains a PyTorch implementation of our ICASSP 2022 paper Multi-relation Message Passing for Multi-label Text Classification.

image info

Requirements

- python~=3.8
- torch~=1.10
- numpy~=1.21.2
- tqdm~=4.62.3
- scipy~=1.7.3
- pandas~=1.3.5
- scikit-learn~=1.0.2

Usage

  1. install the required packages and their dependencies, if your environment does not have them already

    pip install -r requirements.txt
    
  2. download and/or prepare data

    • if you would like to use benchmark datasets in the paper please download here
    • if you would like to use your own dataset bring it into following format:
    data = {'train': {'src': List[List[int]], 'tgt': List[List[int]]},
            'valid': {'src': List[List[int]], 'tgt': List[List[int]]},
            'test' : {'src': List[List[int]], 'tgt': List[List[int]]}, 
            'dict' : {'src': Dict[int, str], 'tgt': Dict[int, str]}
            }
    
  3. pass parameters' settings optional in main.py according to your needs, note that the defaults of hyperparameters are set to tuned values according to the paper, e.g.

    dataset=bibtex
    name=mrmp
    python3.8 -u main.py -dataset $dataset -name $name -mrmp_on $true
    
  4. run python main.py -configuration config.json

Citation

@inproceedings{MrMP_Ozmen22,
	author       = {Ozmen, M. and Zhang, H. and Wang, P. and Coates, M.},
	title        = {Multi-relation Message Passing for Multi-label Text Classification},
	booktitle    = {Proc. IEEE Int. Conf. Acoustics, Speech and Signal Processing (ICASSP)},
	month = "May",
	year = "2022",
	}

The implementation is mainly adapted from [Transformer]. For any questions or comments please start an issue or contact Muberra.