So, why is attention quadratic?

We hear about attention being unscalable because it's scales quadratically. When I was a beginner in ML, I used to wonder what this was all about.

What is quadratic about it? why? This simple explanation is catered towards non-AI technical folks with some lin alg knowledge. Also, I am assuming this is plain attention (no multi-heads), batch=1 and no kv caching (btw rumor is anthropic did not have kv caching in it's early days) for simplicity of explanation.

Self-attention diagram (Q, K, V) inside a Transformer decoder block with masking, scaling and softmax

The diagram above shows where attention happens in the decoder block & the formula for attention. Look at the Q @ Kᵀ part of the formula above. Q, K and V are of shape (Sequence_length, Hidden_dim), when Q and K transpose are matmul'd they result in a matrix of shape (Sequence_length, Sequence_length).

During inference, while the sequence increases by 1 for each token, both time to compute and memory to store the result is O(L²). So, when

Causal self-attention mask: each token attends to itself and past tokens only (lower-triangular)

Hope this makes everything clear.