Chapter 18
k-means clustering
Finding structure in data without being told what to look for. The most widely-used clustering algorithm, watched step by step.
Imagine you have data and no labels. Just points. A spreadsheet of customers with their purchase histories. A folder of unlabelled photographs. A pile of text documents.
You suspect there's structure in there — that customers fall into a few natural types, that the photographs sort themselves into visual groups, that the documents are about a handful of topics. But no one has told you what the groups are. You have to discover them.
This is clustering: the problem of finding structure in data without being told what to look for. And the most widely-used clustering algorithm, the one nearly every textbook starts with, is k-means.
Two reasons to start here. First, k-means is genuinely useful — it runs in production at companies you have heard of, doing things like customer segmentation, image compression, and anomaly detection. Second, k-means is one of the most beautifully visualisable algorithms in all of machine learning. By the end of this chapter you will have watched it converge, iteration by iteration, on three different datasets, and you will have seen exactly how and why it sometimes fails.
§1 What is clustering?
Clustering is unsupervised learning. Recall that in supervised learning (Chapters 4 through 17), we had labelled examples: each input came with the correct output, and our job was to learn the function that maps from one to the other. In unsupervised learning there are no labels. We only have the inputs. The goal is to discover structure in those inputs alone.
The structure we are looking for, in this chapter, is groups. The intuition is simple. A good clustering is one where points in the same group look similar, and points in different groups look different. "Similar" usually means close together in some space. So clustering reduces to: given a cloud of points, partition them into groups such that nearby points end up in the same group and faraway points end up in different ones.
Some real applications:
- Customer segmentation. Group users by behaviour so a marketing team can design different campaigns for different segments.
- Image compression. Each pixel is a colour (three numbers). Cluster all pixels into 256 groups, then replace each pixel with its group's average colour. The result is an image that uses 256 colours instead of millions.
- Document organisation. Each document is a vector (we will see how in Chapter 17). Cluster the vectors, and you have automatically discovered topics.
- Anomaly detection. Most legitimate transactions cluster together. The ones far from any cluster are worth a second look.
What unites these applications is that we do not know the groups ahead of time. The algorithm finds them.
§2 The k-means algorithm
There are many ways to cluster. k-means is one. Its strategy is almost embarrassingly direct.
We start by picking a number, , which is how many groups we expect to find. Choosing is its own problem — we come back to it in §6. Then we make a guess: "the centres of the groups are at these positions." Probably a terrible guess. We refine it.
The refinement loop has two steps:
- Assignment. Look at every point. Assign it to whichever centre is closest.
- Update. Each centre moves to the average position of the points assigned to it.
Then repeat. And repeat. And repeat. Until nothing changes.
That is the algorithm. In pseudocode:
choose K, initialise K centroids
repeat:
assign each point to the nearest centroid
move each centroid to the mean of its assigned points
until no centroid moves
It might not be obvious that this ever terminates, or that the answer is any good when it does. Both facts must be earned — we earn them in §5. For now, look at what the algorithm actually does. Pick a dataset shape, choose , and step through the iterations one by one.
A few things to notice as you play with it:
- The first iteration does the most work. Most points snap to roughly the right cluster the first time the assignment step runs. Subsequent iterations only shuffle the boundary cases.
- Convergence is fast. On the three-blobs dataset with reasonable initialisation, k-means typically converges in three or four iterations. Most real problems behave similarly.
- The dataset matters. Switch to "Two moons" with and run. k-means fails — and not in a subtle way. We explain why in §8, and Chapter 19 introduces algorithms that succeed where k-means fails.
- Initialisation matters. Toggle the initialisation method between random and k-means++ and re-run a few times. The random option will occasionally find a worse answer than k-means++, especially on the anisotropic dataset. We see why in §7.
The rest of this chapter is just unpacking what you have already watched.
§3 The assignment step
The assignment step is the easy one. For each point , we want to find the closest centroid . "Closest" is measured by some distance function . The most common choice is Euclidean distance:
where is the number of features of each point. For our 2D visualisation, . The assignment rule then picks the cluster whose centroid is closest:
A small efficiency note. The square root in Euclidean distance is monotonic: if then . When all we want is the argmin — which centroid is closest — we can skip the square root and compare squared distances instead. Faster, and the answer is the same.
The geometric consequence of this assignment rule is one of the most elegant results in computational geometry. The plane gets divided into regions, one per centroid, where each region contains exactly the points closest to its centroid. This is called a Voronoi tessellation, and it looks like cells in a honeycomb. Every time you watch k-means run, you are watching the Voronoi cells reshape themselves.
§4 The update step
We have been hand-waving over why the centroid update is the mean. Time to earn it.
The update step says: each centroid moves to the average of the points currently assigned to it. But the average is an odd choice — why not the median, or the centre of the bounding box, or whatever point is most central by some other measure? The answer is that the mean is the unique position that minimises the sum of squared distances to the cluster's points. If we want WCSS to decrease, the mean is provably the best place to put each centroid.
The derivation is short. Suppose cluster contains points . We want to choose a centroid position that minimises:
Take the gradient with respect to and set it to zero:
which solves to:
The mean. The Hessian is — positive definite, so this is the unique global minimum, not a maximum or saddle point.
You do not have to take the math on faith. The widget below shows the same fact experimentally. Drag the centroid around a single cluster of points; the running WCSS value updates live. The marked position is the mean. The minimum WCSS happens exactly there — and the further you drag the centroid from the mean, the worse WCSS gets, in proportion to the squared distance you have moved away.
This is one of those rare moments where the theory and the experiment align so cleanly that they feel almost trivial. They are not trivial, though — the same logic generalises directly to weighted k-means, to fuzzy c-means, to mixture models, to the M-step of EM. All of them are minimising some sum of squared distances, and the answer at every step is some kind of mean.
§5 Why k-means converges
Two questions deserve answers before we trust this algorithm. First: does it always terminate, or could it run forever? Second: when it terminates, is the answer any good?
The first question has a clean answer. Look at what WCSS does over the course of the algorithm.
When the algorithm starts, WCSS has some value. The first assignment step picks the closest centroid for each point — and "closest" is exactly the choice that minimises that point's contribution to WCSS, given the current centroids. So the assignment step can only lower WCSS, or leave it unchanged. Then the update step moves each centroid to the mean of its assigned points — which, by §4, is exactly the position that minimises WCSS given the current assignments. So the update step can only lower WCSS, or leave it unchanged, too.
Both steps non-increasing. WCSS is bounded below by zero. So the WCSS sequence converges. That doesn't immediately mean the algorithm stops, but observe: there are only finitely many ways to assign points to clusters — of them, in fact. Once WCSS stops strictly decreasing, the assignments must repeat, which means the algorithm has reached a fixed point where neither step changes anything. We terminate.
Watch this empirically:
Two things to internalise from this view. The WCSS curve only ever points down. And the first iteration does almost all the work — the curve drops steeply, then levels off, often converging in three to six iterations even on substantial datasets.
That answers the does it terminate question. The is the answer any good question is more subtle, and the answer is: locally good, not globally good. k-means is doing coordinate descent on a non-convex objective. Coordinate descent finds local minima, but this objective has many of them — different starting points lead to different fixed points. We come back to this in §7.
§6 Choosing K
So far we have assumed someone told us . In practice, nobody does. Choosing is one of the trickiest parts of using k-means well, and there is no clean automatic answer. There are two widely-used heuristics: the elbow method and the silhouette score.
The elbow method. Run k-means with Plot final WCSS as a function of . The curve always goes down — more clusters can always fit data more tightly — but it usually has an elbow: a point where the rate of decrease drops sharply. The intuition is that before the elbow, each additional cluster captures genuine structure; after the elbow, each additional cluster is just splitting hairs. Pick at the elbow. This method is folk-wisdom grade — it works when there is an obvious elbow and fails when there isn't.
The silhouette score. For each point , define as the mean distance from to other points in its own cluster, and as the mean distance from to points in the nearest other cluster. The silhouette of point is:
This sits between and . Close to means is much closer to its own cluster than to any other — well-clustered. Close to means is on the boundary. Negative means has been assigned to the wrong cluster. The overall silhouette score is the average over all points.
Unlike WCSS, the silhouette has a peak, not an elbow. Plot it as a function of and pick the that maximises it.
The interactive lets you sweep from 2 to 8 across three datasets. On "three blobs" with true , both curves agree clearly: WCSS has a sharp elbow at and silhouette peaks at . On "two moons" the methods disagree (and both lie a bit), which is a sign that something is wrong with k-means' assumptions for this data — a topic we return to in §8.
The most honest summary: when these heuristics agree, trust them. When they disagree, k-means may not be the right algorithm for your data.
§7 Initialisation matters
We saw in §5 that k-means converges to a local minimum, and the local minimum it lands in depends on where the algorithm started. Run k-means twice on the same data with the same but different starting centroids: sometimes you get the same answer, often you do not. This is a real problem — a bad initialisation leads to a bad clustering, and there is no way to tell from inside the algorithm that something has gone wrong.
The simplest initialisation strategy is to pick points at random from the dataset and use them as starting centroids. This is the "random" option in our widgets. It works most of the time, but every so often you sample two starting points from the same true cluster, and the algorithm cannot recover — one true cluster ends up split between two centroids, and another gets absorbed into one.
A much better strategy is k-means++, introduced by Arthur and Vassilvitskii in 2007. The idea: pick the first centroid uniformly at random. Then for each subsequent centroid, sample with probability proportional to the squared distance from the nearest centroid already chosen. New centroids are more likely to land far from existing ones, which makes it much harder to accidentally put two centroids in the same true cluster.
In pseudocode:
choose c_1 uniformly at random from X
for k = 2, ..., K:
for each point x:
d(x) = squared distance from x to its nearest existing centroid
sample c_k from X with probability d(x) / sum(d)
That is the entire algorithm. It costs a single pass through the data for each new centroid, so time. In return you get dramatically more reliable convergence.
Compare them side by side. Press Play both and watch the two methods converge on the same data. Press New seed a few times to see how much variance the random method actually has.
On the anisotropic dataset in particular, random initialisation will fairly
often find a noticeably worse answer than k-means++. In production code,
always use k-means++. It is the default in sklearn.cluster.KMeans, which
is why you may not have noticed it exists.
§8 Where k-means fails
So far k-means has looked excellent. It always converges, it has reasonable heuristics for choosing , and with k-means++ it is reliably close to optimal. Time to be honest about what it cannot do.
The fundamental assumption baked into k-means is that clusters are isotropic and convex — roughly speaking, spherical blobs of similar size. The algorithm finds the partition of the plane that minimises within-cluster squared distance, which produces Voronoi cells, which are convex polygons. If your true clusters are not convex blobs, k-means cannot find them. No amount of running longer, picking better , or initialising better will fix this; the assumption is wrong at the root.
Three concrete failure modes you should recognise.
Non-convex clusters. The two-moons dataset is the canonical example. Two long curved bands interleaved. k-means with slices straight through both bands, putting half of each moon in each cluster. It is doing exactly what we asked — minimising within-cluster squared distance subject to the partition being convex — and the answer is wrong. Chapter 19's density-based methods (DBSCAN, OPTICS) solve this by abandoning the convexity assumption entirely.
Unequal cluster spreads. k-means implicitly assumes clusters have similar variances. When one cluster is much wider than another, the wider cluster's WCSS dominates the objective, and the algorithm will sometimes split the wide cluster into multiple smaller ones while combining smaller true clusters into one. Gaussian mixture models (Chapter 22's optional extensions) handle this by letting each cluster have its own covariance.
Outliers. Because WCSS uses squared distance, a single outlier can pull a centroid far from the rest of its cluster. There are robust variants — k-medoids, k-medians — that use absolute distance or actual data points as centres, but vanilla k-means is sensitive.
If you take one thing from this section: every clustering algorithm encodes an assumption about what counts as a cluster. k-means assumes blobs. When the data is blobs, it is excellent. When the data is anything else, you need a different algorithm.
§9 Complexity
The cost of one full k-means run is straightforward to count.
The assignment step computes the distance from each of points to each of centroids in -dimensional space — that is multiplications and additions. The update step is one pass through all points to accumulate sums, so operations. Together, one iteration costs .
The total cost is where is the number of iterations until convergence. In practice is very small — usually less than 20, often less than 10 — so k-means is effectively linear in once you fix the other parameters. That is why it scales so well.
Memory cost is : the data plus the centroids. The frame-by-frame history we have been visualising costs an extra factor of , but that is a teaching aid, not a production concern.
Two practical implications. First: k-means is one of the few unsupervised algorithms that runs comfortably on truly large datasets. Tens of millions of points are routine on a laptop; billions are routine in a distributed setting. Second: it is much cheaper than methods that compute pairwise distances () such as hierarchical clustering or DBSCAN. When is large and your data is roughly blob-shaped, k-means is hard to beat.
§10 Implementing it yourself
It is one thing to watch an algorithm. It is another to write it. The implementation below is the entire k-means algorithm in about a dozen lines of NumPy — no libraries, no sklearn, just arrays and arithmetic. Click Run and watch it converge. Then change something: the value of , the random seed, the number of points in each cluster. The Python is genuinely running in your browser; the first run takes a few seconds to download the runtime, and every run after that is instant.
import numpy as np
# 1. Make three Gaussian blobs of 50 points each.
np.random.seed(42)
true_centres = np.array([[2.0, 2.0], [-2.0, 2.0], [0.0, -2.0]])
X = np.vstack([
true_centres[k] + np.random.randn(50, 2) * 0.5
for k in range(3)
])
# 2. Initialise K centroids by sampling K points from X.
K = 3
rng = np.random.default_rng(0)
centroids = X[rng.choice(len(X), K, replace=False)].copy()
# 3. Iterate: assign, update, repeat.
for step in range(20):
# Assignment step: each point goes to its nearest centroid.
distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
assignments = np.argmin(distances, axis=1)
# Update step: each centroid moves to the mean of its assigned points.
new_centroids = np.array([
X[assignments == k].mean(axis=0)
for k in range(K)
])
# Convergence check.
if np.allclose(new_centroids, centroids):
print(f"Converged after {step + 1} iterations.\n")
break
centroids = new_centroids
print("Final centroids:")
for k, c in enumerate(centroids):
print(f" cluster {k}: ({c[0]:+.3f}, {c[1]:+.3f})")
print(f"\nTrue centres were:")
for k, c in enumerate(true_centres):
print(f" cluster {k}: ({c[0]:+.3f}, {c[1]:+.3f})")
Two things worth noticing in this implementation.
First, the assignment step is one line. X[:, np.newaxis] - centroids uses
broadcasting to compute the difference between every point and every centroid
at once — the result has shape (n_points, K, 2). np.linalg.norm(..., axis=2)
collapses the last axis (the 2D feature dimension) into a single distance,
giving an (n_points, K) matrix of distances. np.argmin(..., axis=1) then
picks the closest centroid per point. The whole vectorised step does the same
work as a nested loop, far faster.
Second, the recovered centroids will not match the true centres exactly,
and they shouldn't — the data is noisy, and the algorithm only knows about
the noisy version. What is striking is how close it gets. Try increasing the
noise (* 0.5 → * 1.2) and watch the recovered centroids drift further
from the truth; try decreasing it (* 0.5 → * 0.1) and watch them snap
back. This is your first encounter with a recurring theme in unsupervised
learning: the algorithm can only ever recover the structure that is actually
present in the data.
§11 Problems
Five problems, ordered from conceptual through implementation. The first is think-only; the rest run in the editor. Worked solutions are hidden behind the Show solution toggles — try each problem first before peeking.
Problem 1 — Why not just set K = n?
Imagine you set equal to , the number of data points. What does WCSS evaluate to at convergence? What does that tell you about using WCSS alone to choose ?
Show solution
When , each point becomes its own cluster — and each cluster's mean is the single point inside it. The distance from every point to its centroid is zero, so WCSS is exactly zero.
This is the lowest WCSS achievable on any dataset, and it is reached by an algorithm that has learned nothing about the data's structure. The "clustering" just memorises the points. This is why WCSS alone cannot tell us the right : bigger is always "better" in WCSS terms, so optimising WCSS without any other constraint sends you straight to . Heuristics like the elbow method and the silhouette score exist precisely to break this degeneracy — they reward tight clusters while penalising excessive splitting.
The deeper lesson generalises: any loss function that strictly improves as you add capacity needs an external constraint (a held-out test set, a penalty term, a heuristic) to choose how much capacity to use. We will see this theme again in regression with regularisation, in trees with depth, in neural networks with everything.
Problem 2 — Implement the assignment step
Given a dataset and a set of centroids, compute the cluster assignment
for every point. Aim for a vectorised solution — no Python for loop over
points.
import numpy as np
np.random.seed(42)
X = np.vstack([
np.array([2, 2]) + np.random.randn(20, 2) * 0.5,
np.array([-2, 2]) + np.random.randn(20, 2) * 0.5,
np.array([0, -2]) + np.random.randn(20, 2) * 0.5,
])
centroids = np.array([[1.8, 1.9], [-2.1, 2.2], [0.1, -1.7]])
# YOUR CODE: compute an array `assignments` of shape (60,) where
# assignments[i] is the index of the centroid closest to X[i].
assignments = ... # TODO
# Sanity check: how many points went to each cluster?
unique, counts = np.unique(assignments, return_counts=True)
for k, c in zip(unique, counts):
print(f"cluster {k}: {c} points")
Show solution
Broadcasting subtracts every centroid from every point, then we take the norm and the argmin:
distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
assignments = np.argmin(distances, axis=1)
You should see roughly twenty points in each cluster.
Problem 3 — Implement k-means++ initialisation
Pick starting centroids from using the k-means++ rule from §7: the first uniformly at random, each subsequent one sampled with probability proportional to its squared distance from the nearest existing centroid.
import numpy as np
np.random.seed(42)
true_centres = np.array([[3, 3], [-3, 3], [0, -3]])
X = np.vstack([c + np.random.randn(40, 2) * 0.5 for c in true_centres])
K = 3
rng = np.random.default_rng(0)
# Implement k-means++:
# 1. Pick the first centroid uniformly from X.
# 2. For each of the next K - 1 centroids, sample with probability
# proportional to squared distance from the nearest existing centroid.
centroids = [] # TODO: fill with K starting points
for i, c in enumerate(centroids):
print(f"centroid {i}: ({c[0]:+.2f}, {c[1]:+.2f})")
Show solution
Track existing centroids, recompute squared distances each round, and sample
with rng.choice weighted by those distances:
centroids = [X[rng.integers(len(X))]]
for _ in range(K - 1):
diffs = X[:, None, :] - np.array(centroids)[None, :, :]
sq = (diffs ** 2).sum(axis=2)
nearest_sq = sq.min(axis=1)
probs = nearest_sq / nearest_sq.sum()
centroids.append(X[rng.choice(len(X), p=probs)])
Compared to picking three random points, k-means++ will reliably space the centroids near the three true cluster centres.
Problem 4 — Vectorise the update step
The from-scratch implementation in §10 has a Python loop over in the update step. Replace it with a single vectorised expression. Hint: build a one-hot matrix of assignments and use a matrix multiply.
import numpy as np
np.random.seed(42)
X = np.random.randn(100, 2)
K = 4
assignments = np.random.randint(0, K, size=100)
# Slow reference (do not change):
slow = np.array([X[assignments == k].mean(axis=0) for k in range(K)])
# YOUR CODE: produce `fast` with the same values, no Python loop over K.
fast = ... # TODO
print("match:", np.allclose(slow, fast))
print("slow:")
print(slow)
print("fast:")
print(fast)
Show solution
The one-hot matrix has shape (n, K); multiplying its transpose by X gives
unnormalised sums, and dividing by the cluster sizes gives means:
one_hot = np.eye(K)[assignments] # (n, K)
sums = one_hot.T @ X # (K, p)
counts = one_hot.sum(axis=0).reshape(K, 1) # (K, 1)
fast = sums / counts
This avoids the explicit loop over and is significantly faster for large , because it pushes the whole thing into BLAS-backed matrix multiplication.
Problem 5 — Putting it together
Combine your work from problems 2 and 4 (and optionally 3) into a complete k-means implementation. Run it on an anisotropic dataset and report the WCSS at convergence.
import numpy as np
# Anisotropic data: three elongated clusters.
np.random.seed(42)
centres = np.array([[2, 2], [-2, 2], [0, -2]])
R = np.array([[np.cos(0.6), -np.sin(0.6)], [np.sin(0.6), np.cos(0.6)]])
X = np.vstack([
c + (np.random.randn(50, 2) * np.array([1.5, 0.3])) @ R.T
for c in centres
])
K = 3
rng = np.random.default_rng(0)
centroids = X[rng.choice(len(X), K, replace=False)].copy()
for step in range(50):
# YOUR CODE: assignment step (use the vectorised version from Problem 2).
assignments = ... # TODO
# YOUR CODE: update step (use the vectorised version from Problem 4).
new_centroids = ... # TODO
shift = np.linalg.norm(new_centroids - centroids, axis=1).max()
if shift < 1e-4:
print(f"Converged at iteration {step + 1}.")
break
centroids = new_centroids
wcss = ((X - centroids[assignments]) ** 2).sum()
print(f"Final WCSS: {wcss:.2f}")
print("Centroids:")
for k, c in enumerate(centroids):
print(f" {k}: ({c[0]:+.2f}, {c[1]:+.2f})")
Show solution
Plug in the assignment and update steps from problems 2 and 4:
distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
assignments = np.argmin(distances, axis=1)
one_hot = np.eye(K)[assignments]
sums = one_hot.T @ X
counts = one_hot.sum(axis=0).reshape(K, 1)
new_centroids = sums / counts
You should see convergence in around 5–10 iterations and a final WCSS in the low hundreds. Try replacing the random initialisation with your k-means++ implementation from problem 3 — the answer should be at least as good and often noticeably better.
Next: Chapter 19 — Hierarchical and density-based clustering.