Multiple Linear Regression using TensorFlow 2

Multiple linear regression (MLR) is a statistical method that uses two or more independent variables to predict the value of a dependent variable. MLR is like a simple linear regression, but it use multiple independent variables instead of one.

RegressionIndependent variablesDependent variables
Simple linear regression11
Multiple linear regression>= 21

Let’s say we have three independent variables x1, x2 and x3 and dependent variable y:

x1-2-101234
x21234567
x3-5-4-3-2-101
y-60612182430

Relationship between these variables are represented by formula y = 2 * x1 + 3 * x2 + x3.

This tutorial provides example how to create and train a model which predicts the value of y for the given values of x1, x2 and x3. We will use TensorFlow 2.

Using pip package manager install tensorflow from the command line.

pip install tensorflow

In order to train the model we declare an arrays – x1s, x2s, x3s and y. Inputs for the model should be presented in the single array. So we use stack method to join x1s, x2s and x3s arrays along a new axis.

Model has one layer with three inputs and one output. Model is compiled using MSE loss function and SGD optimizer. We use 400 epochs to train the model.

from tensorflow import keras
import numpy as np

x1s = np.array([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
x2s = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=float)
x3s = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0], dtype=float)
x123s = np.stack((x1s, x2s, x3s), 1)
ys = np.array([-6.0,  0.0,  6.0, 12.0, 18.0, 24.0, 30.0], dtype=float)

model = keras.Sequential([
    keras.layers.Dense(units=1, input_shape=[3])
])

model.compile(optimizer='sgd', loss='mean_squared_error')

model.fit(x123s, ys, epochs=400)

x1 = 16.0
x2 = 19.0
x3 = 12.0
x123 = [x1, x2, x3]
y = model.predict([x123])
print(y[0])

After training we predict the value of y for the given values of x1, x2 and x3. Model returns that y is 100.96393. We can verify by calculating:

y = 2 * x1 + 3 * x2 + x3 = 2 * 16 + 3 * 19 + 12 = 101

Leave a Comment

Your email address will not be published. Required fields are marked *