When working with complex PyTorch models, it's important to understand the model's structure, such as the number of parameters and the shapes of input and output on each layer. This information can help for debugging issues and optimizing the model. One of the ways to obtain a comprehensive summary of PyTorch model is by using the torchinfo
package. This tutorial shows how to print PyTorch model summary using torchinfo
.
Prepare environment
- Install the following package using
pip
:
pip install torchinfo
Code
In the following code, we define a simple model which consists of a linear layer with input size 100 and output size 200, followed by a ReLU activation function, and finally, another linear layer with input size 200 and output size 10. Next, we set the batch size and random input data. The batch size is 1. It indicates that we are working with a single input sample. Finally, we call the summary
function by passing the model, input data and column names which should be displayed in the output.
import torch
import torch.nn as nn
from torchinfo import summary
model = nn.Sequential(
nn.Linear(100, 200),
nn.ReLU(),
nn.Linear(200, 10)
)
batch_size = 1
input_data = torch.randn(batch_size, 100)
summary(model, input_data=input_data, col_names=['input_size', 'output_size', 'num_params'])
Here's an example of the output that you might see when running the provided code:
===================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
===================================================================================================================
Sequential [1, 100] [1, 10] --
├─Linear: 1-1 [1, 100] [1, 200] 20,200
├─ReLU: 1-2 [1, 200] [1, 200] --
├─Linear: 1-3 [1, 200] [1, 10] 2,010
===================================================================================================================
Total params: 22,210
Trainable params: 22,210
Non-trainable params: 0
Total mult-adds (M): 0.02
===================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.09
Estimated Total Size (MB): 0.09
===================================================================================================================
Leave a Comment
Cancel reply