ML Series: K-Nearest Neighbours Described

K Nearest neighbours is one of the simplest machine learning models to explain Let’s say, we have lots of historical data with a class assigned. In the below, you can see that we have a number of points in the red, blue and green classes.

The idea then is, if we have a new point added to the chart, which I have denoted in yellow, which of the three classes will it belong to?

That’s simple – we look at the nearest neighbours (i.e., the points closest to it). So, if K=1 then we take the 1 closest point and assign the same class. In this case, red is closest & hence our new datapoint would be classified as red.

Note: K is the number of neighbours (closest datapoints) to check

It gets a bit more complicated when the new datapoint is in the middle of several classes. As in the example below, if k=11, we take the 11 closest points to our new datapoint.

In this case, we have four red points; five green and two blue points. We simply classify the data point as green as the largest portion of the 11 nearest datapoints are green.

Simple enough to understand, right? Maybe you have some questions like:

  1. How do we choose what the value of K should be?
  2. What happens if there is a draw? What if blue = 5 and red = 5


First things first, how do we choose the value of K? Well, it’s not easy & as with all our machine learning algorithms, it takes some tinkering to get the right value. Noise, mislableled data, outliers and overlaps between classes can lead to inaccurate classification and overfitting. We can improve our classification by setting a much larger K value (at the expense of computational demand) or we could pre-process our data to remove perceived outliers and make things much cleaner. Adding too higher value for K can lead to underfitting.

We can retrofit some rules around ties (when multiple classes return the same number). But personally I just use the out of the box functinality with scikit learn, which states that: “Regarding the Nearest Neighbors algorithms, if it is found that two neighbors, neighbor k+1 and k, have identical distances but different labels, the results will depend on the ordering of the training data.”

Pros:

  • Simple to understand and build
  • Only one hyper parameter to tune
  • Can be used for classification and regression


Cons:

  • It’s not very fast
  • Not so accurate for data with high dimensionality
  • Can be quite sensitive to outliers

Why is it not so good in highly dimensional data? Well, as we move to more and more dimensions, datapoints which were close together become further apart, so may no longer be a neighbour. If we look at the below charts (which I found on this great website), you can see that the points are clustered together pretty closely on the 1D chart. As we move to 2D, some of those points move to the extremeties of the axis, pushing them further apart.

Look at our points at 0.4 in the original graph. As we move to two dimensional data, those points have been cast to opposite ends of the Y axis. They are no longer nearest neighbours. When we move to 3 dimensions, the problem is futher exacerbated.

There is suddenly a lot of space betwen data points – the more dense the data is, the more accurately any algorithm can classify, hence, your data needs to grow exponentially to keep the same data density, otherwise, your nearest neighbour won’t be very near & will leave the model open to severe misclassification. K Nearest Neighbours requires that a point be close in every single dimension.

The Knearest algorithm requires that data points are close together, That becomes tough in a highly dimensional, not very dense, dataset as we end up with a lot of whitespace between our data points.

image source