Print PyTorch Model Summary using torchinfo

Print PyTorch Model Summary using torchinfo

When working with complex PyTorch models, it's important to understand the model's structure, such as the number of parameters and the shapes of input and output on each layer. This information can help for debugging issues and optimizing the model. One of the ways to obtain a comprehensive summary of PyTorch model is by using the torchinfo package. This tutorial shows how to print PyTorch model summary using torchinfo.

Prepare environment

  • Install the following package using pip:
pip install torchinfo

Code

In the following code, we define a simple model which consists of a linear layer with input size 100 and output size 200, followed by a ReLU activation function, and finally, another linear layer with input size 200 and output size 10. Next, we set the batch size and random input data. The batch size is 1. It indicates that we are working with a single input sample. Finally, we call the summary function by passing the model, input data and column names which should be displayed in the output.

import torch
import torch.nn as nn
from torchinfo import summary

model = nn.Sequential(
    nn.Linear(100, 200),
    nn.ReLU(),
    nn.Linear(200, 10)
)

batch_size = 1
input_data = torch.randn(batch_size, 100)
summary(model, input_data=input_data, col_names=['input_size', 'output_size', 'num_params'])

Here's an example of the output that you might see when running the provided code:

===================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
===================================================================================================================
Sequential                               [1, 100]                  [1, 10]                   --
├─Linear: 1-1                            [1, 100]                  [1, 200]                  20,200
├─ReLU: 1-2                              [1, 200]                  [1, 200]                  --
├─Linear: 1-3                            [1, 200]                  [1, 10]                   2,010
===================================================================================================================
Total params: 22,210
Trainable params: 22,210
Non-trainable params: 0
Total mult-adds (M): 0.02
===================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.09
Estimated Total Size (MB): 0.09
===================================================================================================================

Leave a Comment

Cancel reply

Your email address will not be published.