08 Dec 2015

k-NN Classification

For this example, I'm going to turn to the trusty Iris data set. There are a number of bits and pieces we need to think about before rushing in to writing our algorithm. As with any learning algorithm that we wish to evaluate, it's key to think about a training and test set. Further, what about resampling?

Here, I'm not trying to reinvent the wheel with a brand new k-nn classifier that I expect everyone to use. I'm just trying to break it down for educational purposes! Max Kuhn's caret package will do all the below, and better. The package takes care of resampling, stratified train/test splitting and data preprocessing with ease. I'm going to be naughty and ignore resampling, and do a very simple train/test split and preprocess. The focus here is on the k-nn algorithm, after all.

So anyway, how does a knn classification algorithm work, you ask? The answer is: very simply!

  1. Start with a training set where we have known outcome classes
  2. Decide on a metric of 'closeness', here we will use the Euclidean distance
  3. Normalise the covariates so they are on the same scale. This way, covariates on different absolute scale hopefully won't have an adverse effect on the classification
  4. For the test data, calculate the distance to all the points in the training data
  5. By specifying how many nearest neighbors (k) to consider, count up the classes of the nearest neighbors, and go with the majority.


The iris data set provides us with 150 measurements of petal and sepal lengths and widths for three species of iris. The species is the outcome which we want to predict based upon the provided measurements.

So that we can say something sensible about how good a job the algorithm does, lets start by randomly selecting training and test samples. The model will be built only on the training set, so the model has no knowledge of the test set.

As a k-nn classification does it's work relying on a metric of 'closeness', what happens if some of the covariates are on a larger absolute scale than others? In short, bad things happen, and those with a larger absolute scale may dominate the fitting. To remedy this, it's always advisable to normalise each of the covariates before applying the technique. To do this, assuming the covariates are normally distributed, I subract the mean and divide by the standard deviation, so we are left with standard normal distributions. Note that this process is done for the training set, and the same transformation is applied to the test set.

KNN Algorithm

The algorithm is really, really straightforward. I've uploaded it to my github here, but I will briefly talk through the key steps.

After normalising the scales of the covariates in the training and test set, the first part of the algorithm is to build up the kClosest matrix in this nested apply loop. This matrix stores for each test point (designated by it's row number) the row number of the k closest test points, defined by the Euclidean distance.

kClosest <- apply(dataTest, 1, function(x) {
  tmp <- apply(dataTrain, 1, function(y) euclDist(x, y))

Then, we go through the columns in the kClosest matrix. We use the base R table() function, and tally up the counts for each class. The class that gets the most 'votes' wins, and the new sample is allocated to that class. In case of a tie, the first element of testClass is chosen.

outcomeClass <- apply(kClosest, 2, function(x) {
  tmpTab <- table(outcomeTrain[x])
  testClass <- names(which(tmpTab == max(tmpTab)))[1]

k is a tuning parameter, and if we were interested in optimising the model we would repeat the fitting (with resampling) over a range of choices of k to evaluate the optimum value.


As I built a simple model, I'm only going to do a simple test. I'll just build up a confusion matrix, and calculate the total classification accuracy. This test isn't particulary robust, as I haven't considered resampling. A confusion matrix can be easily built using the table() function, and the classifier has done a fairly good job when considering the 11 nearest neighbors. I have specified petal length and petal width as the covariates to consider in this example- I could have chosen more if I liked, as the code is written generally and not restricted to the number of input dimensions.

The figure below shows the test set, and the color shows whether it was correctly classified (yellow) or not (blue). Only 3 were incorrectly classified, and as expected its those tricky ones on the border of the versicolor and virginica clusters. Overall, we have about 92% accuracy, which isn't bad at all considering how simple the technique is.


Other thoughts

This was a nice warm up excercise- I'm going to be playing with other techniques in the coming months. It would be interesting to apply this algorithm in conjunction with dimension reduction techniques- I will come back to that another time though!

TL;DR - a simple k-nn classification algorithm applied to the iris data set. It's really easy to do.