Hide code cell source
import numpy as np
import math

Case study 2: prime sieve#

This chapter opened by exploring the importance of good algorithm and code design. We spent a fair bit of time redesigning and micro-optimising a function in standard python that implemented a prime sieve. For large n, for example greater than 10 million, the function prime_sieve_best was our fastest option.

def prime_sieve_best(n):
    '''
    Our fastest prime sieve in standard python
    Fastest for large n e.g. > 10m.
    '''
    candidates = bytearray(b"\x01") * (n + 1)
    candidates[0] = 0
    candidates[1] = 0
    limit = int(math.sqrt(n)) + 1    
    
    for i in range(2, limit): 
        if candidates[i]:
            candidates[i+i::i] = [0] * ((n - i) // i)
            
    return [i for i in range(n+1) if candidates[i]]      

The function prime_sieve_np again reimplements the algorithm, but this time using numpy optimised arrays and functions.

def prime_sieve_np(n):
    '''
    Prime sieve reimplemented in NumPy.
    '''
    candidates = np.ones(n, dtype=bool)
    limit = int(np.sqrt(n)) + 1
    candidates[0:2] = False
    
    for i in range(2, limit):
        if candidates[i]:
            candidates[i+i::i] = False
    return np.flatnonzero(candidates)

You should see a reasonable speed up, for free, using numpy. Let’s compare it for an even larger n.

HUNDRED_MILLION = 100_000_000
%timeit len(prime_sieve_best(HUNDRED_MILLION))
11.4 s ± 601 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit len(prime_sieve_np(HUNDRED_MILLION))
1.21 s ± 42.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

That’s should provide around a factor of 10 speed up. On my machine runtime dropped from around 1 seconds on average to 1.1 seconds on average.

This is also a nice example where, in my opinion, the numpy code is more readable than the standard python. This is partly because numpy broadcasting means we can the elements in a slice cleanly. i.e.

candidates[i+i::i] = False

verus standard python

candidates[i+i::i] = [0] * ((n - i) // i)