MLA 和 MHA 的 Kernel 区别
MLA(Multi-head Latent Attention)和 MHA(Multi-Head Attention)在公式上看起来差别不大,但落到 Kernel 实现上,内存访问模式和计算特征完全不同。这篇从实现角度把差异算清楚。
内存布局和 KV Cache 大小
MHA
Q = X @ W_Q # [batch, seq, n_heads * head_dim]
K = X @ W_K # [batch, seq, n_kv_heads * head_dim]
V = X @ W_V # [batch, seq, n_kv_heads * head_dim]
缓存内容:完整的 K_proj 和 V_proj。
每 token 每 layer 的 KV Cache:2 × n_kv_heads × head_dim。
以 DeepSeek-V3 级别为例(假设 n_kv_heads=8, head_dim=128):约 1.0 MB/token/layer。
MLA
c_KV = X @ W_DKV # [batch, seq, latent_dim]
K = c_KV @ W_UK # [batch, seq, n_heads * head_dim]
V = c_KV @ W_UV # [batch, seq, n_heads * head_dim]
缓存内容:压缩向量 c_KV。
每 token 每 layer 的 KV Cache:1 × latent_dim。
以 DeepSeek-V2 为例(latent_dim=512):约 0.1 MB/token/layer。
压缩比约 10x。
HBM 布局差异
| MHA | MLA | |
|---|---|---|
| 缓存内容 | K_proj, V_proj | c_KV |
| 布局 | [layer, batch, n_kv_heads, seq, head_dim] | [layer, batch, seq, latent_dim] |
| 连续性 | 按 head 分片,跨 head 不连续 | 紧凑连续,对 coalesced access 更友好 |
Decode 阶段的 HBM 访问模式
Decode 阶段每次只生成 1 个 token,但要用到所有历史 token 的 KV。这是纯 memory-bound 场景。
sequenceDiagram
participant HBM as HBM
participant SRAM as SRAM
participant REG as Register
Note over HBM,REG: MHA Decode
loop seq_len times
HBM->>SRAM: load K[i] (head_dim)
HBM->>SRAM: load V[i] (head_dim)
SRAM->>REG: K/V
REG->>REG: dot(Q, K)
end
Note over HBM,REG: MLA Decode
loop seq_len times
HBM->>SRAM: load cKV[i] (latent_dim)
SRAM->>REG: cKV
REG->>REG: up-project K/V
REG->>REG: dot(Q, K)
end
MHA Decode
for each token:
load Q[head_dim]
for i in range(seq_len):
load K[i][head_dim] # 从 HBM 读
load V[i][head_dim] # 从 HBM 读
compute score = dot(Q, K[i])
softmax(scores)
for i in range(seq_len):
load V[i][head_dim] # 再次读 V!
out += scores[i] * V[i]
HBM 读取量:seq_len × n_kv_heads × head_dim × 2(K 和 V 各读一次)。
seq_len=32K, head_dim=128, n_kv_heads=8:约 64 MB。
MLA Decode
for each token:
load Q[head_dim]
for i in range(seq_len):
load ckv[i][latent_dim] # 读 1/10 数据
// 在 SRAM / Register 内解压
K = ckv[i] @ W_UK # [latent_dim] × [latent_dim, head_dim]
V = ckv[i] @ W_UV
compute score = dot(Q, K)
softmax(scores)
for i in range(seq_len):
load ckv[i][latent_dim]
V = ckv[i] @ W_UV
out += scores[i] * V
HBM 读取量:seq_len × latent_dim。
seq_len=32K, latent_dim=512:约 16 MB。
核心变化:HBM 读取量减少 4x,额外计算是在 SRAM / Register 内完成的小矩阵乘。
FlashAttention 适配的差异
MHA 的 FlashAttention
for block_k in range(0, N, BLOCK_SIZE):
K_tile = load_HBM_to_SRAM(K[block_k:block_k+BLOCK])
V_tile = load_HBM_to_SRAM(V[block_k:block_k+BLOCK])
for block_q in range(0, N, BLOCK_SIZE):
Q_tile = load_HBM_to_SRAM(Q[block_q:block_q+BLOCK])
S = Q_tile @ K_tile^T
# ...
Decode 时的尴尬:Q 只有 1 个 token([1, head_dim]),K/V tile 是 [BLOCK, head_dim]。SRAM 利用率极低,退化为逐块读 HBM。
MLA 的 FlashAttention
for block_ckv in range(0, N, BLOCK_SIZE):
CKV_tile = load_HBM_to_SRAM(CKV_cache[block_ckv:block_ckv+BLOCK])
// 在 SRAM 内解压
K_tile = upproject(CKV_tile, W_UK)
V_tile = upproject(CKV_tile, W_UV)
q = load_Q(Q[cur_pos])
S = q @ K_tile^T
优势:
- 同样 BLOCK_SIZE 下,HBM 读取量只有 1/10
- 解压后的 K/V 驻留 SRAM,可被多个 Q 复用
- DeepSeek 开源的 FlashMLA kernel 就是针对 MLA 专门优化的
Roofline 分析
用具体数字算计算强度。假设 batch=1, seq_len=32K。
MHA
读取:K_cache + V_cache = 2 × 32K × 8 × 128 × 2 bytes = 128 MB
FLOPs: Q@K^T + Softmax@V ≈ 4 × 32K × 128 × 128 = 2.1 GFLOPs
计算强度 I = 2.1e9 / 128e6 ≈ 16 FLOPs/byte
A100 带宽墙 I_roof ≈ 156(FP16),I = 16 << 156,深度 memory-bound。
MLA(理论值)
读取:cKV_cache = 32K × 512 × 2 = 32 MB(降 4x)
FLOPs:
Up-project K: 2 × 32K × 512 × 8192 = 270 GFLOPs
Up-project V: 2 × 32K × 512 × 16384 = 540 GFLOPs
Attention: ≈ 2.1 GFLOPs
总 FLOPs ≈ 812 GFLOPs
I = 812e9 / 32e6 ≈ 25,375 FLOPs/byte
25K >> 156,看起来是 compute-bound。但这里有几个陷阱。
实际计算强度
陷阱 1:Up-project 的矩阵形状效率低
MLA 的 up-project 不是一次漂亮的大 GEMM,而是 N=32K 个独立的 [1, 512] × [512, 8192] 小矩阵乘。或者说 batched 成 [32768, 512] × [512, 8192],但 M=32768, K=512 对 Tensor Core 不友好,实际利用率可能只有 15-20%。
陷阱 2:W_UK / W_UV 也要从 HBM 读
W_UK([512, 8192],约 8MB)和 W_UV(约 16MB)在多层切换时可能被踢出 L2 cache,需要从 HBM 重新加载。
修正后:
有效 FLOPs(按 15% 算):812 × 0.15 ≈ 120 GFLOPs
总带宽:cKV 32MB + W_UK 8MB + W_UV 16MB = 56 MB
I_effective = 120e9 / 56e6 ≈ 2,140
2,140 仍然大于 156,但已经没有理论值那么夸张了。
更关键的观察:当 batch size 很小时,MLA 确实大幅降低了 bandwidth 需求,但说”从 memory-bound 变成 compute-bound”过于乐观。更准确的说法是:MLA 把 HBM 瓶颈从”绝对不够”降到了”勉强够用”,用 SRAM 内的计算替代了昂贵的 HBM 读取。
实际 Kernel 层面的代价
| 代价项 | MHA | MLA | 影响 |
|---|---|---|---|
| HBM 读取 | 128 MB | 32 MB | 降 4x |
| 小矩阵乘效率 | 无 | [64,512]×[512,8192] | Tensor Core 利用率低 |
| 寄存器压力 | 中 | 高 | up-project 需要更多寄存器 |
| SRAM 容量 | Q+K_tile+V_tile | Q+ckv_tile+K_tile+V_tile | tiling 块大小被迫减小 |
| L2 cache 抖动 | 低 | 高 | W_UK/W_UV 占满 cache |
量化友好性
| 量化策略 | MHA | MLA |
|---|---|---|
| KV Cache 量化 | 对 K/V 做 FP8/INT8,有损精度 | 对 c_KV 量化,latent 空间更鲁棒 |
| 权重量化 | W_K, W_V 量化 | W_DKV, W_UK, W_UV 量化,矩阵更小 |
| Kernel 融合 | 反量化 + load | 反量化 cKV → upproject,融合度更高 |
MLA 的 c_KV 量化有个优势:latent 空间经过训练压缩,数值分布更集中,量化误差对最终 K/V 的影响比直接量化 K/V 更小。
总结
MHA 的 Kernel 是”从 HBM 读大张量,在 SRAM 内做点积”。MLA 的 Kernel 是”从 HBM 读小张量,在 SRAM 内先解压再点积”。前者被 HBM 带宽掐住,后者用少量寄存器计算换回了数量级的内存效率。
但 MLA 的收益不是”脱离 memory-bound”,而是把与 seq_len 成正比的 HBM 瓶颈,替换成了与 seq_len 成正比的 compute 瓶颈。后者的缓解手段(batching、更好的 kernel fusion、下一代 Tensor Core)比前者的物理带宽上限要灵活。