LibMTL
is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and API instructions.
⭐ Star us on GitHub — it motivates us a lot!
- [Mar 10 2023]: Added QM9 and PAWS-X examples.
- [Jul 22 2022]: Added support for Nash-MTL (ICML 2022).
- [Jul 21 2022]: Added support for Learning to Branch (ICML 2020). Many thanks to @yuezhixiong (#14).
- [Mar 29 2022]: Paper is now available on the arXiv.
- Features
- Overall Framework
- Supported Algorithms
- Supported Benchmark Datasets
- Installation
- Quick Start
- Citation
- Contributor
- Contact Us
- Acknowledgements
- License
- Unified:
LibMTL
provides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms. - Comprehensive:
LibMTL
supports many state-of-the-art MTL methods including 8 architectures and 13 optimization strategies. Meanwhile,LibMTL
provides a fair comparison of several benchmark datasets covering different fields. - Extensible:
LibMTL
follows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel optimization strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support ofLibMTL
.
Each module is introduced in Docs.
LibMTL
currently supports the following algorithms:
Optimization Strategies | Venues | Comments |
---|---|---|
Equal Weighting (EW) | - | Implemented by us |
Gradient Normalization (GradNorm) | ICML 2018 | Implemented by us |
Uncertainty Weights (UW) | CVPR 2018 | Implemented by us |
MGDA | NeurIPS 2018 | Referenced from official PyTorch implementation |
Dynamic Weight Average (DWA) | CVPR 2019 | Referenced from official PyTorch implementation |
Geometric Loss Strategy (GLS) | CVPR 2019 workshop | Implemented by us |
Projecting Conflicting Gradient (PCGrad) | NeurIPS 2020 | Implemented by us |
Gradient sign Dropout (GradDrop) | NeurIPS 2020 | Implemented by us |
Impartial Multi-Task Learning (IMTL) | ICLR 2021 | Implemented by us |
Gradient Vaccine (GradVac) | ICLR 2021 | Implemented by us |
Conflict-Averse Gradient descent (CAGrad) | NeurIPS 2021 | Referenced from official PyTorch implementation |
Nash-MTL | ICML 2022 | Referenced from official PyTorch implementation |
Random Loss Weighting (RLW) | TMLR 2022 | Implemented by us |
Architectures | Venues | Comments |
---|---|---|
Hard Parameter Sharing (HPS) | ICML 1993 | Implemented by us |
Cross-stitch Networks (Cross_stitch) | CVPR 2016 | Implemented by us |
Multi-gate Mixture-of-Experts (MMoE) | KDD 2018 | Implemented by us |
Multi-Task Attention Network (MTAN) | CVPR 2019 | Referenced from official PyTorch implementation |
Customized Gate Control (CGC), Progressive Layered Extraction (PLE) | ACM RecSys 2020 Best Paper | Implemented by us |
Learning to Branch (LTB) | ICML 2020 | Implemented by us |
DSelect-k | NeurIPS 2021 | Referenced from official TensorFlow implementation |
Datasets | Problems | Task Number | Tasks | Multi/Single-input |
---|---|---|---|---|
NYUv2 | Scene Understanding | 3 | Semantic Segmentation+ Depth Estimation+ Surface Normal Prediction |
S |
Office-31 | Image Recognition | 3 | Classification | M |
Office-Home | Image Recognition | 4 | Classification | M |
QM9 | Molecular Property Prediction | 11 (default) | Regression | S |
PAWS-X | Paraphrase Identification | 4 (default) | Classification | M |
-
Create a virtual environment
conda create -n libmtl python=3.8 conda activate libmtl pip install torch==1.8.0 torchvision==0.9.0 numpy==1.20
-
Clone the repository
git clone https://github.com/median-research-group/LibMTL.git
-
Install
LibMTL
cd LibMTL pip install -e .
We use the NYUv2 dataset as an example to show how to use LibMTL
.
The NYUv2 dataset we used is pre-processed by mtan. You can download this dataset here.
The complete training code for the NYUv2 dataset is provided in examples/nyu. The file train_nyu.py is the main file for training on the NYUv2 dataset.
You can find the command-line arguments by running the following command.
python train_nyu.py -h
For instance, running the following command will train a MTL model with EW and HPS on NYUv2 dataset.
python train_nyu.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step
More details is represented in Docs.
If you find LibMTL
useful for your research or development, please cite the following:
@article{LibMTL,
title={{LibMTL}: A Python Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={arXiv preprint arXiv:2203.14338},
year={2022}
}
LibMTL
is developed and maintained by Baijiong Lin.
If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to bj.lin.email@gmail.com
.
We would like to thank the authors that release the public repositories (listed alphabetically): CAGrad, dselect_k_moe, MultiObjectiveOptimization, mtan, nash-mtl, pytorch_geometric, and xtreme.
LibMTL
is released under the MIT license.