vectors with a known classification, and a query vector that we wish to determine the
classification of. It also accepts a value for k, which dictates how many of the nearest
neighbours to the query point are considered when making the classification estimate.
Depending upon the problem the best value of k to use will vary, and the user should
optimise this in a fair way, although in many problems the value can be surprisingly small.
This next function has not been optimised in terms of speed of execution, given we are
primarily interested in clarity here. However, in the on-line material we also include a
speed-optimised version which uses NumPy arrays to avoid loops and repeated calls to
getFeatureDistance(). A notable simplification in the code below is that we only make a
simple decision if there is a tie in the scores, where competing classifications have equal
numbers among the nearest neighbours. Here a tie is broken by taking the category with
the closest single point to the query. Also, when there are only two categories using an odd
The function code involves the definition with the input arguments. We then perform a
check to make sure k is small enough for the data set. After this, the next step is to fill the
starting values of the dists list, which records the distances and categories for all the
known (already classified) data points. The list is appended with small tuples of distance
and category (dist, cat), with the distance being first so that when we sort the list we sort
according to distance, but the categories remain paired with their corresponding distances.
The small tuple could also contain the feature vector from the known data input, if we
need the function to report what the closest known data points actually are, rather than just
the best classification. After the dists list is filled it is sorted so that it is in order of
increasing distance. The k closest of the known categories is then simply taken from the
start of the list using the appropriate slice notation dists[:k].
counts = {}
for dist, cat in closest:
counts[cat] = counts.get(cat, 0) + 1
bestCount = max(counts.values())
bestCats = [cat for cat in counts if counts[cat] == bestCount]
for dist, cat in closest:
if cat in bestCats:
return cat
The remainder of the function involves looking at the k closest data points to the query
and determining what the most popular category is. This is achieved by making a
dictionary for category counts, and then as we loop through the closest of the known
points the count is made by adding one to the count for each category encountered. Note
that the .get() function is used so that the starting value for a category’s count is zero; there
is no previous entry in the dictionary to add one to. With categories of the closest points
tallied the best count is determined as the maximum of the list values. The best categories
are then determined by finding all those that have this maximum count, using Python’s
compact list comprehension notation. More than one category may have a maximal count,
indicating a tie. Potential ties are resolved by the final loop where the closest matches are
gone through in order, remembering that they are sorted to give the closest first. The first
category in the list from those with maximum count is then returned as the best category
prediction (whereupon the loop ceases). The first point in the closest list is not necessarily
the winner if other categories within the closest k are more populous.
We can test the functions with some crude fictitious data. Here we have colour vectors
that are placed into only two named categories. We have tried to have about equal
numbers of well-spaced points for each category, so that the choice of inputs doesn’t
introduce much bias. We then test by running the function on a query colour. This example
shows how the input of known data with classifications is expected to be a list of smaller
lists (or tuples), each of which contains first a feature vector, and second a textual category
classification. Other Python data structures could be used, but they should naturally match
the programming of the prediction function. Also, in many situations you would need to
check the form and validity of the input before running the calculation.
knownClasses = [((1.0, 0.0, 0.0), 'warm'), # red
((0.0, 1.0, 0.0), 'cool'), # green
((0.0, 0.0, 1.0), 'cool'), # blue
((0.0, 1.0, 1.0), 'cool'), # cyan
((1.0, 1.0, 0.0), 'warm'), # yellow
((1.0, 0.0, 1.0), 'warm'), # magenta
((0.0, 0.0, 0.0), 'cool'), # black
((0.5, 0.5, 0.5), 'cool'), # grey
((1.0, 1.0, 1.0), 'cool'), # white
((1.0, 1.0, 0.5), 'warm'), # light yellow
((0.5, 0.0, 0.0), 'warm'), # maroon
((1.0, 0.5, 0.5), 'warm'), # pink
]
result = kNearestNeighbour(knownClasses, (0.7,0.7,0.2), k=3)