Simple Linear Regression using MLP and Rubix ML

Simple Linear Regression using MLP and Rubix ML

Simple linear regression is a statistical method that is used to analyze the relationship between two continuous variables:

  • x - independent variable also known as explanatory or predictor.
  • y - dependent variable also known as response or outcome.

Assume that we have the following sets of numbers:


We can describe relationship between x and y variables by formula y = 2 * x + 1. Relationship can be expressed in a graphical format as well:

Simple Linear Regression

This tutorial presents an example how to create and train a model that allows to predict the value of y for the given value of x. We will use multilayer perceptron (MLP) and Rubix ML library.

Add Rubix ML library to composer.json file:

"require": {
    "rubix/ml": "^1.0"

And then run command to install:

composer install

MLP has one hidden layer with one neuron. Stochastic gradient descent (SGD) is used as optimizer to update the network parameters. We have chosen mean squared error (MSE) for the loss function to compute error between the given target output and the actual output of a network.

Arrays of x and y values are used to train the network. We use 400 epochs.


use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\NeuralNet\Optimizers\Stochastic;
use Rubix\ML\Regressors\MLPRegressor;
use Rubix\ML\NeuralNet\CostFunctions\LeastSquares;
use Rubix\ML\NeuralNet\Layers\Dense;

require_once __DIR__.'/vendor/autoload.php';

$xs = [[-2.0], [-1.0], [0.0], [1.0], [2.0], [3.0], [4.0]];
$ys = [-3.0, -1.0, 1.0, 3.0, 5.0, 7.0, 9.0];

$model = new MLPRegressor(
    [new Dense(1)],
    batchSize: 1,
    optimizer: new Stochastic(),
    epochs: 400,
    costFn: new LeastSquares()

$model->train(new Labeled($xs, $ys));

$x = 15.0;
$y = $model->predict(new Unlabeled([$x]));
echo $y[0];

After training, we try to predict a value of y for a previously unknown value of x. In our case, if x is 15.0, then the trained model returns that y is 31.073727. We can verify:

y = 2 * x + 1 = 2 * 15 + 1 = 31

Leave a Comment

Cancel reply

Your email address will not be published.