Count Non-Zero Elements in Array using C++ SIMD

Count Non-Zero Elements in Array using C++ SIMD

Counting non-zero elements in an array is a common task in data analysis, signal processing, and machine learning. For large datasets, performing this operation efficiently is crucial. Using SIMD allows us to perform the comparison on multiple array elements simultaneously, making this process faster than the standard scalar approach.

Here's the basic implementation:

#include <iostream>
#include <vector>

size_t countNonZero(const float *data, const size_t n) {
    size_t count = 0;
    for (size_t i = 0; i < n; ++i) {
        if (data[i] != 0.0f) {
            ++count;
        }
    }

    return count;
}

int main() {
    std::vector<float> a = {
        -2.1, 0, 4.7, 9.8, -7.2, 0, 3.3, 0, 2.1,
        15, 0, 8.2, -8.3, 0, -4.2, 6.1, 9.9, 0,
    };

    auto value = countNonZero(a.data(), a.size());
    std::cout << value;

    return 0;
}

In the code, the function iterates over each element of the array, checking if the value is non-zero. If an element is non-zero, the variable is incremented by one. The code outputs 12.

Here's the optimized implementation using AVX2:

#include <immintrin.h>

size_t countNonZero(const float *data, const size_t n) {
    size_t count = 0;
    __m256 zero = _mm256_setzero_ps();

    size_t i = 0;
    for (; i + 8 <= n; i += 8) {
        __m256 vdata = _mm256_loadu_ps(&data[i]);
        __m256 cmp = _mm256_cmp_ps(vdata, zero, _CMP_NEQ_OQ);
        size_t mask = _mm256_movemask_ps(cmp);
        count += _mm_popcnt_u32(mask);
    }

    for (; i < n; ++i) {
        if (data[i] != 0.0f) {
            ++count;
        }
    }

    return count;
}

Here's how the AVX2 version works:

  • _mm256_setzero_ps creates a vector of 8 floating-point zeros to use in the comparison.
  • _mm256_loadu_ps loads 8 elements from the array.
  • _mm256_cmp_ps checks each element to see if it is non-zero, returning a result to reflect non-zero elements.
  • _mm256_movemask_ps converts the comparison result into a bit mask, where each bit represents whether an element was non-zero.
  • _mm_popcnt_u32 counts the number of set bits in the mask, corresponding to the number of non-zero elements in this set of eight.

Any remaining elements are handled in a fallback loop at the end.

Leave a Comment

Cancel reply

Your email address will not be published.