Sealessland logo Sealessland
LLM Inference

TileLang 学习笔记

整理 TileLang 的核心概念和写法,作为学习过程中的记录。

TileLang 学习笔记

TileLang 的定位卡在 CUDA 和 Triton 中间:比 CUDA 开发快,比 Triton 控制精度高。这篇整理核心概念和写法,方便查阅。

TileLang 是什么

TileLang 是 MLC 团队基于 TVM 编译栈做的 GPU/CPU Kernel DSL。核心特点是:

  • 用 Python 写 kernel,语法接近 Triton
  • 显式控制数据在 Global / Shared / Register 三级内存之间的流动
  • 编译时特化所有参数,运行时零开销
  • 基于 TVM,所以后端支持 CUDA、Metal、Vulkan 等

和 Triton 的主要区别:Triton 的 tile 管理是半自动的,编译器帮你决定一些细节;TileLang 要求你显式指定 tile size 和数据搬运路径,控制更细。

核心概念:Tile 和内存层级

TileLang 的名字来源于 Tile(瓦片)——把大矩阵切成固定大小的块,逐块处理。这是高性能 GPU kernel 的基础思想,但 TileLang 要求在代码里显式表达这个过程。

三级内存层级:

层级位置速度TileLang 操作
Global Memory (HBM)显卡显存最慢tl.copy()
Shared Memory (SRAM)SM 片上很快tl.load() / tl.store()
Register线程私有最快直接参与计算
graph LR
    HBM["Global Memory (HBM)<br/>tl.copy()"] --> SRAM["Shared Memory (SRAM)<br/>tl.load() / tl.store()"]
    SRAM --> REG["Register<br/>直接计算"]
    
    style HBM fill:#ffebee,stroke:#c62828
    style SRAM fill:#e3f2fd,stroke:#1565c0
    style REG fill:#e8f5e9,stroke:#2e7d32

关键规则:数据必须从 Global → Shared → Register 逐级搬运,不能跳过。这个约束看起来麻烦,但其实是帮你写出高性能 kernel 的——显式管理内存的 kernel,通常比隐式管理的快。

一个完整例子:矩阵乘法

import tilelang as tl

@tl.jit
def matmul(A: tl.Buffer, B: tl.Buffer, C: tl.Buffer):
    # 定义 tile 大小
    with tl.Tile([128, 128]):
        # 从 Global Memory 拷贝到 Shared Memory
        # 这一步是 coalesced load,利用全局内存带宽
        shared_A = tl.copy(A)
        shared_B = tl.copy(B)
        
        # 累加器放在 Register
        acc = tl.zeros([128, 128])
        
        # K 维度分块迭代
        for k in tl.range(0, K, 32):
            # 从 Shared Memory 加载到 Register
            local_A = tl.load(shared_A)
            local_B = tl.load(shared_B)
            
            # 在 Register 内计算
            acc += local_A * local_B
        
        # 写回 Global Memory
        tl.store(C, acc)

这段代码的结构很清晰:

  1. tl.copy() 负责大块数据的 Global → Shared 搬运
  2. tl.load() 负责 Shared → Register 的细粒度加载
  3. 计算在 Register 内完成
  4. tl.store() 把结果写回

编译后,tl.copy() 变成 ld.global → st.shared 的 PTX 指令,tl.load() 变成 ld.shared,没有 Python runtime 介入。

内存布局和对齐

TileLang 中 Shared Memory 的 layout 很重要,直接影响 bank conflict。

# 好的 layout:行优先,连续访问
shared_A = tl.copy(A, layout=tl.RowMajor())

# 如果需要转置,显式指定
shared_B = tl.copy(B, layout=tl.ColMajor())

TileLang 允许你在 tl.copy() 时指定 layout,编译器会根据 layout 生成无 bank conflict 的访问模式。这一点比手写 CUDA 方便,不用自己算 offset。

和 CUDA / Triton / Cutlass 的对比

维度TileLangTritonCutlassCUDA
开发速度
控制精度极高极高
性能天花板接近 Cutlass接近 Cutlass最高最高
学习曲线平缓陡峭陡峭
硬件覆盖TVM 后端NVIDIANVIDIANVIDIA

选 TileLang 的场景

  • 需要写融合算子(如 fused MLP、MoE routing),Triton 内置实现不够用
  • 需要跨平台(除了 NVIDIA,还要支持 Metal、Vulkan)
  • 对内存层级有精确控制需求,但不想写 C++ Cutlass

不选 TileLang 的场景

  • 只是写标准 GEMM,直接用 cuBLAS 最快
  • 快速验证想法,Triton 更轻量
  • 团队里没有 TVM 经验,调试编译问题成本高

TVM 到 TileLang 的编译链路

graph TD
    A[Python Kernel Code] -->|TileLang Frontend| B[Relay / Relax]
    B --> C[TensorIR / TIR]
    C -->|Schedule Transform| D[Optimized TIR]
    D -->|Codegen| E[CUDA / Metal / Vulkan]
    E --> F[Binary .so]
    
    C -.->|AutoTVM / MetaSchedule| G[Search Best Config]
    G -.-> D

TileLang 的 Python 代码先 lower 到 TVM 的图 IR,再 lower 到 TIR,经过 schedule 优化后生成目标代码。理解这个链路对调试编译错误很重要。

学习路径和踩坑记录

入门顺序

  1. 先理解 TVM 的基础概念:Relay/Relax、TIR、Schedule。TileLang 是站在 TVM 肩膀上的,不了解 TVM 的话,编译错误很难看懂
  2. 从简单 kernel 开始:向量加法、矩阵乘法,熟悉 tl.copy / tl.load / tl.store 的语义
  3. 学内存层级优化:尝试改 tile size,观察性能变化,理解为什么某个 size 更快
  4. 写融合算子:比如把两个线性层 + activation 融合成一个 kernel
  5. 看 MLC-LLM 的源码:实际项目中 TileLang 是怎么用的

常见坑

编译错误信息难读。TileLang 的 Python 前端会 lower 到 TIR,再 lower 到目标后端。如果某个操作不合法,错误信息可能是 TIR 级别的,和 Python 代码的对应关系不明显。建议先写小例子验证,再逐步扩展。

tile size 选不对性能差很远。TileLang 不会自动选最优 tile size,需要自己测或者接 AutoTVM。比如 64×64 的 tile size,性能可能只有最优的 30%,改成 128×128 后才接近 cuBLAS。

Shared Memory 容量限制。SM 的 Shared Memory 通常只有 100KB 左右(Hopper),两个 128×128 的 FP32 tile 就占 128KB,超了。需要根据实际情况调整 tile size 或用 FP16/BF16。

一个实际的融合算子例子

@tl.jit
def fused_mlp(x: tl.Buffer, w1: tl.Buffer, w2: tl.Buffer, out: tl.Buffer):
    with tl.Tile([128, 512]):
        # 加载输入和权重到 Shared Memory
        shared_x = tl.copy(x)
        shared_w1 = tl.copy(w1)
        
        # 第一个线性层
        hidden = tl.matmul(shared_x, shared_w1)
        
        # Activation(SiLU)
        hidden = tl.silu(hidden)
        
        # 第二个线性层
        shared_w2 = tl.copy(w2)
        result = tl.matmul(hidden, shared_w2)
        
        tl.store(out, result)

这个例子展示了 TileLang 的核心价值:把多个操作融合成一个 kernel,中间结果不写出到 HBM。如果用 PyTorch eager 模式,这段计算至少要 launch 4 个 kernel(linear1 → silu → linear2),每个之间都要读写 HBM。TileLang 融合后只有一个 kernel,中间数据驻留在 Shared Memory / Register。

FlashAttention 的实现

用 TileLang 写 FlashAttention 是进阶练习,能覆盖更多内存层级控制技巧。

为什么需要 FlashAttention

标准 Attention 的计算过程:

S = Q @ K^T
P = softmax(S)
O = P @ V

如果序列长度 N=32KN = 32KSSPP 的维度都是 [N,N][N, N],FP32 下需要 4GB 显存。N=1MN = 1M 时就是 4TB,不可能存进 HBM。

FlashAttention 的核心是不存中间矩阵。通过 tiling 把 Q/K/V 分块加载到 SRAM,在 SRAM 内完成 softmax 和输出计算,只把最终结果写回 HBM。

Online Softmax 的关键

FlashAttention 不能简单地对每个 block 单独做 softmax 再拼接,因为 softmax 的分母依赖于全局最大值。解决方法是online softmax:维护两个累加器,row-wise 的 max(m)和 sum(l),遇到新 block 时增量更新。

具体公式:

m_new = max(m_prev, max(S_block))
l_new = l_prev * exp(m_prev - m_new) + sum(exp(S_block - m_new))
O_new = O_prev * exp(m_prev - m_new) + exp(S_block - m_new) @ V_block

每加载一个 K/V block,就用上面的公式修正之前的输出。这样不需要知道未来的 block,可以逐块处理。

TileLang 实现

@tl.jit
def flash_attention(Q: tl.Buffer, K: tl.Buffer, V: tl.Buffer, O: tl.Buffer):
    Br = 128   # Q tile 的行数
    Bc = 128   # K/V tile 的行数
    
    with tl.Tile([Br, d]):
        # 加载 Q tile 到 Shared Memory
        q_tile = tl.copy(Q)
        
        # 累加器和 softmax 统计量放在 Register
        acc = tl.zeros([Br, d])
        m = tl.full([Br], -float('inf'))
        l = tl.zeros([Br])
        
        # 外层循环遍历 K/V
        for j in tl.range(0, seq_len, Bc):
            k_tile = tl.copy(K[j:j+Bc])
            v_tile = tl.copy(V[j:j+Bc])
            
            # S = Q @ K^T / sqrt(d)
            s = tl.matmul(q_tile, tl.transpose(k_tile)) / sqrt_d
            
            # Online softmax
            m_new = tl.max(s, axis=1)
            m_prev = m
            m = tl.maximum(m_prev, m_new)
            
            # 修正累加器
            scale_prev = tl.exp(m_prev - m)
            p = tl.exp(s - m)
            l = l * scale_prev + tl.sum(p, axis=1)
            acc = acc * scale_prev + tl.matmul(p, v_tile)
        
        # 最后归一化
        o = acc / l
        tl.store(O, o)

和矩阵乘法的区别

FlashAttention 和矩阵乘法有几个关键区别:

1. 双重循环结构

矩阵乘法是 K 维度上单纯累加:acc += A * B。FlashAttention 的 K 维度循环里,每次都要做 softmax 修正,依赖关系更复杂。

2. Shared Memory 预算更紧

一个 block 内要同时驻留 Q tile、K tile、V tile,加上中间结果 S:

Q: 128 × 64 × 2 = 16 KB (FP16)
K: 128 × 64 × 2 = 16 KB
V: 128 × 64 × 2 = 16 KB
S: 128 × 128 × 4 = 64 KB (FP32,避免精度损失)
────────────────────────────
Total: ~112 KB

Hopper 的 Shared Memory 是 228 KB,所以还能塞得下。但如果 d = 128,或者想增大 tile size,就会超预算。这时候要么减小 tile size,要么用 FP16 存 S(有精度风险)。

3. Online softmax 的数值稳定性

exp(s - m) 里的减法不能省。如果直接 exp(s),数值会爆炸。漏掉这一步会导致输出全是 NaN。

FlashAttention-2 的改进

FlashAttention-2 在算法上和 v1 没有本质区别,主要是调度优化:

  • 减少 non-matmul FLOPs:v1 里 softmax 的 online 统计和 matmul 是串行的,v2 通过更好的 warp 划分让它们更并行
  • 更好的 occupancy:通过调整 block 和 warp 的分配,让更多 warp 同时活跃

用 TileLang 写 v2 时,核心逻辑是一样的,差异主要在 tl.Tile 的参数配置和 thread binding 上。

待优化点

上面这个 FlashAttention 实现能跑通,但性能还没调到最优。主要卡在:

  1. tile size 的选择:Br=128, Bc=128 不一定是最优配置,需要根据 seq_len 和 head_dim 动态调整
  2. FP16 精度:用 FP16 存中间结果能省显存,但长序列下 softmax 的数值精度会出问题
  3. 和 TVM schedule 的交互:TileLang 的 tl.range 循环,编译器默认的调度策略可能不是最优,需要手动调 split、reorder