Anatomy of an LLM: Internals of DeepSeek-V3's Attention Implementation

The DeepSeek-V3 paper presents, in the section on Multi-Head Latent Attention, a fairly complex looking set of equations:

\[\mathbf{c}_t^{KV} = W^{DKV}\mathbf{h}_t \tag{1}\] \[\left[\mathbf{k}_{t,1}^{C};\mathbf{k}_{t,1}^{C};...;\mathbf{k}_{t,n_h}^{C}\right] = \mathbf{k}_{t}^{C} = W^{UK}\mathbf{c}_t^{KV} \tag{2}\] \[\mathbf{k}_t^{R} = \text{RoPE}\left( W^{KR}\mathbf{h}_t \right) \tag{3}\] \[\mathbf{k}_{t,i} = \left[ \mathbf{k}_{t,i}^C;\mathbf{k}_t^R \right] \tag{4}\] \[\left[\mathbf{v}_{t,1}^{C};\mathbf{v}_{t,1}^{C};...;\mathbf{v}_{t,n_h}^{C}\right] = \mathbf{v}_{t}^{C} = W^{UV}\mathbf{c}_t^{KV} \tag{5}\] \[\mathbf{o}_{t,i} = \sum_{j=1}^{t}\text{Softmax}_j\left( \frac{\mathbf{q}_{t,i}\mathbf{k}_{j,i}}{\sqrt{d_h+d^R_h}} \right)\mathbf{v}_{j,i}^C \tag{6}\]

This post attempts to break these equations down to show where they’re coming from, and hopefully provide some helpful intuition.

Overview

DeepSeek builds on the standard transformer architecture, where given a series of input vectors $\mathbf{h}_t$, which each vector representing a token, we want to calculate three vectors: the key $\mathbf{k}$, query $\mathbf{q}$, and value $\mathbf{v}$, using a set of weight matrices. We then apply the attention operation, which involves taking the dot product between the query of one token with the value of other tokens to get an idea of how “relevant” the tokens are to each other. Finally, we scale the result by the value.

Typically, we’d start by taking the input token, $\mathbf{h}_t$, and calculating $\mathbf{q}_t$, $\mathbf{k}_t$, $\mathbf{v}_t$:

\[\mathbf{q}_t = W^Q\mathbf{h}_t\] \[\mathbf{k}_t = W^K\mathbf{h}_t\] \[\mathbf{v}_t = W^V\mathbf{h}_t\]

DeepSeek makes a few modifications to this process.

Low-Rank Joint Compression

First, it leverages Low-Rank Joint Compression to reduce the size of the matrices involved. So instead of using a single matrix $W^K$, it decomposes the matrix into a product of two smaller matrices:

\[W^K = W^{UK}W^{DK}\]

So instead of just multiplying the input by $W^K$, we first multiply the input token $\mathbf{h}_t$ by the down-projection matrix $W^{DK}$, which reduces the dimension of the input, and then we apply an up-projection matrix $W^{UK}$ to restore the original dimension.

We also combine the key and value matrices ($W^K$ and $W^V$) into a single matrix. Putting this together, we get the first equation:

\[\mathbf{c}_t^{KV} = W^{DKV}\mathbf{h}_t \tag{1}\]

So we take the input token, and multiply by a combined key and value down-projection matrix $W^{DKV}$ to get $\mathbf{c}_t^{KV}$, which is a compressed combination of the key and value vectors.

Next, we apply the up-projection matrices for keys and values to $\mathbf{c}_t^{KV}$, which give us equations $(2)$ and $(5)$:

\(\left[\mathbf{k}_{t,1}^{C};\mathbf{k}_{t,1}^{C};...;\mathbf{k}_{t,n_h}^{C}\right] = \mathbf{k}_{t}^{C} = W^{UK}\mathbf{c}_t^{KV} \tag{2}\) \(\left[\mathbf{v}_{t,1}^{C};\mathbf{v}_{t,1}^{C};...;\mathbf{v}_{t,n_h}^{C}\right] = \mathbf{v}_{t}^{C} = W^{UV}\mathbf{c}_t^{KV} \tag{5}\)

Multi-Head Attention

Note that the up-projection matrix is sized so that instead of getting a single key vector as output, we get a whole series of them - $n_h$ of them specifically. This is part of the multi-head attention mechanism, where instead of calculating $\mathbf{q}$, $\mathbf{k}$, and $\mathbf{v}$ once for each input token, we do it multiple times in parallel, and then combine the result in the end.

Position Encoding

Now we move on to positional encoding. The idea is that we’d like the queries and keys to encode some information about the positions of the input tokens. DeepSeek uses RoPE, which we can define as a function that takes in an input token along with its position, and modifies it in a way that incorporates its position:

\[\mathbf{k}_t^{R} = \text{RoPE}\left( W^{KR}\mathbf{h}_t \right) \tag{3}\]

The DeepSeek implementation doesn’t apply RoPE to the keys calculated in equation $(2)$. Instead, a separate matrix $W^{KR}$ is applied to create a key vector $\mathbf{k}_t^{R}$ that will be position-encoded. This is then concatenated with the position-encoded key vectors to the key vectors from equation $(2)$, which gives us equation $(4)$:

\[\mathbf{k}_{t,i} = \left[ \mathbf{k}_{t,i}^C;\mathbf{k}_t^R \right] \tag{4}\]

Note that there is no position encoding applied to the value vectors. So both queries and keys have position encoding applied, but not the values.

Implementation

The implementation is fairly straightforward, except there’s a slight change of terminology. Instead of using $U$ to label up-projection matrices and $D$ to label down-projection matrices, $a$ and $b$ are used instead. So we start by declaring these matrices $W^Q_a$ and $W^Q_b$ for the query calculation:

self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)

The input tensor x has dimensions $[B, T, D]$ where $B$ is the batch size, $T$ is the number of tokens in the batch, and $D$ is the embedding dimension. $W^Q_a$ is a linear transformation from the embedding dimension $D$ (or self.dim), to the LoRA dimension (self.q_lora_rank).

$W^Q_b$ is a linear transformation from the LoRA dimension to a series of vectors, one per attention head. So the total size of the result of the transformation is the number of heads (self.n_heads) multiplied by the dimension per head (self.qk_head_dim).

These transformations are then applied to the input:

q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)

This gives us a tensor with dimensions $[B, T, H * D_H]$, which is then re-shaped into $[B, T, H, D_H]$, where $H$ is the number of heads and $D_H$ is the per-head dimension. The dimension is then split into two parts, with RoPE only being applied to one part. The results are then concatenated to form the $\mathbf{q}$ tensor:

q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
q = torch.cat([q_nope, q_pe], dim=-1)

Next, we compute the key and value tensors. These are computed together, using a single set set of $W^{KV}$ matrices:

self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))

$W^{KV}_a$ is a linear transformation that reduces the input dimension to kv_lora_rank + qk_rope_head_dim. The kv_lora_rank dimensions will go through the rest of the LoRA transformation, meanwhile the qk_rope_head_dim will have position encoding applied.

The output dimension of $W^{KV}_b$ is set so that there are two parts to the tensor given to each attention head. There are self.qk_nope_head_dim dimensions used for the non-position encoding LoRA path for keys, and self.v_head_dim dimensions used for the $\mathbf{v}$ vector, which only goes through the LoRA transformation - no position encoding is applied for values.

Next, we apply these matrices:

kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

First, $W_a^{KV}$ is applied to reduce the input dimension to kv_lora_rank + qk_rope_head_dim. Next, we extract out a subset of size qk_rope_head_dim to apply RoPE encoding to. Then:

kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)

The leftover dimensions, of size kv_lora_rank, are then normalized and up-projected using $W^{KV}_b$ to a dimension of qk_nope_head_dim + v_head_dim, and then finally split out into $\mathbf{k}$ and $\mathbf{v}$ components. $\mathbf{k}$ is composed of one part that had positional encoding applied to it, and another that hadn’t.