TileLang 学习笔记
TileLang 的定位卡在 CUDA 和 Triton 中间:比 CUDA 开发快,比 Triton 控制精度高。这篇整理核心概念和写法,方便查阅。
TileLang 是什么
TileLang 是 MLC 团队基于 TVM 编译栈做的 GPU/CPU Kernel DSL。核心特点是:
- 用 Python 写 kernel,语法接近 Triton
- 显式控制数据在 Global / Shared / Register 三级内存之间的流动
- 编译时特化所有参数,运行时零开销
- 基于 TVM,所以后端支持 CUDA、Metal、Vulkan 等
和 Triton 的主要区别:Triton 的 tile 管理是半自动的,编译器帮你决定一些细节;TileLang 要求你显式指定 tile size 和数据搬运路径,控制更细。
核心概念:Tile 和内存层级
TileLang 的名字来源于 Tile(瓦片)——把大矩阵切成固定大小的块,逐块处理。这是高性能 GPU kernel 的基础思想,但 TileLang 要求在代码里显式表达这个过程。
三级内存层级:
| 层级 | 位置 | 速度 | TileLang 操作 |
|---|---|---|---|
| Global Memory (HBM) | 显卡显存 | 最慢 | tl.copy() |
| Shared Memory (SRAM) | SM 片上 | 很快 | tl.load() / tl.store() |
| Register | 线程私有 | 最快 | 直接参与计算 |
graph LR
HBM["Global Memory (HBM)<br/>tl.copy()"] --> SRAM["Shared Memory (SRAM)<br/>tl.load() / tl.store()"]
SRAM --> REG["Register<br/>直接计算"]
style HBM fill:#ffebee,stroke:#c62828
style SRAM fill:#e3f2fd,stroke:#1565c0
style REG fill:#e8f5e9,stroke:#2e7d32
关键规则:数据必须从 Global → Shared → Register 逐级搬运,不能跳过。这个约束看起来麻烦,但其实是帮你写出高性能 kernel 的——显式管理内存的 kernel,通常比隐式管理的快。
一个完整例子:矩阵乘法
import tilelang as tl
@tl.jit
def matmul(A: tl.Buffer, B: tl.Buffer, C: tl.Buffer):
# 定义 tile 大小
with tl.Tile([128, 128]):
# 从 Global Memory 拷贝到 Shared Memory
# 这一步是 coalesced load,利用全局内存带宽
shared_A = tl.copy(A)
shared_B = tl.copy(B)
# 累加器放在 Register
acc = tl.zeros([128, 128])
# K 维度分块迭代
for k in tl.range(0, K, 32):
# 从 Shared Memory 加载到 Register
local_A = tl.load(shared_A)
local_B = tl.load(shared_B)
# 在 Register 内计算
acc += local_A * local_B
# 写回 Global Memory
tl.store(C, acc)
这段代码的结构很清晰:
tl.copy()负责大块数据的 Global → Shared 搬运tl.load()负责 Shared → Register 的细粒度加载- 计算在 Register 内完成
tl.store()把结果写回
编译后,tl.copy() 变成 ld.global → st.shared 的 PTX 指令,tl.load() 变成 ld.shared,没有 Python runtime 介入。
内存布局和对齐
TileLang 中 Shared Memory 的 layout 很重要,直接影响 bank conflict。
# 好的 layout:行优先,连续访问
shared_A = tl.copy(A, layout=tl.RowMajor())
# 如果需要转置,显式指定
shared_B = tl.copy(B, layout=tl.ColMajor())
TileLang 允许你在 tl.copy() 时指定 layout,编译器会根据 layout 生成无 bank conflict 的访问模式。这一点比手写 CUDA 方便,不用自己算 offset。
和 CUDA / Triton / Cutlass 的对比
| 维度 | TileLang | Triton | Cutlass | CUDA |
|---|---|---|---|---|
| 开发速度 | 中 | 快 | 慢 | 慢 |
| 控制精度 | 高 | 中 | 极高 | 极高 |
| 性能天花板 | 接近 Cutlass | 接近 Cutlass | 最高 | 最高 |
| 学习曲线 | 中 | 平缓 | 陡峭 | 陡峭 |
| 硬件覆盖 | TVM 后端 | NVIDIA | NVIDIA | NVIDIA |
选 TileLang 的场景:
- 需要写融合算子(如 fused MLP、MoE routing),Triton 内置实现不够用
- 需要跨平台(除了 NVIDIA,还要支持 Metal、Vulkan)
- 对内存层级有精确控制需求,但不想写 C++ Cutlass
不选 TileLang 的场景:
- 只是写标准 GEMM,直接用 cuBLAS 最快
- 快速验证想法,Triton 更轻量
- 团队里没有 TVM 经验,调试编译问题成本高
TVM 到 TileLang 的编译链路
graph TD
A[Python Kernel Code] -->|TileLang Frontend| B[Relay / Relax]
B --> C[TensorIR / TIR]
C -->|Schedule Transform| D[Optimized TIR]
D -->|Codegen| E[CUDA / Metal / Vulkan]
E --> F[Binary .so]
C -.->|AutoTVM / MetaSchedule| G[Search Best Config]
G -.-> D
TileLang 的 Python 代码先 lower 到 TVM 的图 IR,再 lower 到 TIR,经过 schedule 优化后生成目标代码。理解这个链路对调试编译错误很重要。
学习路径和踩坑记录
入门顺序
- 先理解 TVM 的基础概念:Relay/Relax、TIR、Schedule。TileLang 是站在 TVM 肩膀上的,不了解 TVM 的话,编译错误很难看懂
- 从简单 kernel 开始:向量加法、矩阵乘法,熟悉
tl.copy/tl.load/tl.store的语义 - 学内存层级优化:尝试改 tile size,观察性能变化,理解为什么某个 size 更快
- 写融合算子:比如把两个线性层 + activation 融合成一个 kernel
- 看 MLC-LLM 的源码:实际项目中 TileLang 是怎么用的
常见坑
编译错误信息难读。TileLang 的 Python 前端会 lower 到 TIR,再 lower 到目标后端。如果某个操作不合法,错误信息可能是 TIR 级别的,和 Python 代码的对应关系不明显。建议先写小例子验证,再逐步扩展。
tile size 选不对性能差很远。TileLang 不会自动选最优 tile size,需要自己测或者接 AutoTVM。比如 64×64 的 tile size,性能可能只有最优的 30%,改成 128×128 后才接近 cuBLAS。
Shared Memory 容量限制。SM 的 Shared Memory 通常只有 100KB 左右(Hopper),两个 128×128 的 FP32 tile 就占 128KB,超了。需要根据实际情况调整 tile size 或用 FP16/BF16。
一个实际的融合算子例子
@tl.jit
def fused_mlp(x: tl.Buffer, w1: tl.Buffer, w2: tl.Buffer, out: tl.Buffer):
with tl.Tile([128, 512]):
# 加载输入和权重到 Shared Memory
shared_x = tl.copy(x)
shared_w1 = tl.copy(w1)
# 第一个线性层
hidden = tl.matmul(shared_x, shared_w1)
# Activation(SiLU)
hidden = tl.silu(hidden)
# 第二个线性层
shared_w2 = tl.copy(w2)
result = tl.matmul(hidden, shared_w2)
tl.store(out, result)
这个例子展示了 TileLang 的核心价值:把多个操作融合成一个 kernel,中间结果不写出到 HBM。如果用 PyTorch eager 模式,这段计算至少要 launch 4 个 kernel(linear1 → silu → linear2),每个之间都要读写 HBM。TileLang 融合后只有一个 kernel,中间数据驻留在 Shared Memory / Register。
FlashAttention 的实现
用 TileLang 写 FlashAttention 是进阶练习,能覆盖更多内存层级控制技巧。
为什么需要 FlashAttention
标准 Attention 的计算过程:
S = Q @ K^T
P = softmax(S)
O = P @ V
如果序列长度 , 和 的维度都是 ,FP32 下需要 4GB 显存。 时就是 4TB,不可能存进 HBM。
FlashAttention 的核心是不存中间矩阵。通过 tiling 把 Q/K/V 分块加载到 SRAM,在 SRAM 内完成 softmax 和输出计算,只把最终结果写回 HBM。
Online Softmax 的关键
FlashAttention 不能简单地对每个 block 单独做 softmax 再拼接,因为 softmax 的分母依赖于全局最大值。解决方法是online softmax:维护两个累加器,row-wise 的 max(m)和 sum(l),遇到新 block 时增量更新。
具体公式:
m_new = max(m_prev, max(S_block))
l_new = l_prev * exp(m_prev - m_new) + sum(exp(S_block - m_new))
O_new = O_prev * exp(m_prev - m_new) + exp(S_block - m_new) @ V_block
每加载一个 K/V block,就用上面的公式修正之前的输出。这样不需要知道未来的 block,可以逐块处理。
TileLang 实现
@tl.jit
def flash_attention(Q: tl.Buffer, K: tl.Buffer, V: tl.Buffer, O: tl.Buffer):
Br = 128 # Q tile 的行数
Bc = 128 # K/V tile 的行数
with tl.Tile([Br, d]):
# 加载 Q tile 到 Shared Memory
q_tile = tl.copy(Q)
# 累加器和 softmax 统计量放在 Register
acc = tl.zeros([Br, d])
m = tl.full([Br], -float('inf'))
l = tl.zeros([Br])
# 外层循环遍历 K/V
for j in tl.range(0, seq_len, Bc):
k_tile = tl.copy(K[j:j+Bc])
v_tile = tl.copy(V[j:j+Bc])
# S = Q @ K^T / sqrt(d)
s = tl.matmul(q_tile, tl.transpose(k_tile)) / sqrt_d
# Online softmax
m_new = tl.max(s, axis=1)
m_prev = m
m = tl.maximum(m_prev, m_new)
# 修正累加器
scale_prev = tl.exp(m_prev - m)
p = tl.exp(s - m)
l = l * scale_prev + tl.sum(p, axis=1)
acc = acc * scale_prev + tl.matmul(p, v_tile)
# 最后归一化
o = acc / l
tl.store(O, o)
和矩阵乘法的区别
FlashAttention 和矩阵乘法有几个关键区别:
1. 双重循环结构
矩阵乘法是 K 维度上单纯累加:acc += A * B。FlashAttention 的 K 维度循环里,每次都要做 softmax 修正,依赖关系更复杂。
2. Shared Memory 预算更紧
一个 block 内要同时驻留 Q tile、K tile、V tile,加上中间结果 S:
Q: 128 × 64 × 2 = 16 KB (FP16)
K: 128 × 64 × 2 = 16 KB
V: 128 × 64 × 2 = 16 KB
S: 128 × 128 × 4 = 64 KB (FP32,避免精度损失)
────────────────────────────
Total: ~112 KB
Hopper 的 Shared Memory 是 228 KB,所以还能塞得下。但如果 d = 128,或者想增大 tile size,就会超预算。这时候要么减小 tile size,要么用 FP16 存 S(有精度风险)。
3. Online softmax 的数值稳定性
exp(s - m) 里的减法不能省。如果直接 exp(s),数值会爆炸。漏掉这一步会导致输出全是 NaN。
FlashAttention-2 的改进
FlashAttention-2 在算法上和 v1 没有本质区别,主要是调度优化:
- 减少 non-matmul FLOPs:v1 里 softmax 的 online 统计和 matmul 是串行的,v2 通过更好的 warp 划分让它们更并行
- 更好的 occupancy:通过调整 block 和 warp 的分配,让更多 warp 同时活跃
用 TileLang 写 v2 时,核心逻辑是一样的,差异主要在 tl.Tile 的参数配置和 thread binding 上。
待优化点
上面这个 FlashAttention 实现能跑通,但性能还没调到最优。主要卡在:
- tile size 的选择:Br=128, Bc=128 不一定是最优配置,需要根据 seq_len 和 head_dim 动态调整
- FP16 精度:用 FP16 存中间结果能省显存,但长序列下 softmax 的数值精度会出问题
- 和 TVM schedule 的交互:TileLang 的
tl.range循环,编译器默认的调度策略可能不是最优,需要手动调 split、reorder