Pre-Requisites to know [Click on the topic header of these below to make it collapse or expand]

Flash Attention from a Beginner's Point of View

23rd March, 2025

Transformers are amazing, right? Their attention mechanisms—like self-attention and multi-head attention—make them super powerful for understanding context in text, translations, and more. But there’s a catch: traditional attention can be a memory hog and slowpoke, especially when dealing with long sequences. Enter Flash Attention, a clever optimization that makes attention faster and more efficient without sacrificing accuracy. Let’s dive in!

This is just merely part I of what Flash attention is in upcoming parts of Flash attention we will be diving deeper into flash attention internals and implementation in pytorch

What is Flash Attention and Why is it Needed?

Attention, as we know, computes a weighted sum of values (V) based on how relevant each key (K) is to a query (Q). The problem? For a sequence of length N, the attention mechanism needs to create an N × N attention score matrix. That’s a lot of memory! For example, if N = 10,000 (think a long document), you’re storing 100 million numbers just for that matrix. GPUs, which power these models, choke on this because they have limited memory, and moving data back and forth slows everything down.

Flash Attention is a way to compute attention without ever fully building that giant N × N matrix in memory. Instead, it processes the data in smaller chunks (or "tiles") and does the math on-the-fly. It’s like solving a puzzle piece by piece instead of laying out the whole picture at once. This saves memory, speeds things up, and lets Transformers handle much longer sequences—like entire books—without crashing.

Why is it Needed?

How Does Flash Attention Work?

Let’s break it down :

Flash Attention says, “Why store that huge score matrix?” Instead, it:

Here’s the trick: it avoids storing the full attention score matrix by recomputing intermediate results when necessary and using clever math to keep the softmax accurate across blocks. This happens entirely on the GPU’s fast memory (SRAM), skipping the slower main memory (HBM).

Step-by-Step:

  1. Input: Q, K, and V matrices (say, each is N × d, where d is the embedding size).

  2. Tiling: Break them into smaller chunks (e.g., blocks of size B × d, where B << N).

  3. Inner Loop: For each block of Q, compute attention with all blocks of K and V, but only store tiny temporary results.

  4. Softmax Trick: Use a running normalization to combine results across blocks without ever needing the full matrix.

  5. Output: Build the final attention output block-by-block.

The result? Same output as regular attention, but way less memory and faster computation.

Mathematical Representation

Let’s keep it simple but precise. Regular attention computes:

Attention(Q, K, V) = softmax(QKT√dk)V

Where:

Flash Attention avoids materializing \( QK^T \) fully. Instead:

The magic is in the incremental softmax, which uses two extra variables (a max and a sum) to stitch everything together accurately without the full matrix.

With Flash Attention, the operator is smarter. They only look at one section of the stage at a time (a block), figure out who’s important in that moment, and adjust the spotlight right away. They keep a tiny notepad (fast GPU memory) with just enough info to move to the next section. The play (computation) goes on smoothly, and the audience (model) still sees the full story—no one notices the operator’s trick!

Flash Attention in the Transformer Architecture

Flash Attention fits right into the Transformer’s attention layers:

It’s plug-and-play: same Transformer, just faster and leaner.

Why Is Flash Attention a Big Deal?

Intuition: Flash Attention as a Spotlight

Remember our theater analogy? In regular attention, the spotlight operator (Attention) scans the entire stage (all tokens) at once, writing down a huge list of who’s important (the N × N matrix) before deciding where to shine the light. With a big stage, that list gets unwieldy, and the operator runs out of paper (memory).
With Flash Attention, the operator is smarter. They only look at one section of the stage at a time (a block), figure out who’s important in that moment, and adjust the spotlight right away. They keep a tiny notepad (fast GPU memory) with just enough info to move to the next section. The play (computation) goes on smoothly, and the audience (model) still sees the full story—no one notices the operator’s trick!