Convert TensorFlow 2 Model to ONNX Format

Convert TensorFlow 2 Model to ONNX Format

ONNX is a format for representing machine learning models. ONNX Runtime can be used to run inference using a model represented in ONNX format.

This tutorial demonstrates how to convert TensorFlow 2 model to model represented in ONNX format.

Prepare environment

Before starting, make sure you have installed tensorflow package:

pip install tensorflow

You also need to install tf2onnx package for model conversion:

pip install tf2onnx

Install onnxruntime package for performing inference:

pip install onnxruntime

Model training

We have created a model to solve a simple linear regression problem. The model has one layer and predicts the value of y for the given value of x. Relationship between x and y can be represented as y = 2 * x + 1. A trained model is saved in SavedModel format.

train.py

from tensorflow import keras
import numpy as np

xs = np.array([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0, 9.0], dtype=float)

model = keras.Sequential([
    keras.layers.Dense(units=1, input_shape=[1])
])

model.compile(optimizer='sgd', loss='mean_squared_error')

model.fit(xs, ys, epochs=400)

model.save('model')

Model conversion

First, we load a model from SavedModel format. Using the tf2onnx package, TensorFlow 2 model is converted to ONNX format and saved to .onnx file.

convert.py

import tensorflow as tf
import tf2onnx

model = tf.keras.models.load_model('model')

tf2onnx.convert.from_keras(model, output_path='model.onnx')

Inference

Now, the model presented in ONNX format can be used with ONNX Runtime to predict a value of y for a previously unknown value of x. We load a model, perform inference, and print predicted y.

test.py

import onnxruntime as rt
import numpy as np

session = rt.InferenceSession('model.onnx')

x = np.array([[15.0]], dtype=np.float32)

inputDetails = session.get_inputs()
y = session.run(None, {inputDetails[0].name: x})
print(y[0])

In our case, model returned y equal to 31.00175 when x is 15.0. The result can be verified as follows:

y = 2 * x + 1 = 2 * 15 + 1 = 31

Leave a Comment

Cancel reply

Your email address will not be published.