24 Jan 2015

Visualising Classification Boundaries

When performing classification modeling, it is often informative to be able to visualise the classification boundaries produced by different models. This is a really great way to gain intuition of how different types of model work, and can highlight model flexibility, or what a good choice of model might be.

As always, the script I used to generate these visualisations is here.


I have turned to the trusty Iris data set once more for this little problem. It's a good candidate data set- it has more than two dimensions (petal length, petal width, sepal length and sepal width) so we have to go beyond the trivial task of just plotting two predictive factors against one another (as we would if there were only two).

The problem of visualising high-dimensional predictor space is nothing new- often people turn to principle component analysis. The problem with this, as far as I have found, is that it can be very difficult to plot classification boundaries to superimpose on top of your data points.

This is because you would have to do an exhaustive grid search of predictor values, to then make predictions for each possible point in predictor space, which would become a hugely expensive task if you have many features.

So my strategy is far more simple: pick two predictive inputs of interest, hold all other predictive inputs at their mean value, and then perform a grid search of the two chosen predictors to determine the boundaries.


I have implemented this 'rough draft' in a simple little function- I plan to improve it, and then include it in a package I have slowly been working on. It aims to make it easier to keep track of multiple predictive models you may have developed using caret (coming soon, so watch this space!). But anyway- the function only works for continuous numeric predictors at present- for categorical, I think the best strategy would be to dummy up the predictors and then select the mode, rather than the mean.

The methodology is fairly simple. The function takes as input the names of the predictors of interest, a predictive model (here I have only tested for those of class 'train', built by caret::train()), the grid spacing for the two predictors of interest, and the input data set (that was used to build the model). It outputs a data frame with regulary spaced grid points and class predictions at each point.

The function itself is fairly simple:

# function for creating data to plot a grid search of classification
# boundaries
plotDecision <- function(pred1, pred2, trainModel,
                         dP1 = 1, dP2 = 1, inputData) {
  # max and min of predictors of interest, and use
  # expand.grid to get values for grid search
  min1 <- min(inputData[, pred1])
  max1 <- max(inputData[, pred1])
  min2 <- min(inputData[, pred2])
  max2 <- max(inputData[, pred2])
  pred1vals <- seq(from = min1, to = max1, by = dP1)
  pred2vals <- seq(from = min2, to = max2, by = dP2)
  grid_vals <- expand.grid(pred1vals, pred2vals)
  # calculate the mean value of the other predictive factors
  # these will be held constant as we vary over the two inputs
  # of interest
  otherData <- inputData[, !names(inputData) %in% c(pred1, pred2)]
  reps <- length(pred1vals) * length(pred2vals)
  meanData <- data.frame(vapply(otherData,
                                function(x) rep(mean(x), reps),
  nVar <- ncol(meanData)
  # fill up a data frame with the grid search values and
  # the mean values of the other predictors
  plotData <- as.data.frame(matrix(NA, nrow = reps, ncol = (nVar + 3)))
  plotData[, 1:2] <- grid_vals
  plotData[, 3:(nVar+2)] <- meanData
  names(plotData)[1] <- pred1
  names(plotData)[2] <- pred2
  names(plotData)[3:(nVar+2)] <- names(meanData)
  names(plotData[, (nVar+3)]) <- predictions
  # predict using the model for all the data points,
  # and return the required data for plotting
  plotData$predictions <- predict(trainModel, newdata = plotData)
  return(plotData[, c(1:2, ncol(plotData))])

Hopefully, with the comments I have included, it is self-explanatory.


Ok, so here are some sample results for the Iris data set. I will just include a few here. I call the function with:

# calculate decision boundaries
p <- plotDecision(Petal.Length,
                  dP1 = 0.01,
                  dP2 = 0.01,
                  inputData = iris[, 1:4])

and can produce nice plots with the outputted data with:

# plot the grid of predictions using geom_raster
# then, superimpose the chosen predictor values
# (petal length and width in this case) on top
ggplot(p, aes(x = Petal.Length,
              y = Petal.Width)) +
  geom_raster(aes(x = Petal.Length,
                  y = Petal.Width,
                  fill = factor(predictions)),
              alpha = 0.2,
              interpolate = TRUE) +
  scale_fill_manual(values = philTheme(),
                    name = Prediction Region) +
  geom_point(data = iris, aes(x = Petal.Length,
                              y = Petal.Width,
                              shape  = Species),
             alpha = 0.6,
             size = 3) +

The input variable model1 was an object of class train I generated using caret::train(), specifying the model to be an LDA. The classification boundaries look like this:


What about other models? Using a random forest, the boundaries look like this:


I hope that is informative to you about why random forests can often produce such powerful predictions: they can fit very flexible classification barriers.

A final plot to leave you with, which I think is quite instructive, is an example of when blindly choosing non-linear models over linear models can be a bad idea. The below plot is for a neural network:


We see the neural net model, despite it's flexibility, has produced boundaries similar to the LDA model. Therefore, we are not getting any advantages by using the more complex, and less interpretable model. In other words, when the boundaries are seem to be linear, a linear model may be a good choice!

Other thoughts

Now, I will confess that I know the Iris data set quite well. Due to this, I know that Petal Length and Width are probably the two most important predictive inputs. Therefore, it's not to suprising that the data points fall (mostly) quite nicely in the correct regions. I cannot guarentee that these plots will be as informative for data sets that you do not know much about- but that is in itself an argument that exploratory data analysis is vital, and you should get to know your data before doing any modeling!

My plan with this function (or a descendant of it) is to eventually include in my own package to make handling and evaluating multiple models built with caret a little easier. I will modify it to be able to deal with catergorical predictors and numeric count data as well. So watch this space!

TL;DR- I've presented a method to visualise classification boundaries. The code is here, and I hope you find it informative!