Convert PyTorch Model to ONNX Format

Convert PyTorch Model to ONNX Format

ONNX (Open Neural Network Exchange) provides a standardized format for representing machine learning models. Using ONNX Runtime, you can run inference on models stored in this format regardless of the original framework. This tutorial explains how to convert PyTorch model to ONNX format.

Prepare environment

Ensure the torch package is installed before you begin:

pip install torch

You also need to install the onnxscript package for model conversion:

pip install onnxscript

Install the onnxruntime package to perform inference:

pip install onnxruntime

Model training

We will train a simple linear regression model to learn the relationship y = 2 * x + 1. The model consists of a single linear layer with one input and one output neuron. After training, we save the model parameters.

import torch

xs = torch.Tensor([[-2.0], [-1.0], [0.0], [1.0], [2.0], [3.0], [4.0]])
ys = torch.Tensor([[-3.0], [-1.0], [1.0], [3.0], [5.0], [7.0], [9.0]])

model = torch.nn.Sequential(
    torch.nn.Linear(1, 1)
)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

epochs=400
for epoch in range(epochs):
    optimizer.zero_grad()

    ys_pred = model(xs)

    loss = loss_fn(ys_pred, ys)
    loss.backward()
    optimizer.step()

torch.save(model.state_dict(), 'model.pt')

Model conversion

Once the model is trained, it can be exported to ONNX format, making it framework-independent and usable outside PyTorch. First, we recreate the model structure, load the saved parameters, and provide a dummy input tensor for tracing. We convert PyTorch model to ONNX format, supporting dynamic batch sizes and verifying that the exported model produces the same outputs as the original PyTorch model.

The external_data parameter determines how the model's data is stored: when set to False, all weights and parameters are saved within a single .onnx file, whereas setting it to True splits large models, storing the weights in separate external files.

import torch

model = torch.nn.Sequential(
    torch.nn.Linear(1, 1)
)

model.load_state_dict(torch.load('model.pt'))

input_tensor = torch.randn(1, 1)
torch.onnx.export(
    model,
    (input_tensor,),
    'model.onnx',
    verbose=False,
    dynamo=True,
    external_data=False,
    verify=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_shapes=({0: torch.export.Dim('batch')},)
)

Inference

After exporting, the ONNX model can be loaded using ONNX Runtime for inference to predict a value of y for a previously unknown value of x. We load the model, run it on the input data, and display the predicted value of y.

import onnxruntime as rt
import numpy as np

session = rt.InferenceSession('model.onnx')

x = np.array([[15.0]], dtype=np.float32)

inputDetails = session.get_inputs()
y = session.run(None, {inputDetails[0].name: x})
print(y[0])

In this example, the model predicts a value of approximately 31.00213 for x equal 15.0. The result can be confirmed using the equation:

y = 2 * x + 1 = 2 * 15 + 1 = 31

Leave a Comment

Cancel reply

Your email address will not be published.