Sealessland logo Sealessland
LLM Inference

MLA 和 MHA 的 Kernel 区别

从 Kernel 实现角度对比 MLA 和 MHA,包括内存访问、计算强度和量化影响。

MLA 和 MHA 的 Kernel 区别

MLA(Multi-head Latent Attention)和 MHA(Multi-Head Attention)在公式上看起来差别不大,但落到 Kernel 实现上,内存访问模式和计算特征完全不同。这篇从实现角度把差异算清楚。

内存布局和 KV Cache 大小

MHA

Q = X @ W_Q      # [batch, seq, n_heads * head_dim]
K = X @ W_K      # [batch, seq, n_kv_heads * head_dim]
V = X @ W_V      # [batch, seq, n_kv_heads * head_dim]

缓存内容:完整的 K_proj 和 V_proj。

每 token 每 layer 的 KV Cache:2 × n_kv_heads × head_dim

以 DeepSeek-V3 级别为例(假设 n_kv_heads=8, head_dim=128):约 1.0 MB/token/layer。

MLA

c_KV = X @ W_DKV           # [batch, seq, latent_dim]
K = c_KV @ W_UK            # [batch, seq, n_heads * head_dim]
V = c_KV @ W_UV            # [batch, seq, n_heads * head_dim]

缓存内容:压缩向量 c_KV。

每 token 每 layer 的 KV Cache:1 × latent_dim

以 DeepSeek-V2 为例(latent_dim=512):约 0.1 MB/token/layer。

压缩比约 10x

HBM 布局差异

MHAMLA
缓存内容K_proj, V_projc_KV
布局[layer, batch, n_kv_heads, seq, head_dim][layer, batch, seq, latent_dim]
连续性按 head 分片,跨 head 不连续紧凑连续,对 coalesced access 更友好

Decode 阶段的 HBM 访问模式

Decode 阶段每次只生成 1 个 token,但要用到所有历史 token 的 KV。这是纯 memory-bound 场景。

sequenceDiagram
    participant HBM as HBM
    participant SRAM as SRAM
    participant REG as Register
    
    Note over HBM,REG: MHA Decode
    loop seq_len times
        HBM->>SRAM: load K[i] (head_dim)
        HBM->>SRAM: load V[i] (head_dim)
        SRAM->>REG: K/V
        REG->>REG: dot(Q, K)
    end
    
    Note over HBM,REG: MLA Decode
    loop seq_len times
        HBM->>SRAM: load cKV[i] (latent_dim)
        SRAM->>REG: cKV
        REG->>REG: up-project K/V
        REG->>REG: dot(Q, K)
    end

MHA Decode

for each token:
    load Q[head_dim]
    
    for i in range(seq_len):
        load K[i][head_dim]      # 从 HBM 读
        load V[i][head_dim]      # 从 HBM 读
        compute score = dot(Q, K[i])
    
    softmax(scores)
    
    for i in range(seq_len):
        load V[i][head_dim]      # 再次读 V!
        out += scores[i] * V[i]

HBM 读取量seq_len × n_kv_heads × head_dim × 2(K 和 V 各读一次)。

seq_len=32K, head_dim=128, n_kv_heads=8:约 64 MB。

MLA Decode

for each token:
    load Q[head_dim]
    
    for i in range(seq_len):
        load ckv[i][latent_dim]   # 读 1/10 数据
        
        // 在 SRAM / Register 内解压
        K = ckv[i] @ W_UK          # [latent_dim] × [latent_dim, head_dim]
        V = ckv[i] @ W_UV
        
        compute score = dot(Q, K)
    
    softmax(scores)
    
    for i in range(seq_len):
        load ckv[i][latent_dim]
        V = ckv[i] @ W_UV
        out += scores[i] * V

HBM 读取量seq_len × latent_dim

seq_len=32K, latent_dim=512:约 16 MB。

核心变化:HBM 读取量减少 4x,额外计算是在 SRAM / Register 内完成的小矩阵乘。

FlashAttention 适配的差异

MHA 的 FlashAttention

for block_k in range(0, N, BLOCK_SIZE):
    K_tile = load_HBM_to_SRAM(K[block_k:block_k+BLOCK])
    V_tile = load_HBM_to_SRAM(V[block_k:block_k+BLOCK])
    
    for block_q in range(0, N, BLOCK_SIZE):
        Q_tile = load_HBM_to_SRAM(Q[block_q:block_q+BLOCK])
        S = Q_tile @ K_tile^T
        # ...

Decode 时的尴尬:Q 只有 1 个 token([1, head_dim]),K/V tile 是 [BLOCK, head_dim]。SRAM 利用率极低,退化为逐块读 HBM。

MLA 的 FlashAttention

for block_ckv in range(0, N, BLOCK_SIZE):
    CKV_tile = load_HBM_to_SRAM(CKV_cache[block_ckv:block_ckv+BLOCK])
    
    //SRAM 内解压
    K_tile = upproject(CKV_tile, W_UK)
    V_tile = upproject(CKV_tile, W_UV)
    
    q = load_Q(Q[cur_pos])
    S = q @ K_tile^T

优势

  • 同样 BLOCK_SIZE 下,HBM 读取量只有 1/10
  • 解压后的 K/V 驻留 SRAM,可被多个 Q 复用
  • DeepSeek 开源的 FlashMLA kernel 就是针对 MLA 专门优化的

Roofline 分析

用具体数字算计算强度。假设 batch=1, seq_len=32K。

MHA

读取:K_cache + V_cache = 2 × 32K × 8 × 128 × 2 bytes = 128 MB
FLOPs: Q@K^T + Softmax@V ≈ 4 × 32K × 128 × 128 = 2.1 GFLOPs

计算强度 I = 2.1e9 / 128e6 ≈ 16 FLOPs/byte

A100 带宽墙 I_roof ≈ 156(FP16),I = 16 << 156,深度 memory-bound。

MLA(理论值)

读取:cKV_cache = 32K × 512 × 2 = 32 MB(降 4x)

FLOPs:
  Up-project K: 2 × 32K × 512 × 8192 = 270 GFLOPs
  Up-project V: 2 × 32K × 512 × 16384 = 540 GFLOPs
  Attention: ≈ 2.1 GFLOPs
  总 FLOPs ≈ 812 GFLOPs

I = 812e9 / 32e6 ≈ 25,375 FLOPs/byte

25K >> 156,看起来是 compute-bound。但这里有几个陷阱。

实际计算强度

陷阱 1:Up-project 的矩阵形状效率低

MLA 的 up-project 不是一次漂亮的大 GEMM,而是 N=32K 个独立的 [1, 512] × [512, 8192] 小矩阵乘。或者说 batched 成 [32768, 512] × [512, 8192],但 M=32768, K=512 对 Tensor Core 不友好,实际利用率可能只有 15-20%。

陷阱 2:W_UK / W_UV 也要从 HBM 读

W_UK([512, 8192],约 8MB)和 W_UV(约 16MB)在多层切换时可能被踢出 L2 cache,需要从 HBM 重新加载。

修正后:

有效 FLOPs(按 15% 算):812 × 0.15 ≈ 120 GFLOPs
总带宽:cKV 32MB + W_UK 8MB + W_UV 16MB = 56 MB
I_effective = 120e9 / 56e6 ≈ 2,140

2,140 仍然大于 156,但已经没有理论值那么夸张了。

更关键的观察:当 batch size 很小时,MLA 确实大幅降低了 bandwidth 需求,但说”从 memory-bound 变成 compute-bound”过于乐观。更准确的说法是:MLA 把 HBM 瓶颈从”绝对不够”降到了”勉强够用”,用 SRAM 内的计算替代了昂贵的 HBM 读取

实际 Kernel 层面的代价

代价项MHAMLA影响
HBM 读取128 MB32 MB降 4x
小矩阵乘效率[64,512]×[512,8192]Tensor Core 利用率低
寄存器压力up-project 需要更多寄存器
SRAM 容量Q+K_tile+V_tileQ+ckv_tile+K_tile+V_tiletiling 块大小被迫减小
L2 cache 抖动W_UK/W_UV 占满 cache

量化友好性

量化策略MHAMLA
KV Cache 量化对 K/V 做 FP8/INT8,有损精度对 c_KV 量化,latent 空间更鲁棒
权重量化W_K, W_V 量化W_DKV, W_UK, W_UV 量化,矩阵更小
Kernel 融合反量化 + load反量化 cKV → upproject,融合度更高

MLA 的 c_KV 量化有个优势:latent 空间经过训练压缩,数值分布更集中,量化误差对最终 K/V 的影响比直接量化 K/V 更小。

总结

MHA 的 Kernel 是”从 HBM 读大张量,在 SRAM 内做点积”。MLA 的 Kernel 是”从 HBM 读小张量,在 SRAM 内先解压再点积”。前者被 HBM 带宽掐住,后者用少量寄存器计算换回了数量级的内存效率。

但 MLA 的收益不是”脱离 memory-bound”,而是把与 seq_len 成正比的 HBM 瓶颈,替换成了与 seq_len 成正比的 compute 瓶颈。后者的缓解手段(batching、更好的 kernel fusion、下一代 Tensor Core)比前者的物理带宽上限要灵活。