Skip to content Skip to sidebar Skip to footer

Count The Number Of Non Zero Values In A Numpy Array In Numba

Very simple. I am trying to count the number of non-zero values in an array in NumPy jit compiled with Numba (njit()). The following I've tried is not allowed by Numba. a[a != 0].

Solution 1:

You may also consider, well, counting the nonzero values:

import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

I know it seems wrong, but bear with me:

import numpy as np
import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

@nb.njit()
def count_len_nonzero(a):
    return len(np.nonzero(a)[0])

@nb.njit()
def count_sum_neq_zero(a):
    return (a != 0).sum()

np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c

%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It is in fact faster than np.count_nonzero, which can get quite slow for some reason:

%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Solution 2:

In case you need it really fast for large arrays you could even use numbas prange to process the count in parallel (for small arrays it will be slower due to the parallel-processing overhead).

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

Note that when you use numba you normally want to write out your loops because that's what numba is really very good at optimizing.

I actually timed it against the other solutions mentioned here (using my Python module simple_benchmark):

enter image description here

Code to reproduce:

import numpy as np
from numba import njit, prange

@njit
def n_nonzero(a):
    return a[a != 0].size

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

@njit() 
def methodB(a): 
    return (a!=0).sum()

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

@njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

from simple_benchmark import benchmark

args = {}
for exp in range(2, 20):
    size = 2**exp
    arr = np.random.random(size)
    arr[arr < 0.3] = 0.0
    args[size] = arr

b = benchmark(
    funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
    arguments=args,
    argument_name='array size',
    warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)

Solution 3:

You can use np.nonzero and induce the length of it:

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

count_non_zero(np.array([0,1,0,1]))
# 2

Solution 4:

Not sure if I have made a mistake here, but this seems 6x faster:

# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)  

In [41]: @njit() 
    ...: def methodA(a): 
    ...:     return len(np.nonzero(a)[0])                                                                                           

# Call and check result
In [42]: methodA(a)                                                                                 
Out[42]: 666644445

In [43]: %timeit methodA(a)                                                                         
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [44]: @njit() 
    ...: def methodB(a): 
    ...:     return (a!=0).sum()                                                                                         

# Call and check result    
In [45]: methodB(a)                                                                                 
Out[45]: 666644445

In [46]: %timeit methodB(a)                                                                         
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Post a Comment for "Count The Number Of Non Zero Values In A Numpy Array In Numba"