Calculate Sum of Array Elements using C++ SIMD

Calculate Sum of Array Elements using C++ SIMD

When working with large datasets, calculating the sum of array elements can be time-consuming if processed one element at a time. By using SIMD instructions, we can process multiple elements simultaneously, allowing for faster computations by taking advantage of parallelism.

The straightforward implementation:

#include <iostream>
#include <vector>

float calculateSum(const float *data, const size_t n) {
    float sum = 0.0f;
    for (size_t i = 0; i < n; ++i) {
        sum += data[i];
    }

    return sum;
}

int main() {
    std::vector<float> a = {
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
    };

    auto value = calculateSum(a.data(), a.size());
    std::cout << value;

    return 0;
}

The code iterates through each element of the array individually, accumulating the total sum in a variable. The code output is 171.

Code optimization with AVX2:

#include <immintrin.h>

float calculateSum(const float *data, const size_t n) {
    __m256 vsum = _mm256_setzero_ps();

    size_t i = 0;
    for (; i + 8 <= n; i += 8) {
        __m256 vdata = _mm256_loadu_ps(&data[i]);
        vsum = _mm256_add_ps(vsum, vdata);
    }

    __m128 bottom = _mm256_castps256_ps128(vsum);
    __m128 top = _mm256_extractf128_ps(vsum, 1);

    bottom = _mm_add_ps(bottom, top);
    bottom = _mm_hadd_ps(bottom, bottom);
    bottom = _mm_hadd_ps(bottom, bottom);

    float sum = _mm_cvtss_f32(bottom);
    for (; i < n; ++i) {
        sum += data[i];
    }

    return sum;
}

There are three main parts:

  • Vectorized Sum Accumulation

The _mm256_loadu_ps function loads 8 elements from the array. Each loaded chunk is added to vsum using _mm256_add_ps, accumulating the sum in parallel across all loaded values.

  • Horizontal Summation

After looping, vsum contains a partial sum in each of its 8 lanes. To obtain a single scalar value, we first split the 256-bit vector into two 128-bit registers (top and bottom). We then use _mm_hadd_ps to add the elements within these registers, performing horizontal addition to combine values.

  • Fallback for Remaining Elements

If the array length isn't a multiple of 8, any leftover elements are summed in a final loop to ensure no values are missed.

Leave a Comment

Cancel reply

Your email address will not be published.