When working with large datasets, calculating the sum of array elements can be time-consuming if processed one element at a time. By using SIMD instructions, we can process multiple elements simultaneously, allowing for faster computations by taking advantage of parallelism.
The straightforward implementation:
#include <iostream>
#include <vector>
float calculateSum(const float *data, const size_t n) {
    float sum = 0.0f;
    for (size_t i = 0; i < n; ++i) {
        sum += data[i];
    }
    return sum;
}
int main() {
    std::vector<float> a = {
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
    };
    auto value = calculateSum(a.data(), a.size());
    std::cout << value;
    return 0;
}The code iterates through each element of the array individually, accumulating the total sum in a variable. The code output is 171.
Code optimization with AVX2:
#include <immintrin.h>
float calculateSum(const float *data, const size_t n) {
    __m256 vsum = _mm256_setzero_ps();
    size_t i = 0;
    for (; i + 8 <= n; i += 8) {
        __m256 vdata = _mm256_loadu_ps(&data[i]);
        vsum = _mm256_add_ps(vsum, vdata);
    }
    __m128 bottom = _mm256_castps256_ps128(vsum);
    __m128 top = _mm256_extractf128_ps(vsum, 1);
    bottom = _mm_add_ps(bottom, top);
    bottom = _mm_hadd_ps(bottom, bottom);
    bottom = _mm_hadd_ps(bottom, bottom);
    float sum = _mm_cvtss_f32(bottom);
    for (; i < n; ++i) {
        sum += data[i];
    }
    return sum;
}There are three main parts:
- Vectorized Sum Accumulation
The _mm256_loadu_ps function loads 8 elements from the array. Each loaded chunk is added to vsum using _mm256_add_ps, accumulating the sum in parallel across all loaded values.
- Horizontal Summation
After looping, vsum contains a partial sum in each of its 8 lanes. To obtain a single scalar value, we first split the 256-bit vector into two 128-bit registers (top and bottom). We then use _mm_hadd_ps to add the elements within these registers, performing horizontal addition to combine values.
- Fallback for Remaining Elements
If the array length isn't a multiple of 8, any leftover elements are summed in a final loop to ensure no values are missed.
 
             
                         
                         
                        
Leave a Comment
Cancel reply