We add PyTorch APIs in APIs/
for more convenient multi-task model construction.
- Now, users could directly adopt VCNs in
layer_node.py
, including Conv2dNode and BN2dNode, when building a customized multi-task model (should be inherited frommtl_model.py
) in PyTorch. - An example can be found in
mobilenetv2.py
andExample_mobilenetv2.ipynb
. - The detailed documentation is online here.
This is the website for our paper "AutoMTL: A Programming Framework for Automating Efficient Multi-Task Learning". The arXiv version can be found here.
Multi-task learning (MTL) jointly learns a set of tasks. It is a promising approach to reduce the training and inference time and storage costs while improving prediction accuracy and generalization performance for many computer vision tasks. However, a major barrier preventing the widespread adoption of MTL is the lack of systematic support for developing compact multi-task models given a set of tasks. In this paper, we aim to remove the barrier by developing the first programming framework AutoMTL that automates MTL model development. AutoMTL takes as inputs an arbitrary backbone convolutional neural network and a set of tasks to learn, then automatically produce a multi-task model that achieves high accuracy and has small memory footprint simultaneously. As a programming framework, AutoMTL could facilitate the development of MTL-enabled computer vision applications and even further improve task performance.
Welcome to cite our work if you find it is helpful to your research.
@misc{zhang2021automtl,
title={AutoMTL: A Programming Framework for Automating Efficient Multi-Task Learning},
author={Lijun Zhang and Xiao Liu and Hui Guan},
year={2021},
eprint={2110.13076},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
You can build on your conda environment from the provided environment.yml
. Feel free to change the env name in the file.
conda env create -f environment.yml
We conducted experiments on three popular datasets in multi-task learning (MTL), CityScapes [1], NYUv2 [2], and Tiny-Taskonomy [3]. You can download the them here. For Tiny-Taskonomy, you will need to contact the authors directly. See their official website.
├── data
│ ├── dataloader
│ │ ├── *_dataloader.py
│ ├── heads
│ │ ├── pixel2pixel.py
│ ├── metrics
│ │ ├── pixel2pixel_loss/metrics.py
├── framework
│ ├── layer_containers.py
│ ├── base_node.py
│ ├── layer_node.py
│ ├── mtl_model.py
│ ├── trainer.py
├── models
│ ├── *.prototxt
├── utils
└── └── pytorch_to_caffe.py
Our code can be divided into three parts: code for data, code of AutoMTL, and others
-
For Data
- Dataloaders
*_dataloader.py
: For each dataset, we offer a corresponding PyTorch dataloader with a specific task variable. - Heads
pixel2pixel.py
: The ASPP head [4] is implemented for the pixel-to-pixel vision tasks. - Metrics
pixel2pixel_loss.py
andpixel2pixel_metrics.py
: For each task, it has its own criterion and metric.
- Dataloaders
-
AutoMTL
- Multi-Task Model Generator
mtl_model.py
: Transfer the given backbone model in the format of prototxt, and the task-specific model head dictionary to a multi-task supermodel. - Trainer Tools
trainer.py
: Meterialize a three-stage training pipeline to search out a good multi-task model for the given tasks.
- Multi-Task Model Generator
-
Others
- Input Backbone
*.prototxt
: Typical vision backbone models including Deeplab-ResNet34 [4], MobileNetV2, and MNasNet. - Transfer to Prototxt
pytorch_to_caffe.py
: If you define your own customized backbone model in PyTorch API, we also provide a tool to convert it to a prototxt file.
- Input Backbone
Note: Please refer to Example.ipynb
for more details.
Each task will have its own dataloader for both training and validation, task-specific criterion (loss), evaluation metric, and model head. Here we take CityScapes as an example.
tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1} # the number of classes in each task
You can also define your own dataloader, criterion, and evaluation metrics. Please refer to files in data/
to make sure your customized classes have the same output format as ours to fit for our framework.
trainDataloaderDict = {[]}
valDataloaderDict = {}
for task in tasks:
dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
trainDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)
dataset1 = CityScapes(dataroot, 'train1', task, crop_h=224, crop_w=224)
trainDataloaderDict[task].append(DataLoader(dataset1, 16, shuffle=True)) # for network param training
dataset2 = CityScapes(dataroot, 'train2', task, crop_h=224, crop_w=224)
trainDataloaderDict[task].append(DataLoader(dataset2, 16, shuffle=True)) # for policy param training
dataset = CityScapes(dataroot, 'test', task)
valDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)
criterionDict = {}
for task in tasks:
criterionDict[task] = CityScapesCriterions(task)
metricDict = {}
for task in tasks:
metricDict[task] = CityScapesMetrics(task)
headsDict = nn.ModuleDict() # must be nn.ModuleDict() instead of python dictionary
for task in tasks:
headsDict[task] = ASPPHeadNode(<feature_dim>, task_cls_num[task])
prototxt = 'models/deeplab_resnet34_adashare.prototxt' # can be any CNN model
mtlmodel = MTLModel(prototxt, headsDict)
Note: We currently support Conv2d, BatchNorm2d, Linear, ReLU, Droupout, MaxPool2d and AvgPool2d (including global pooling), elementwise operators (inclduing production, add, and max).
trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict)
trainer.pre_train(iters=<total_iter>, lr=<init_lr>, savePath=<save_path>)
loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005} # the weights for each task and the policy regularization term from the paper
trainer.alter_train_with_reg(iters=<total_iter>, policy_network_iters=<alter_iters>, policy_lr=<policy_lr>, network_lr=<network_lr>,
loss_lambda=loss_lambda, savePath=<save_path>)
Note: When training the policy and the model weights together, we alternatively train them for specified iters in policy_network_iters
.
sample_policy_dict = OrderedDict()
for task in tasks:
for name, policy in zip(name_list[task], policy_list[task]):
distribution = softmax(policy, axis=-1)
distribution /= sum(distribution)
choice = np.random.choice((0,1,2), p=distribution)
if choice == 0:
sample_policy_dict[name] = torch.tensor([1.0,0.0,0.0]).cuda()
elif choice == 1:
sample_policy_dict[name] = torch.tensor([0.0,1.0,0.0]).cuda()
elif choice == 2:
sample_policy_dict[name] = torch.tensor([0.0,0.0,1.0]).cuda()
Note: The policy-train stage only obtains a good policy distribution. Before conducting post-train, we should sample a certain policy from the distribution.
trainer.post_train(ters=<total_iter>, lr=<init_lr>,
loss_lambda=loss_lambda, savePath=<save_path>, reload=<sampled_policy>)
You can download fully-trained models for each dataset here.
mtlmodel.load_state_dict(torch.load(<model_name>))
trainer.validate('mtl', hard=True)
Note: The "hard" must be set to True when conducting inference since we don't want to have soft policy this time.
mtlmodel.load_state_dict(torch.load(<model_name>))
output = mtlmodel(x, task=<task_name>, hard=True)
[1] Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt. The cityscapes dataset for semantic urban scene understanding. CVPR, 3213-3223, 2016.
[2] Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob. Indoor segmentation and support inference from rgbd images. ECCV, 746-760, 2012.
[3] Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio. Taskonomy: Disentangling task transfer learning. CVPR, 3712-3722, 2018.
[4] Chen, Liang-Chieh and Papandreou, George and Kokkinos, Iasonas and Murphy, Kevin and Yuille, Alan L. Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. PAMI, 834-848, 2017.