Case study 3: online statistics#

A cautionary tale.#

In many computational modelling procedures you will need an updated estimate of statistics as the code executes. For example, you may need to track a mean or a standard deviation of a performance measure in a multi-stage algorithm or as a simulation model of a healthcare system executes.

As we have seen numpy provides highly efficient functions for calculating a mean or standard deviation based on data held in an array. I’m always tempted to make use of these built in procedures. They are indeed fast and incredibly easy to use. The downside is that you waste computation via repeated iteration over an array. The other option, that requires more careful thought (due to floating point issues), is a running estimate of your statistics. In general, I’ve implemented such procedures in standard python. Let’s look at an example where we compare recalculation using a numpy function with a running (sometimes called an ‘online’) calculation of the mean and standard deviation in standard python.

We will first refactor AttendanceSummary from Statistical procedures to an OnlineSummary class to include an update() function. It will accept a np.ndarray that recalculates the sample mean and standard deviation using a numpy on the full data set. The function test_complete_recalculation iteratively calls update using more data each time. For simplicities sake we will reuse the data contained within ed_data.

Wait a minute!

This chapter is about scientific coding in numpy, but this case study is demonstrating that standard python is more efficient! Well not quite. The overall theme of this part of the book is that code is a first class citizen in health data science. You should always think about the design of your code in any algorithms or models you implement. This case study is demonstrating that there may be instances where a numpy solution is not the most efficient.

Imports#

import numpy as np

Data#

file_name = 'data/minor_illness_ed_attends.csv'
ed_data = np.loadtxt(file_name, skiprows=1, delimiter=',')
print(ed_data.shape)
(74,)

numpy solution#

class OnlineSummary:
    
    def __init__(self, data=None, decimal_places=2):
        """
        Track online statistics of mean and standard deviation.

        Params:
        -------
        data: np.ndarray, optional (default = None) 
            Contains an initial data sample.
            
        decimal_places: int, optional (default=2)
            Summary decimal places.
        """
        if isinstance(data, np.ndarray):
            self.n = len(data)
            self.mean = data.mean()
            self.std = data.std(ddof=1)
        else:
            self.n = 0
            self.mean = None
            self.std = None
            
        self.dp = decimal_places
        
    def update(self, data):
        '''
        Update the mean and standard deviation using complete recalculation.
        
        Params:
        ------
        data: np.ndarray
            Vector of data
        '''
        self.n = len(data)
        
        # update the mean and std. Easy!
        self.mean = data.mean()
        self.std = data.std(ddof=1)
        
    
    def __str__(self):
        to_print = f'Mean:\t{self.mean:.{self.dp}f}' \
             + f'\nStdev:\t{self.std:.{self.dp}f}' \
             + f'\nn:\t{self.n}' \
        
        return to_print
def test_complete_recalculation(data, start=2):
    summary = OnlineSummary(data[:start])

    for i in range(start, len(data)+1):
        summary.update(data[:i])
    return summary
summary = test_complete_recalculation(ed_data)
print(summary)
Mean:	2.92
Stdev:	0.71
n:	74
%timeit summary = test_complete_recalculation(ed_data)
1.17 ms ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

You should find the numpy implementation fairly efficient clocking in at around 1.5ms on average. But can we do better in standard python by computing an online mean and standard deviation?

Online sample mean and variance#

To do this we will use Welford’s algorithm for computing a running sample mean and standard deviation. This is a robust, accurate and old(ish) approach (1960s) that I first read about in Donald Knuth’s art of computer programming vol 2. (just to be clear I learnt how to do this in 2008 not 1960!). To implement it we need to refactor update. Note that we will need a fair bit more code than our simple numpy solution.

The algorithm is given in a recursive format. For our purposes here, you can just think of that as tracking the mean and standard deviation as attributes of a class that we iteratively update with a new \(x\).

The first thing you need to do is handle the first observation encountered.

\[M_1 = x_1\]
\[S_1 = 0\]

Then on each subsequent call you update \(M\) and \(S\) making use of the previous values. Note that \(M\) has a relatively simple interpretation: its the sample mean. However, \(S\) is not the standard deviation. Its actually the sum of squares of differences from the current mean. We will look at how to update that first and then I’ll show you the equation for converting to the standard deviation.

\[M_n = M_{n-1} + \dfrac{x_n - M_{n-1}}{n}\]
\[S_n = S_{n-1} + \left[(x_n - M_{n-1}) \times (x_n - M_n)\right]\]

If the equations are confusing you can think of \(M_n\) as the updated_mean and \(M_{n-1}\) as the previous_mean.

Once the update is complete it is then relatively trivial to calculate the standard deviation \(\sigma_n\). Note that we don’t necessarily need to track the standard deviation just \(S_n\). We can inexpensively calculate \(\sigma_n\) when it is needed.

\[\sigma_n = \sqrt{\dfrac{S_n}{n-1}}\]

The code listing below modifies OnlineSummary to make use of Welford’s algorithm. Note that std is now a property that calculates the standard deviation on the fly using \(S_n\)

class OnlineSummary:
    
    def __init__(self, data=None, decimal_places=2):
        """
        Returns mean, stdev and 5/95 percentiles of ed data

        Params:
        -------
        data: np.ndarray, optional (default = None) 
            Contains an initial data sample.
            
        decimal_places: int, optional (default=2)
            Summary decimal places.
        """
        
        self.n = 0
        self.mean = None
        self._sq = None
        
        if isinstance(data, np.ndarray):
            for x in data:
                self.update(x)
            
        self.dp = decimal_places
    
    @property
    def variance(self):
        return self._sq / (self.n - 1)
    
    @property
    def std(self):
        return np.sqrt(self.variance)
    
    def update(self, x):
        '''
        Running update of mean and variance implemented using Welford's 
        algorithm (1962).
        
        See Knuth. D `The Art of Computer Programming` Vol 2. 2nd ed. Page 216.
        
        Params:
        ------
        x: float
            A new observation
        '''
        self.n += 1
        
        # we need to do more work ourselves for online stats!
        
        # init values
        if self.n == 1:
            self.mean = x
            self._sq = 0
        else:
            # compute the updated mean
            updated_mean = self.mean + ((x - self.mean) / self.n)
        
            # update the sum of squares 
            self._sq += (x - self.mean) * (x - updated_mean)
            
            # update the tracked mean
            self.mean = updated_mean
    
    def __str__(self):
        to_print = f'Mean:\t{self.mean:.{self.dp}f}' \
             + f'\nStdev:\t{self.std:.{self.dp}f}' \
             + f'\nn:\t{self.n}' \
        
        return to_print
def test_online_calculation(data, start=1):
    summary = OnlineSummary()

    for observation in data:
        summary.update(observation)
    return summary
summary = test_online_calculation(ed_data)
print(summary)
Mean:	2.92
Stdev:	0.71
n:	74
%timeit summary = test_online_calculation(ed_data)
66.1 µs ± 586 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Summing up#

Crickey nothing beats a good algorithm! You should find that you are now working in microseconds (µs) as opposed to milliseconds. 1µs = 1000ms. On my machine the test_online_calculation executes in ~45 µs on average while test_complete_recalculation takes ~1500 µs. So we are finding a speed up of ~97%. That gap will continue to grow as the number of samples \(n\) increases. The result is explained because our second implementation has a constant time for execution (and constant number of computational steps) while the time complexity of the numpy call depends on the size of the array. That’s a lesson well worth remembering when developing code for scientific applications requiring performant code.