Implement scaled dot-product attention in NumPy. Then extend it to multi-head attention. Walk me through the shapes at each step.
You implement attention with X @ W_q as a single matrix multiply. In PyTorch's actual nn.MultiheadAttention, the Q, K, V projections are sometimes fused into one matrix. How would you implement that, and why might it be faster?
tldr
Attention = softmax(QK^T / sqrt(d_k)) @ V. The sqrt(d_k) scaling prevents dot products from growing large enough to saturate softmax. Multi-head splits the representation into num_heads subspaces, runs attention independently in each, then concatenates — this lets different heads learn different relationship patterns. Shapes: after splitting, each head sees (num_heads, seq_len, d_k) where d_k = d_model / num_heads. Fused QKV projection merges three matmuls into one for better GPU utilization.
follow-up
- FlashAttention is described as "IO-aware" attention — what memory bottleneck is it solving, and how does tiling help?
- How would you modify this implementation to support cross-attention (Q from one sequence, K and V from another), as used in encoder-decoder transformers?
- Implement positional encoding in NumPy — both the original sinusoidal version and RoPE (rotary position embeddings).