The Mean Squared Error (MSE) is one of the most commonly used metrics in machine learning and statistical analysis. It measures the average squared difference between the predicted values and actual values. While a basic implementation of MSE can work for small datasets, optimizing the calculation using SIMD can dramatically improve performance.

The traditional implementation:

```
#include <iostream>
#include <vector>
float mse(const float *a, const float *b, const size_t n) {
float sum = 0;
for (size_t i = 0; i < n; ++i) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return sum / (float) n;
}
int main() {
std::vector<float> a = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17,
};
std::vector<float> b = {
0.5, 1, 2.5, 3, 4.5, 5, 6.5, 7, 8.5,
9, 10.5, 11, 12.5, 13, 14.5, 15, 16.5, 17,
};
float value = mse(a.data(), b.data(), a.size());
std::cout << value;
return 0;
}
```

In this code, the `mse`

function computes the squared difference between elements in arrays, sums them up, and divides by the total number of elements to get the mean squared error. Output:

`0.125`

This method works well for small datasets, but becomes a performance bottleneck for larger data.

Here's the optimized implementation using AVX2:

```
#include <immintrin.h>
float mse(const float *a, const float *b, const size_t n) {
__m256 vsum = _mm256_setzero_ps();
size_t i = 0;
for (; i + 8 <= n; i += 8) {
__m256 va = _mm256_loadu_ps(&a[i]);
__m256 vb = _mm256_loadu_ps(&b[i]);
__m256 vdiff = _mm256_sub_ps(va, vb);
vdiff = _mm256_mul_ps(vdiff, vdiff);
vsum = _mm256_add_ps(vsum, vdiff);
}
__m128 bottom = _mm256_castps256_ps128(vsum);
__m128 top = _mm256_extractf128_ps(vsum, 1);
bottom = _mm_add_ps(bottom, top);
bottom = _mm_hadd_ps(bottom, bottom);
bottom = _mm_hadd_ps(bottom, bottom);
float sum = _mm_cvtss_f32(bottom);
for (; i < n; ++i) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return sum / (float) n;
}
```

AVX2 code explanation:

`_mm256_loadu_ps`

loads 8 floating-point values from arrays.`_mm256_sub_ps`

computes differences between corresponding elements.`_mm256_mul_ps`

calculates the squared error for each pair of values.`_mm256_add_ps`

used to accumulate results in the`vsum`

register.

To convert the SIMD result back to a scalar, we first split the `vsum`

into two halves and add them together, reducing the sum to a single value.

Remaining elements are handled with a scalar loop at the end.

## Leave a Comment

Cancel reply