The Root Mean Squared Error (RMSE) is a common metric used to measure the difference between two arrays, such as observed and predicted values. It involves summing the squared differences, dividing by the number of elements, and taking the square root of the result. By leveraging SIMD instructions, we can compute RMSE more efficiently by processing multiple elements simultaneously.
Here's the basic scalar implementation:
#include <iostream>
#include <vector>
#include <cmath>
float rmse(const float *a, const float *b, const size_t n) {
float sum = 0;
for (size_t i = 0; i < n; ++i) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return std::sqrt(sum / (float) n);
}
int main() {
std::vector<float> a = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17,
};
std::vector<float> b = {
0.5, 1, 2.5, 3, 4.5, 5, 6.5, 7, 8.5,
9, 10.5, 11, 12.5, 13, 14.5, 15, 16.5, 17,
};
float value = rmse(a.data(), b.data(), a.size());
std::cout << value;
return 0;
}
This approach calculates squared differences one element at a time and sums them. Finally, it computes the square root of the mean squared sum. Output:
0.353553
Here's the optimized implementation using AVX2:
#include <immintrin.h>
float rmse(const float *a, const float *b, const size_t n) {
__m256 vsum = _mm256_setzero_ps();
size_t i = 0;
for (; i + 8 <= n; i += 8) {
__m256 va = _mm256_loadu_ps(&a[i]);
__m256 vb = _mm256_loadu_ps(&b[i]);
__m256 vdiff = _mm256_sub_ps(va, vb);
vdiff = _mm256_mul_ps(vdiff, vdiff);
vsum = _mm256_add_ps(vsum, vdiff);
}
__m128 bottom = _mm256_castps256_ps128(vsum);
__m128 top = _mm256_extractf128_ps(vsum, 1);
bottom = _mm_add_ps(bottom, top);
bottom = _mm_hadd_ps(bottom, bottom);
bottom = _mm_hadd_ps(bottom, bottom);
float sum = _mm_cvtss_f32(bottom);
for (; i < n; ++i) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return std::sqrt(sum / (float) n);
}
AVX2 code explanation:
_mm256_loadu_ps
loads 8 floating-point values from the input arrays._mm256_sub_ps
computes the differences between the corresponding elements._mm256_mul_ps
squares each element of the difference vector to calculate the squared error._mm256_add_ps
accumulates the squared differences into thevsum
register.
The partial sum in vsum
is split into two halves. These are added together with _mm_add_ps
, and further horizontally summed using _mm_hadd_ps
to reduce to a single scalar value.
Remaining elements are handled with a scalar loop and the square root of the mean squared sum is computed to produce the RMSE.
Leave a Comment
Cancel reply