Calculate Mean Squared Error using C++ SIMD

Calculate Mean Squared Error using C++ SIMD

The Mean Squared Error (MSE) is one of the most commonly used metrics in machine learning and statistical analysis. It measures the average squared difference between the predicted values and actual values. While a basic implementation of MSE can work for small datasets, optimizing the calculation using SIMD can dramatically improve performance.

The traditional implementation:

#include <iostream>
#include <vector>

float mse(const float *a, const float *b, const size_t n) {
    float sum = 0;
    for (size_t i = 0; i < n; ++i) {
        float diff = a[i] - b[i];
        sum += diff * diff;
    }

    return sum / (float) n;
}

int main() {
    std::vector<float> a = {
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
        11, 12, 13, 14, 15, 16, 17,
    };
    std::vector<float> b = {
        0.5, 1, 2.5, 3, 4.5, 5, 6.5, 7, 8.5,
        9, 10.5, 11, 12.5, 13, 14.5, 15, 16.5, 17,
    };

    float value = mse(a.data(), b.data(), a.size());
    std::cout << value;

    return 0;
}

In this code, the mse function computes the squared difference between elements in arrays, sums them up, and divides by the total number of elements to get the mean squared error. Output:

0.125

This method works well for small datasets, but becomes a performance bottleneck for larger data.

Here's the optimized implementation using AVX2:

#include <immintrin.h>

float mse(const float *a, const float *b, const size_t n) {
    __m256 vsum = _mm256_setzero_ps();

    size_t i = 0;
    for (; i + 8 <= n; i += 8) {
        __m256 va = _mm256_loadu_ps(&a[i]);
        __m256 vb = _mm256_loadu_ps(&b[i]);
        __m256 vdiff = _mm256_sub_ps(va, vb);
        vdiff = _mm256_mul_ps(vdiff, vdiff);
        vsum = _mm256_add_ps(vsum, vdiff);
    }

    __m128 bottom = _mm256_castps256_ps128(vsum);
    __m128 top = _mm256_extractf128_ps(vsum, 1);

    bottom = _mm_add_ps(bottom, top);
    bottom = _mm_hadd_ps(bottom, bottom);
    bottom = _mm_hadd_ps(bottom, bottom);

    float sum = _mm_cvtss_f32(bottom);
    for (; i < n; ++i) {
        float diff = a[i] - b[i];
        sum += diff * diff;
    }

    return sum / (float) n;
}

AVX2 code explanation:

  • _mm256_loadu_ps loads 8 floating-point values from arrays.
  • _mm256_sub_ps computes differences between corresponding elements.
  • _mm256_mul_ps calculates the squared error for each pair of values.
  • _mm256_add_ps used to accumulate results in the vsum register.

To convert the SIMD result back to a scalar, we first split the vsum into two halves and add them together, reducing the sum to a single value.

Remaining elements are handled with a scalar loop at the end.

Leave a Comment

Cancel reply

Your email address will not be published.