The Mean Squared Error (MSE) is one of the most commonly used metrics in machine learning and statistical analysis. It measures the average squared difference between the predicted values and actual values. While a basic implementation of MSE can work for small datasets, optimizing the calculation using SIMD can dramatically improve performance.
The traditional implementation:
#include <iostream>
#include <vector>
float mse(const float *a, const float *b, const size_t n) {
float sum = 0;
for (size_t i = 0; i < n; ++i) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return sum / (float) n;
}
int main() {
std::vector<float> a = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17,
};
std::vector<float> b = {
0.5, 1, 2.5, 3, 4.5, 5, 6.5, 7, 8.5,
9, 10.5, 11, 12.5, 13, 14.5, 15, 16.5, 17,
};
float value = mse(a.data(), b.data(), a.size());
std::cout << value;
return 0;
}
In this code, the mse
function computes the squared difference between elements in arrays, sums them up, and divides by the total number of elements to get the mean squared error. Output:
0.125
This method works well for small datasets, but becomes a performance bottleneck for larger data.
Here's the optimized implementation using AVX2:
#include <immintrin.h>
float mse(const float *a, const float *b, const size_t n) {
__m256 vsum = _mm256_setzero_ps();
size_t i = 0;
for (; i + 8 <= n; i += 8) {
__m256 va = _mm256_loadu_ps(&a[i]);
__m256 vb = _mm256_loadu_ps(&b[i]);
__m256 vdiff = _mm256_sub_ps(va, vb);
vdiff = _mm256_mul_ps(vdiff, vdiff);
vsum = _mm256_add_ps(vsum, vdiff);
}
__m128 bottom = _mm256_castps256_ps128(vsum);
__m128 top = _mm256_extractf128_ps(vsum, 1);
bottom = _mm_add_ps(bottom, top);
bottom = _mm_hadd_ps(bottom, bottom);
bottom = _mm_hadd_ps(bottom, bottom);
float sum = _mm_cvtss_f32(bottom);
for (; i < n; ++i) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return sum / (float) n;
}
AVX2 code explanation:
_mm256_loadu_ps
loads 8 floating-point values from arrays._mm256_sub_ps
computes differences between corresponding elements._mm256_mul_ps
calculates the squared error for each pair of values._mm256_add_ps
used to accumulate results in thevsum
register.
To convert the SIMD result back to a scalar, we first split the vsum
into two halves and add them together, reducing the sum to a single value.
Remaining elements are handled with a scalar loop at the end.
Leave a Comment
Cancel reply