Why Don't We Invoke model.forward() in PyTorch?

Instead, we always invoke model(). Here's why.

In PyTorch, the forward pass is implemented in the forward() method, as demonstrated below:

Here, have you ever wondered that when we want to run the forward pass, we rarely invoke this forward() method:

Instead, we always invoke the model (a class object) as demonstrated below, as it was a function:

We can also verify that model is a class object:

How can a class object be invoked like a function and what are we missing here?

Let’s understand!

A simple example

Consider we want to evaluate the following quadratic:

One way is to define a method that accepts the input and returns the value of the quadratic, as shown below:

Of course, there is nothing wrong with this approach.

But there is one smart and elegant way of doing this in Python.

Instead of explicitly invoking a method, we can define the __call__() magic method.

This magic method allows you to define the behavior of the class object when it is invoked like a function (like this: object()).

Let’s rename the evaluate() method to __call__().

As a result, we can now invoke the class object directly instead of explicitly invoking a method.

This can have many advantages. For instance:

  • It allows us to implement objects that can be used in a flexible and intuitive way.

  • It allows us to use a class object in contexts where a callable object is expected — using a class object as a decorator, for instance.

What is callable?

In Python, a callable is any object that can be called using parentheses and may return a value.

For instance, a function is a callable object (one that can be called/invoked):

Coming back to PyTorch

This is what happens when we build deep learning models with PyTorch.

For instance, consider the PyTorch class again:

As you may have already guessed, the model object can be invoked because all PyTorch classes implicitly declare the __call__() method themselves.

Within that __call__() method, they invoke the user-defined forward pass.

A simplified version of this is depicted below:

  • PyTorch itself adds the __call__() method.

  • The __call__() method invokes the user-defined forward() method.

This way, Python gets to know that the model object can be invoked like a function — model().

In fact, we can verify that we will get the same output no matter which way we run the forward pass:

Cool Pythonistic stuff, isn’t it?

These things revolve around good and elegant object-oriented programming practices.

We covered such advanced OOP stuff in detail in a recent deep dive here if you wish to level up your OOP skills: Object-Oriented Programming with Python for Data Scientists.

Also, if you want to get really good at Python OOP, learn about Python Descriptors.

I find them to be massively helpful in reducing work and code redundancy while also making the entire implementation much more elegant.

We covered it in this newsletter here: Define Elegant and Concise Python Classes with Descriptors.

👉 Over to you: What are some other cool Python OOP tricks?

Thanks for reading!

Are you preparing for ML/DS interviews or want to upskill at your current job?

Every week, I publish in-depth ML dives. The topics align with the practical skills that typical ML/DS roles demand.

Join below to unlock all full articles:

Here are some of the top articles:

Join below to unlock all full articles:

👉 If you love reading this newsletter, share it with friends!

👉 Tell the world what makes this newsletter special for you by leaving a review here :)

Reply

or to participate.