/multitask

Primary LanguagePythonMIT LicenseMIT

Multi-task splitter

A simple deep learning solution for multitask in PyTorch.

You can copy-paste the standalone file splitter.py or clone the repo and do pip install .
The main class is Splitter class is used as follows:

from torchmultitask.splitter import NormalizedMultiTaskSplitter

# set the weight of each task, no need to take into account the different scales of the loss function.
# typically 1.0 is a one-size-fits-all that avoids fine-tuning the loss weights (weights between 0.1 and 10 can be tried out, to prioritize one loss)
task_weight_dict = {"task_1": 1.0, "task_2": 1.0}
splitter = NormalizedMultiTaskSplitter(task_weight_dict)

for x in dataloader:
    x = base_model(x)
    x_dict = splitter(x)  # identity in the forward pass but normalizes and project gradients in the backward pass.  
    loss1 = loss_model1(x_dict["task_1"])
    loss2 = loss_model2(x_dict["task_2"])
    loss = loss1 + loss2
    loss.backward()
    ...  # can be used with any pytorch optimizer

How does it work

1) Gradient normalization: The first trick is to normalization gradients similarly as in Meta's multitask normalizer [3]. An important difference is that it's implemented here with a custom autograd.Funcion to override the backward pass. This enables splitting gradients anywhere in the computational graph unlike with Meta's implementation (here). In short each gradient is normalization using:

def backward(grads, m):
    # pseudo-code summary of GradientNormalizer(...) in gradient_normalizer.py
    # the NormalizedMultiTaskSplitter(...) in splitter.py does it on all tasks simultaneously to go faster
    m = beta * m + (1-beta) * grads.square().mean(0).sum() # momentum of the gradient norms, mean on batch dimension
    v = m / (1 - torch.pow(beta,t) # unbiased momentum formula
    return grads / v.sqrt().clip(min=epsilon) * loss_coeff # return the scalaed gradient

2) Gradient projection: Inspired by pcgrad [1,4] and this nice paper [2] we also tried to project the gradients onto each other. This is disabled by default because it slows down the training a bit (by ~20%, still faster that the pytorch implementation [3]) in some cases it might improve performance.

Results on Mulit-MNIST:

Two digits are plotted and the two tasks are to classify digits 1 and 2. This is similar to the toy multi-task problem considered in [2]. The dataloader is debugged from the Pytorch-PCGrad repo and we compare our results to their implementations of pcgrad.

MNIST result2

Demo on Multi MNIST: Two digits (left and right) are overlapping on an image. The image is processed with a LeNet ConvNet to produce features. The feature vector is sent into two MLPs to compute classification cross entropy losses for the left and right digit respectively.

MNIST result2
Figure 1. Left: Classification accuracy for the left digit (averaged over 3 independent runs), the x-axis is the epoch number. Right: Classification accuracy for the right digit.

MNIST result
Figure 2. Same as Figure 1, but without the re-weighted of the loss functions, there are already balanced and the normalization has a weaker effect.

References

Cite this repo:

@misc{bellec2023,
  author = {Bellec, G},
  title = {Multi-task splitter: A simple deep learning solution for multitask in PyTorch},
  year = {2023},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/guillaumeBellec/multitask}}
}

Related references:

[1] Gradient Surgery for Multi-Task Learning (pcgrad)
Tianhe Yu, Saurabh Kumar, Abhishek Gupta, Sergey Levine, Karol Hausman, Chelsea Finn

[2] Multi-Task Learning as Multi-Objective Optimization
Ozan Sener, Vladlen Koltun

[3] High Fidelity Neural Audio Compression
Alexandre DĂ©fossez, Jade Copet, Gabriel Synnaeve, Yossi Adi

[4] Unofficial pytorch pcgrad repository
github.com/WeiChengTseng/Pytorch-PCGrad.git