Chapter 18

k-means clustering

Finding structure in data without being told what to look for. The most widely-used clustering algorithm, watched step by step.

Imagine you have data and no labels. Just points. A spreadsheet of customers with their purchase histories. A folder of unlabelled photographs. A pile of text documents.

You suspect there's structure in there — that customers fall into a few natural types, that the photographs sort themselves into visual groups, that the documents are about a handful of topics. But no one has told you what the groups are. You have to discover them.

This is clustering: the problem of finding structure in data without being told what to look for. And the most widely-used clustering algorithm, the one nearly every textbook starts with, is k-means.

Two reasons to start here. First, k-means is genuinely useful — it runs in production at companies you have heard of, doing things like customer segmentation, image compression, and anomaly detection. Second, k-means is one of the most beautifully visualisable algorithms in all of machine learning. By the end of this chapter you will have watched it converge, iteration by iteration, on three different datasets, and you will have seen exactly how and why it sometimes fails.

§1 What is clustering?

Clustering is unsupervised learning. Recall that in supervised learning (Chapters 4 through 17), we had labelled examples: each input came with the correct output, and our job was to learn the function that maps from one to the other. In unsupervised learning there are no labels. We only have the inputs. The goal is to discover structure in those inputs alone.

The structure we are looking for, in this chapter, is groups. The intuition is simple. A good clustering is one where points in the same group look similar, and points in different groups look different. "Similar" usually means close together in some space. So clustering reduces to: given a cloud of points, partition them into groups such that nearby points end up in the same group and faraway points end up in different ones.

Some real applications:

What unites these applications is that we do not know the groups ahead of time. The algorithm finds them.

§2 The k-means algorithm

There are many ways to cluster. k-means is one. Its strategy is almost embarrassingly direct.

We start by picking a number, KK, which is how many groups we expect to find. Choosing KK is its own problem — we come back to it in §6. Then we make a guess: "the centres of the KK groups are at these KK positions." Probably a terrible guess. We refine it.

The refinement loop has two steps:

  1. Assignment. Look at every point. Assign it to whichever centre is closest.
  2. Update. Each centre moves to the average position of the points assigned to it.

Then repeat. And repeat. And repeat. Until nothing changes.

That is the algorithm. In pseudocode:

choose K, initialise K centroids
repeat:
    assign each point to the nearest centroid
    move each centroid to the mean of its assigned points
until no centroid moves

It might not be obvious that this ever terminates, or that the answer is any good when it does. Both facts must be earned — we earn them in §5. For now, look at what the algorithm actually does. Pick a dataset shape, choose KK, and step through the iterations one by one.

Iteration
0 / 8
WCSS (inertia)
Status
initial centroids placed
Speed
Figure 18.1 — k-means convergence on synthetic data. Modify any control and the algorithm re-runs from scratch.

A few things to notice as you play with it:

The rest of this chapter is just unpacking what you have already watched.

§3 The assignment step

The assignment step is the easy one. For each point xix_i, we want to find the closest centroid μk\mu_k. "Closest" is measured by some distance function dd. The most common choice is Euclidean distance:

d(xi,μk)=j=1p(xijμkj)2d(x_i, \mu_k) = \sqrt{\sum_{j=1}^{p} (x_{ij} - \mu_{kj})^2}

where pp is the number of features of each point. For our 2D visualisation, p=2p = 2. The assignment rule then picks the cluster whose centroid is closest:

ci=argmink{1,,K}  d(xi,μk)c_i = \arg\min_{k \in \{1, \dots, K\}} \; d(x_i, \mu_k)

A small efficiency note. The square root in Euclidean distance is monotonic: if d(a,b)<d(c,e)d(a, b) < d(c, e) then d(a,b)2<d(c,e)2d(a, b)^2 < d(c, e)^2. When all we want is the argmin — which centroid is closest — we can skip the square root and compare squared distances instead. Faster, and the answer is the same.

The geometric consequence of this assignment rule is one of the most elegant results in computational geometry. The plane gets divided into regions, one per centroid, where each region contains exactly the points closest to its centroid. This is called a Voronoi tessellation, and it looks like cells in a honeycomb. Every time you watch k-means run, you are watching the Voronoi cells reshape themselves.

§4 The update step

We have been hand-waving over why the centroid update is the mean. Time to earn it.

The update step says: each centroid moves to the average of the points currently assigned to it. But the average is an odd choice — why not the median, or the centre of the bounding box, or whatever point is most central by some other measure? The answer is that the mean is the unique position that minimises the sum of squared distances to the cluster's points. If we want WCSS to decrease, the mean is provably the best place to put each centroid.

The derivation is short. Suppose cluster CkC_k contains points x1,,xmx_1, \dots, x_m. We want to choose a centroid position μ\mu that minimises:

L(μ)=i=1mxiμ2L(\mu) = \sum_{i=1}^{m} \|x_i - \mu\|^2

Take the gradient with respect to μ\mu and set it to zero:

μL=2i=1m(xiμ)=0\nabla_\mu L = -2 \sum_{i=1}^{m} (x_i - \mu) = 0

which solves to:

μ=1mi=1mxi\mu = \frac{1}{m} \sum_{i=1}^{m} x_i

The mean. The Hessian is 2mI2m I — positive definite, so this is the unique global minimum, not a maximum or saddle point.

You do not have to take the math on faith. The widget below shows the same fact experimentally. Drag the centroid around a single cluster of points; the running WCSS value updates live. The marked position is the mean. The minimum WCSS happens exactly there — and the further you drag the centroid from the mean, the worse WCSS gets, in proportion to the squared distance you have moved away.

1.5× min2.5× min4× minmean
WCSS at centroid
583,806
× minimum
3.37×
WCSS at the mean
173,194
Drag the centroid · dashed rings are WCSS contours
Figure 18.2 — WCSS as a function of centroid position for a single cluster. The level rings are contours of the WCSS surface — concentric circles centred on the mean, where WCSS doubles roughly every in radius.

This is one of those rare moments where the theory and the experiment align so cleanly that they feel almost trivial. They are not trivial, though — the same logic generalises directly to weighted k-means, to fuzzy c-means, to mixture models, to the M-step of EM. All of them are minimising some sum of squared distances, and the answer at every step is some kind of mean.

§5 Why k-means converges

Two questions deserve answers before we trust this algorithm. First: does it always terminate, or could it run forever? Second: when it terminates, is the answer any good?

The first question has a clean answer. Look at what WCSS does over the course of the algorithm.

When the algorithm starts, WCSS has some value. The first assignment step picks the closest centroid for each point — and "closest" is exactly the choice that minimises that point's contribution to WCSS, given the current centroids. So the assignment step can only lower WCSS, or leave it unchanged. Then the update step moves each centroid to the mean of its assigned points — which, by §4, is exactly the position that minimises WCSS given the current assignments. So the update step can only lower WCSS, or leave it unchanged, too.

Both steps non-increasing. WCSS is bounded below by zero. So the WCSS sequence converges. That doesn't immediately mean the algorithm stops, but observe: there are only finitely many ways to assign nn points to KK clusters — KnK^n of them, in fact. Once WCSS stops strictly decreasing, the assignments must repeat, which means the algorithm has reached a fixed point where neither step changes anything. We terminate.

Watch this empirically:

WCSSiteration
Iteration
0 / 6
WCSS now
Status
initial centroids placed
Figure 18.3 — Clustering state alongside its WCSS. Each iteration the WCSS goes down, never up.

Two things to internalise from this view. The WCSS curve only ever points down. And the first iteration does almost all the work — the curve drops steeply, then levels off, often converging in three to six iterations even on substantial datasets.

That answers the does it terminate question. The is the answer any good question is more subtle, and the answer is: locally good, not globally good. k-means is doing coordinate descent on a non-convex objective. Coordinate descent finds local minima, but this objective has many of them — different starting points lead to different fixed points. We come back to this in §7.

§6 Choosing K

So far we have assumed someone told us KK. In practice, nobody does. Choosing KK is one of the trickiest parts of using k-means well, and there is no clean automatic answer. There are two widely-used heuristics: the elbow method and the silhouette score.

The elbow method. Run k-means with K=1,2,3,K = 1, 2, 3, \dots Plot final WCSS as a function of KK. The curve always goes down — more clusters can always fit data more tightly — but it usually has an elbow: a point where the rate of decrease drops sharply. The intuition is that before the elbow, each additional cluster captures genuine structure; after the elbow, each additional cluster is just splitting hairs. Pick KK at the elbow. This method is folk-wisdom grade — it works when there is an obvious elbow and fails when there isn't.

The silhouette score. For each point ii, define a(i)a(i) as the mean distance from ii to other points in its own cluster, and b(i)b(i) as the mean distance from ii to points in the nearest other cluster. The silhouette of point ii is:

s(i)=b(i)a(i)max(a(i),b(i))s(i) = \frac{b(i) - a(i)}{\max(a(i), b(i))}

This sits between 1-1 and 11. Close to 11 means ii is much closer to its own cluster than to any other — well-clustered. Close to 00 means ii is on the boundary. Negative means ii has been assigned to the wrong cluster. The overall silhouette score is the average over all points.

Unlike WCSS, the silhouette has a peak, not an elbow. Plot it as a function of KK and pick the KK that maximises it.

WCSS by Klower is tighter
1,236,032K2345678
Silhouette by Khigher is better separated
0.76K2345678peak
Figure 18.4 — As you sweep K, the WCSS falls monotonically while the silhouette peaks. The peak (or the elbow) is your best K.

The interactive lets you sweep KK from 2 to 8 across three datasets. On "three blobs" with true K=3K = 3, both curves agree clearly: WCSS has a sharp elbow at K=3K = 3 and silhouette peaks at K=3K = 3. On "two moons" the methods disagree (and both lie a bit), which is a sign that something is wrong with k-means' assumptions for this data — a topic we return to in §8.

The most honest summary: when these heuristics agree, trust them. When they disagree, k-means may not be the right algorithm for your data.

§7 Initialisation matters

We saw in §5 that k-means converges to a local minimum, and the local minimum it lands in depends on where the algorithm started. Run k-means twice on the same data with the same KK but different starting centroids: sometimes you get the same answer, often you do not. This is a real problem — a bad initialisation leads to a bad clustering, and there is no way to tell from inside the algorithm that something has gone wrong.

The simplest initialisation strategy is to pick KK points at random from the dataset and use them as starting centroids. This is the "random" option in our widgets. It works most of the time, but every so often you sample two starting points from the same true cluster, and the algorithm cannot recover — one true cluster ends up split between two centroids, and another gets absorbed into one.

A much better strategy is k-means++, introduced by Arthur and Vassilvitskii in 2007. The idea: pick the first centroid uniformly at random. Then for each subsequent centroid, sample with probability proportional to the squared distance from the nearest centroid already chosen. New centroids are more likely to land far from existing ones, which makes it much harder to accidentally put two centroids in the same true cluster.

In pseudocode:

choose c_1 uniformly at random from X
for k = 2, ..., K:
    for each point x:
        d(x) = squared distance from x to its nearest existing centroid
    sample c_k from X with probability d(x) / sum(d)

That is the entire algorithm. It costs a single pass through the data for each new centroid, so O(nK)O(nK) time. In return you get dramatically more reliable convergence.

Compare them side by side. Press Play both and watch the two methods converge on the same data. Press New seed a few times to see how much variance the random method actually has.

Random initialisation
k-means++
Final WCSS: random 599,086 · k-means++ 599,086
0 / 8
Figure 18.5 — Same data, same K, two different initialisations. Press new seed a few times — random sometimes finds the right answer, sometimes does not. k-means++ is far more reliable.

On the anisotropic dataset in particular, random initialisation will fairly often find a noticeably worse answer than k-means++. In production code, always use k-means++. It is the default in sklearn.cluster.KMeans, which is why you may not have noticed it exists.

§8 Where k-means fails

So far k-means has looked excellent. It always converges, it has reasonable heuristics for choosing KK, and with k-means++ it is reliably close to optimal. Time to be honest about what it cannot do.

The fundamental assumption baked into k-means is that clusters are isotropic and convex — roughly speaking, spherical blobs of similar size. The algorithm finds the partition of the plane that minimises within-cluster squared distance, which produces Voronoi cells, which are convex polygons. If your true clusters are not convex blobs, k-means cannot find them. No amount of running longer, picking better KK, or initialising better will fix this; the assumption is wrong at the root.

Three concrete failure modes you should recognise.

Non-convex clusters. The two-moons dataset is the canonical example. Two long curved bands interleaved. k-means with K=2K = 2 slices straight through both bands, putting half of each moon in each cluster. It is doing exactly what we asked — minimising within-cluster squared distance subject to the partition being convex — and the answer is wrong. Chapter 19's density-based methods (DBSCAN, OPTICS) solve this by abandoning the convexity assumption entirely.

Unequal cluster spreads. k-means implicitly assumes clusters have similar variances. When one cluster is much wider than another, the wider cluster's WCSS dominates the objective, and the algorithm will sometimes split the wide cluster into multiple smaller ones while combining smaller true clusters into one. Gaussian mixture models (Chapter 22's optional extensions) handle this by letting each cluster have its own covariance.

Outliers. Because WCSS uses squared distance, a single outlier can pull a centroid far from the rest of its cluster. There are robust variants — k-medoids, k-medians — that use absolute distance or actual data points as centres, but vanilla k-means is sensitive.

If you take one thing from this section: every clustering algorithm encodes an assumption about what counts as a cluster. k-means assumes blobs. When the data is blobs, it is excellent. When the data is anything else, you need a different algorithm.

§9 Complexity

The cost of one full k-means run is straightforward to count.

The assignment step computes the distance from each of nn points to each of KK centroids in pp-dimensional space — that is n×K×pn \times K \times p multiplications and additions. The update step is one pass through all nn points to accumulate sums, so n×pn \times p operations. Together, one iteration costs O(nKp)O(n \cdot K \cdot p).

The total cost is O(nKpi)O(n \cdot K \cdot p \cdot i) where ii is the number of iterations until convergence. In practice ii is very small — usually less than 20, often less than 10 — so k-means is effectively linear in nn once you fix the other parameters. That is why it scales so well.

Memory cost is O(np+Kp)O(np + Kp): the data plus the centroids. The frame-by-frame history we have been visualising costs an extra factor of ii, but that is a teaching aid, not a production concern.

Two practical implications. First: k-means is one of the few unsupervised algorithms that runs comfortably on truly large datasets. Tens of millions of points are routine on a laptop; billions are routine in a distributed setting. Second: it is much cheaper than methods that compute pairwise distances (O(n2)O(n^2)) such as hierarchical clustering or DBSCAN. When nn is large and your data is roughly blob-shaped, k-means is hard to beat.

§10 Implementing it yourself

It is one thing to watch an algorithm. It is another to write it. The implementation below is the entire k-means algorithm in about a dozen lines of NumPy — no libraries, no sklearn, just arrays and arithmetic. Click Run and watch it converge. Then change something: the value of KK, the random seed, the number of points in each cluster. The Python is genuinely running in your browser; the first run takes a few seconds to download the runtime, and every run after that is instant.

Python · runs in browser
import numpy as np

# 1. Make three Gaussian blobs of 50 points each.
np.random.seed(42)
true_centres = np.array([[2.0, 2.0], [-2.0, 2.0], [0.0, -2.0]])
X = np.vstack([
  true_centres[k] + np.random.randn(50, 2) * 0.5
  for k in range(3)
])

# 2. Initialise K centroids by sampling K points from X.
K = 3
rng = np.random.default_rng(0)
centroids = X[rng.choice(len(X), K, replace=False)].copy()

# 3. Iterate: assign, update, repeat.
for step in range(20):
  # Assignment step: each point goes to its nearest centroid.
  distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
  assignments = np.argmin(distances, axis=1)

  # Update step: each centroid moves to the mean of its assigned points.
  new_centroids = np.array([
      X[assignments == k].mean(axis=0)
      for k in range(K)
  ])

  # Convergence check.
  if np.allclose(new_centroids, centroids):
      print(f"Converged after {step + 1} iterations.\n")
      break
  centroids = new_centroids

print("Final centroids:")
for k, c in enumerate(centroids):
  print(f"  cluster {k}: ({c[0]:+.3f}, {c[1]:+.3f})")
print(f"\nTrue centres were:")
for k, c in enumerate(true_centres):
  print(f"  cluster {k}: ({c[0]:+.3f}, {c[1]:+.3f})")
Listing 18.1 — k-means from scratch in NumPy.

Two things worth noticing in this implementation.

First, the assignment step is one line. X[:, np.newaxis] - centroids uses broadcasting to compute the difference between every point and every centroid at once — the result has shape (n_points, K, 2). np.linalg.norm(..., axis=2) collapses the last axis (the 2D feature dimension) into a single distance, giving an (n_points, K) matrix of distances. np.argmin(..., axis=1) then picks the closest centroid per point. The whole vectorised step does the same work as a nested loop, far faster.

Second, the recovered centroids will not match the true centres exactly, and they shouldn't — the data is noisy, and the algorithm only knows about the noisy version. What is striking is how close it gets. Try increasing the noise (* 0.5* 1.2) and watch the recovered centroids drift further from the truth; try decreasing it (* 0.5* 0.1) and watch them snap back. This is your first encounter with a recurring theme in unsupervised learning: the algorithm can only ever recover the structure that is actually present in the data.

§11 Problems

Five problems, ordered from conceptual through implementation. The first is think-only; the rest run in the editor. Worked solutions are hidden behind the Show solution toggles — try each problem first before peeking.

Problem 1 — Why not just set K = n?

Imagine you set KK equal to nn, the number of data points. What does WCSS evaluate to at convergence? What does that tell you about using WCSS alone to choose KK?

Show solution

When K=nK = n, each point becomes its own cluster — and each cluster's mean is the single point inside it. The distance from every point to its centroid is zero, so WCSS is exactly zero.

This is the lowest WCSS achievable on any dataset, and it is reached by an algorithm that has learned nothing about the data's structure. The "clustering" just memorises the points. This is why WCSS alone cannot tell us the right KK: bigger KK is always "better" in WCSS terms, so optimising WCSS without any other constraint sends you straight to K=nK = n. Heuristics like the elbow method and the silhouette score exist precisely to break this degeneracy — they reward tight clusters while penalising excessive splitting.

The deeper lesson generalises: any loss function that strictly improves as you add capacity needs an external constraint (a held-out test set, a penalty term, a heuristic) to choose how much capacity to use. We will see this theme again in regression with regularisation, in trees with depth, in neural networks with everything.

Problem 2 — Implement the assignment step

Given a dataset XX and a set of centroids, compute the cluster assignment for every point. Aim for a vectorised solution — no Python for loop over points.

Python · runs in browser
import numpy as np

np.random.seed(42)
X = np.vstack([
  np.array([2, 2]) + np.random.randn(20, 2) * 0.5,
  np.array([-2, 2]) + np.random.randn(20, 2) * 0.5,
  np.array([0, -2]) + np.random.randn(20, 2) * 0.5,
])
centroids = np.array([[1.8, 1.9], [-2.1, 2.2], [0.1, -1.7]])

# YOUR CODE: compute an array `assignments` of shape (60,) where
# assignments[i] is the index of the centroid closest to X[i].
assignments = ...  # TODO

# Sanity check: how many points went to each cluster?
unique, counts = np.unique(assignments, return_counts=True)
for k, c in zip(unique, counts):
  print(f"cluster {k}: {c} points")
Show solution

Broadcasting subtracts every centroid from every point, then we take the norm and the argmin:

distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
assignments = np.argmin(distances, axis=1)

You should see roughly twenty points in each cluster.

Problem 3 — Implement k-means++ initialisation

Pick K=3K = 3 starting centroids from XX using the k-means++ rule from §7: the first uniformly at random, each subsequent one sampled with probability proportional to its squared distance from the nearest existing centroid.

Python · runs in browser
import numpy as np

np.random.seed(42)
true_centres = np.array([[3, 3], [-3, 3], [0, -3]])
X = np.vstack([c + np.random.randn(40, 2) * 0.5 for c in true_centres])

K = 3
rng = np.random.default_rng(0)

# Implement k-means++:
# 1. Pick the first centroid uniformly from X.
# 2. For each of the next K - 1 centroids, sample with probability
#    proportional to squared distance from the nearest existing centroid.
centroids = []  # TODO: fill with K starting points

for i, c in enumerate(centroids):
  print(f"centroid {i}: ({c[0]:+.2f}, {c[1]:+.2f})")
Show solution

Track existing centroids, recompute squared distances each round, and sample with rng.choice weighted by those distances:

centroids = [X[rng.integers(len(X))]]
for _ in range(K - 1):
    diffs = X[:, None, :] - np.array(centroids)[None, :, :]
    sq = (diffs ** 2).sum(axis=2)
    nearest_sq = sq.min(axis=1)
    probs = nearest_sq / nearest_sq.sum()
    centroids.append(X[rng.choice(len(X), p=probs)])

Compared to picking three random points, k-means++ will reliably space the centroids near the three true cluster centres.

Problem 4 — Vectorise the update step

The from-scratch implementation in §10 has a Python loop over KK in the update step. Replace it with a single vectorised expression. Hint: build a one-hot matrix of assignments and use a matrix multiply.

Python · runs in browser
import numpy as np

np.random.seed(42)
X = np.random.randn(100, 2)
K = 4
assignments = np.random.randint(0, K, size=100)

# Slow reference (do not change):
slow = np.array([X[assignments == k].mean(axis=0) for k in range(K)])

# YOUR CODE: produce `fast` with the same values, no Python loop over K.
fast = ...  # TODO

print("match:", np.allclose(slow, fast))
print("slow:")
print(slow)
print("fast:")
print(fast)
Show solution

The one-hot matrix has shape (n, K); multiplying its transpose by X gives unnormalised sums, and dividing by the cluster sizes gives means:

one_hot = np.eye(K)[assignments]          # (n, K)
sums = one_hot.T @ X                       # (K, p)
counts = one_hot.sum(axis=0).reshape(K, 1) # (K, 1)
fast = sums / counts

This avoids the explicit loop over KK and is significantly faster for large KK, because it pushes the whole thing into BLAS-backed matrix multiplication.

Problem 5 — Putting it together

Combine your work from problems 2 and 4 (and optionally 3) into a complete k-means implementation. Run it on an anisotropic dataset and report the WCSS at convergence.

Python · runs in browser
import numpy as np

# Anisotropic data: three elongated clusters.
np.random.seed(42)
centres = np.array([[2, 2], [-2, 2], [0, -2]])
R = np.array([[np.cos(0.6), -np.sin(0.6)], [np.sin(0.6), np.cos(0.6)]])
X = np.vstack([
  c + (np.random.randn(50, 2) * np.array([1.5, 0.3])) @ R.T
  for c in centres
])

K = 3
rng = np.random.default_rng(0)
centroids = X[rng.choice(len(X), K, replace=False)].copy()

for step in range(50):
  # YOUR CODE: assignment step (use the vectorised version from Problem 2).
  assignments = ...  # TODO

  # YOUR CODE: update step (use the vectorised version from Problem 4).
  new_centroids = ...  # TODO

  shift = np.linalg.norm(new_centroids - centroids, axis=1).max()
  if shift < 1e-4:
      print(f"Converged at iteration {step + 1}.")
      break
  centroids = new_centroids

wcss = ((X - centroids[assignments]) ** 2).sum()
print(f"Final WCSS: {wcss:.2f}")
print("Centroids:")
for k, c in enumerate(centroids):
  print(f"  {k}: ({c[0]:+.2f}, {c[1]:+.2f})")
Show solution

Plug in the assignment and update steps from problems 2 and 4:

distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
assignments = np.argmin(distances, axis=1)

one_hot = np.eye(K)[assignments]
sums = one_hot.T @ X
counts = one_hot.sum(axis=0).reshape(K, 1)
new_centroids = sums / counts

You should see convergence in around 5–10 iterations and a final WCSS in the low hundreds. Try replacing the random initialisation with your k-means++ implementation from problem 3 — the answer should be at least as good and often noticeably better.


Next: Chapter 19 — Hierarchical and density-based clustering.