Two major operations in modern LLM inference are prefill and decode. It’s important to know what they are and how they differ. Most modern LLMs are variations of GPT, which is Decoder-only model. These models take an input (or prompt), process it, sample the next token, and use the previous input + sampled token as the next input in an autoregressive manner. This is called “decoding” and happens one token at a time.

What is Prefill?

Prefill happens when the input or prompt are given to the model. This is done by calculating all the tokens parallely, just like we do in Transformer’s Encoder Layer. Unlike BERT or T5, which allows bidirectional attention, standard Decoder-only models like GPT need to apply causal masking. This is because that’s the way the model is trained — to avoid “cheating” the next token which may lead to lack of next token prediction ability.

Prefill is known to be compute-bound, because we need to process self-attention to each and every token in the given prompt.

Ignoring minor calculations such as splitting/concatenating heads for Multi-Head Attention, two major operations we need in a single forward pass is to 1) forward pass each tokens to Q,K,V & 2) calculate Attention scores ($\text{softmax}({\frac{QK^T}{\sqrt{d_k}}V})$). So in Prefill stage, we will need to do 1 & 2 for all the tokens with causal mask applied.

What is Decode?

As mentioned above, decode happens when we concatentate the sampled next token to the original prompt and use it as the next input of the model. Decoding ends until it samples <eos>(special token indicating end of a sequence) or reaches the maximum length — either specificed by the user or the LLM.

Someone may ask, why does Decode differ from Prefill if both of them are essentially inserting input to a Decoder model and processing it? This is true if we didn’t have KV Cache.

KV Cache is a special cache that stores the information of already calculated (either by prefill or during the decode stage) token’s KV values. If we can save the KV values during the decoding process, we can just reuse the KV values and only calculate the QKV for the last sampled token.

{{< note >}}

$Q.$ How is reusing KV Cache possible?

$A.$ This is because Q, K, and V of a token only depends on the previous token’s K and V. This means even if new tokens are generated during the decoding process, it doesn’t affect the previous K, V values. Rewriting this, once the K, V values are calculated, they do not change (unless for exceptions like bidirectional attention)

$Q.$ Why KV Cache, not QKV Cache?

$A.$ This is because we do not need Q values after that token is used. In other words, in decoding stage, we only calculate the last step’s sampled token’s Q with all the previous K, V matrices. After sampling the next token, we don’t need to use Query values anymore since its sole purpose was to get the probability scores of the next likely token. However, since all the past KV values are used for the next token’s Q, we need to preserve this information. This is why we only care saving KV values but not Q values.

{{< /note >}}

If it weren’t for KV cache, decoding would be same as “prefilling with the next token added” which is very inefficient since we need to recalculate all the tokens from start. Image provided from Sebastian Raschka clearly visualizes this inefficient computation if we didn’t use KV cache:

image.png