Building Multi-task Learning Models

A practical guide in PyTorch.

A couple of days back, we discussed four critical model training paradigms used in training many real-world ML models.

Here’s the visual from that post for a quick recap:

Consider multi-task learning (MTL)

Do you know how such models are trained?

I think this is a great topic to cover because typically, most ML models are trained on one task.

As a result, many struggle to intuitively understand how a model can be trained on multiple tasks simultaneously.

So let’s discuss this today!

To reiterate, in MTL, the network has a few shared layers and task-specific segments.

During backpropagation, gradients are accumulated from all branches, as depicted in the animation below:

Let’s take a simple example to understand its implementation.

Consider we want our model to take a real value (x) as input and generate two outputs:

  • sin(x)

  • cos(x)

This can be formulated as an MTL problem.

First, we define our model class using PyTorch.

As demonstrated above:

  • We have some fully connected layers in self.model → These are the shared layers.

  • Furthermore, we have the output-specific layers to predict sin(x) and cos(x).

This network architecture can be visually depicted as follows:

Next, let’s define the forward pass in the class above:

  • First, we pass the input through the shared layers (self.model).

  • The output of the shared layers is passed through the sin and cos branches.

  • We return the output from both branches.

We are almost done.

The final part of this implementation is to train the model.

Let’s use mean squared error as the loss function.

The training loop is implemented below:

  • We pass the input data through the model.

  • It returns two outputs, one from each segment of the network.

  • We compute the branch-specific loss values (loss1 and loss2) using true predictions.

  • We add the two loss values to get the total loss for the network.

  • Finally, we run the backward pass.

Done!

With this, we have trained our MTL model.

Also, we get a decreasing loss, which depicts that the model is being trained.

And that’s how we train an MTL model.

That was simple, wasn’t it?

You can extend the same idea to build any MTL model of your choice.

Do remember that building an MTL model on unrelated tasks will not produce good results.

Thus, “task-relatedness” is a critical component of all MTL models because of the shared layers.

Also, it is NOT necessary that every task must equally contribute to the entire network’s loss.

We may assign weights to each task as well, as depicted below:

The weights could be based on task importance.

Or…

At times, I also use dynamic task weights, which could be inversely proportional to the validation accuracy achieved on that task.

My rationale behind this technique is that in an MTL setting, some tasks can be easy while others can be difficult.

If the model achieves high accuracy on one task during training, we can safely reduce its loss contribution so that the model focuses more on the second task.

This makes intuitive sense as well.

You can download the notebook for today’s post here: Multi-task learning notebook.

👉 Over to you: What could be some other techniques to aggregate loss values of different tasks?

Extended piece #1

Despite rigorously testing an ML model locally (on validation and test sets), it could be a terrible idea to instantly replace the previous model with the new model.

A more reliable strategy is to test the model in production (yes, on real-world incoming data).

While this might sound risky, ML teams do it all the time, and it isn’t that complicated.

Extended piece #2

Businesses have more data than ever before.

Traditional single-node model training just doesn’t work because one cannot wait months to train a model.

Distributed (or multi-GPU) training is one of the most essential ways to address this.

In fact, if you look at job descriptions for Applied ML or ML engineer roles on LinkedIn, most of them demand skills like the ability to train models on large datasets:

We covered some core technicalities behind multi-GPU training, how it works under the hood, and implementation details here: A Beginner-friendly Guide to Multi-GPU Model Training.

Are you preparing for ML/DS interviews or want to upskill at your current job?

Every week, I publish in-depth ML dives. The topics align with the practical skills that typical ML/DS roles demand.

Join below to unlock all full articles:

👉 If you love reading this newsletter, share it with friends!

👉 Tell the world what makes this newsletter special for you by leaving a review here :)

Reply

or to participate.