Skorch: The Power of PyTorch Combined with The Elegance of Sklearn

The best of both worlds.

PyTorch has always been my go-to library for building any deep learning model.

However, one thing I particularly dislike about PyTorch is manually writing its long training loops, which go as follows:

  • For every epoch:

    • For every batch:

      • Run the forward pass

      • Calculate the loss

      • Compute the gradients

      • Run backpropagation

    • Compute epoch accuracy

    • Print the accuracy, loss, etc.

That’s too much work and code, isn’t it?

Skorch immensely simplifies training neural networks with PyTorch.

Skorch (Sklearn + PyTorch) is an open-source library that provides full Scikit-learn compatibility to PyTorch.

This means we can train PyTorch models in a way similar to Scikit-learn, using functions such as fit(), predict(), score(), etc.

Isn’t that cool?

Let’s see how to use it!

First, we define our PyTorch neural network as we usually would (no change here):

Make sure you have installed Skorch: pip install skorch.

As we are creating a classifier, we import and create an object of Skorch’s NeuralNetClassifier class.

There’s a class for regression models as well: NeuralNetRegressor.

  • The first argument is the PyTorch model class (MyClassifier).

  • Next, we specify training hyperparameters like learning rate, batch size, etc.

  • We also specify the optimizer and loss function as a parameter.

Done!

Now, we can directly invoke fit() method to train the model as follows:

As shown above, Skorch automatically prints all training metrics for us.

What’s more, we can also call the predict() and score() methods to generate predictions and output accuracy, respectively.

Isn’t that simple, cool, and elegant?

PyTorch lightning is yet another library, which further supercharges the whole PyTorch framework, and comes with built-in plug-and-play support for mixed precision training, multi-GPU or TPU training, logging, profiling, reducing boilerplate code, and more.

👉 Over to you: Are you aware of any other utility libraries to simplify model training? Let me know :)

Whenever you are ready, here’s one more way I can help you:

Every week, I publish 1-2 in-depth deep dives (typically 20+ mins long). Here are some of the latest ones that you will surely like:

To receive all full articles and support the Daily Dose of Data Science, consider subscribing:

👉 If you love reading this newsletter, feel free to share it with friends!

👉 Tell the world what makes this newsletter special for you by leaving a review here :)

Reply

or to participate.