15 Jan 2018

A primer for online learning

I first encountered the concept of online learning several years ago, and it has been on my to-do list to investigate in more detail since.

In a nutshell, online models give you the ability update (and hopefully improve) your model by setting up a feedback loop in production. For example, users may log in to your website, and you make predictions about what they may like to buy based on their demographics. You serve these recommendations, and then record if they actually purchased what you recommended. You then use this feedback to update your model.

This differs from what I will refer to as the batch approach, where you train a model to some initial data, and then place it into an environment where it produces predictions. The model cannot update itself based on any feedback loop, although by monitoring the performance a decision could be made to retrain the model based on some criteria. This batch approach is a standard supervised learning approach.

I will be using the SGDClassifier class in scikit-learn throughout this post. The SGD Classifier allows you to implement stochastic gradient descent, and is equivilent to a logistic regression model if you specify a log loss function. You would never expect a stochastic model to converge as well as a regular logistic regression model, however is has advantages, such as use as an online model (what we are interested in), and it can also more suited to 'big data' (we won't discuss this in this post). For a great, bitesized introduction to SGD and online learning, I can recommend Andrew Ng's Machine Learning course on Coursera.

# three scenarios

I will be investigating three scenarios in this post, where I will compare a batch trained logistic regression model to an online model. These are fairly simple toy examples, but will hopefully demonstrate some scenarios where and online model may give you advantages over a batch trained model.

# scenario 1: not enough data to train initial model

We will look at the MNIST digits dataset (or at least a sample of it) which comes with scikit-learn. Other than scaling the data, I will just apply models to the data as is. This scaling is very important for SGD models, as they can converge very poorly if the features are on different scales.

We will set up the example in the following way:

Take an initial batch of 90 observations to train initial models, using the LogisticRegression class for the batch model, and the SGDClassifier with a log loss function for the online model. We will use cross validation to select reasonable hyperparameter values for both models.

Divide the remaining data into batches. We will take each batch one by one, and calculate the accuracy of the two models for comparison. The SGDClassifier model will be updated using the partial_update method, whereas the LinearRegression model will not be updated. We will then compare accuracy of the two models over all batches to gain intuition as to whether there was an advantage using an online model.

If you are interested, you can have a look at a notebook containing this example.

As a quick aside, SGD models are usually quite sensitive if you pick your hyperparamenters poorly. You would never expect a stochastic model to converge as well as its non-stochastic counterpart, so poor choices of hyperparameter will compound this. As mentioned, a good rule of thumb is to always scale the data before fitting the model.

It is a common strategy when using SGD models in batch scenarios to decrease the learning rate as you iterate to improve convergence. However, in an online scenario where the distributions may not be entirely stationary, or if we dont have as much data as we would have liked for the initial model training, this may not be desireable as we use the partial_update method. The method does allow you to pass sample weights, which is good as it provides flexibility, but bad as it is another potential tuning parameter. The point I am trying to get across is that online models could require a non-trivial amount of tuning to perform optimally.

We see that the online and batch models start with similar performance, and the performance of the online model seems to increase for new batches of data. Interestingly, the perfomance of the batch model seems to degrade somewhat.

As the digits dataset has 10 outcome classes and we train both models on an initial sample of approximately 90 examples, we could hypothesise that we did not have enough data initially to build an optimimal model. We see the online model improve over time as we are adding more information to the model, and improving the decision function.

Of course, we have a fairly limited amount of data- we only have 20 batches to track performance over. It would be interesting to see if the performance of the online model stableises over time. In that case we would want to reduce the learning rate or reduce the sample weights of the new data to try and get a stable model. Also, I would be suprised if the trend of the batch model keeps decreasing as we obtain more batches: this apparent trend of decreasing accuracy could just be an artifact of having limited data.

None the less, an interesting scenario. This example would suggest that where the initially available data set is relatively small,that an online model could be deployed and then would be expected to improve over time as more data is gathered.

# scenario 2: distributions that change over time

A very interesting scenario (that is sometime referred to as concept drift) occurs when the underlying distributions of data change over time. For instance, you wouldn't expect a model trained to recommend a summer range of clothing to neccessarily be approriate to describe an end of winter season sale. That is, the scenario you are using your model to score is no longer representative of that when you trained the initial model.

For batch trained models, a typical approach is to monitor a performance metric, and if it is seen to drift a new model is trained. This can be undesirable if a modeller is required to manually retrain the algorithm at any given time.

A second example notebook can be found here, where I generate some data from two distributions with slowly oscillating means. If we take selections of consecutive data points, we can visualise the behaviour:

Clearly, we expect a batch trained model to perform very well if the data is in a configuration similar to that in which it was trained. The accuracy will be very poor otherwise. We want to investigate if the online model can 'keep up' with the moving distribution.

The results were actually quite impressive for this example: if we compare the moving averages of the accuracy for the two models, the online model significantly outperforms the batch trained model. I say suprising, as a little bit of reading around suggests that scenarios like this can be incredibly difficult to model. Of course the scenario is fairly simple here, but none the less the results are promising:

We see the accuracy of both models falls initially, but the accuracy of the online model soon recovers to become reasonably stable, giving a value of about 0.8. The accuracy of the batch trained model fluctuates with the moving centroids of the distributions: it performs very well when the configuration is similar to that when it was trained, and very poorly otherwise.

# scenario 3: stable distributions with enough data

The final scenario I will investigate is one where an online model may not be appropriate. As mentioned, you would not expect the convergence of an SGD model to be as good as that of the equivalent non-stochastic version. Therefore, for a scenario where there is enough data to build a good initial batch model, and the class distributions do not change over time, there would be little advantage to using an online model over a batch trained model. In fact, you might be concerned about the consistency of the perfomance of the online model, as the fitted model parameters can change at every update.

We can generate a simple stationary distribution from two classes, and fit initial models:

We will draw additional points from these distributions to compare the performance over a number of batches. This notebook contains this example.

If we plot the accuracy over a number of iterations, we see the results of the two approaches are comparable:

In this instance, I would favour the batch trained approach for its simplicity. However, if you were to use a model like this to score real customers, you would want to constantly monitor it in case patterns and behaviours start to change over time.

# other thoughts

Modelling these scenarios have certainly been educational for me, and I hope they were interesting to you.

I plan to spend some time investigating the implementation side of online models in the near future. They certainly have an extra level of complexity, but this complexity has the potential to pay off significantly compared to the approach of deploying models trained to a single batch of data in some scenarios. A 'real world' implementation could involve a number of models, both batch and online, using a bandit algorithm to mediate which should be served up to users.

As a final point, we have to remember that not all learning algorithms are suited to online learning. For example, a random forest would have to be entirely retrained for new data: it wouldn't be possible to update the decision function online if additional data was gathered. The SGD class in scikit-learn is fairly limited to a number of linear models that it can implement, but in my experience linear models are often the ones that are successfully implemented into production.

*TL;DR- I started investigating online learning models with a few toy scenarios. The methodology certainly looks promising to apply to some real world scenarios. If you want to follow along, my notebooks are in this repository.*