Multi-head Attention
你需要写一个运行在 GPU 上的程序,实现多头注意力机制。
HPC
你需要写一个运行在 GPU 上的程序,实现多头注意力机制。具体来说,给出 Q (Query), K (Key), V (Value) 三个形状为 N * d_model 的矩阵,你需要计算:
同时:
Example :
N = 2, d_model = 4, h = 2
数据限制:
1 ≤ N ≤ 10002 ≤ d_model ≤ 10241 ≤ h ≤ d_modeld_model % h == 0-10.0 ≤ values ≤ 10.0
要求:
- 不能使用外部库
- 不允许修改
Solve函数
Answer:
__device__ inline float warpReduceMax(float val) {
for (int step = 32 / 2; step > 0; step /= 2) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, step));
}
return __shfl_sync(0xffffffff, val, 0);
}
__device__ inline float warpReduceSum(float val) {
for (int step = 32 / 2; step > 0; step /= 2) {
val += __shfl_down_sync(0xffffffff, val, step);
}
return __shfl_sync(0xffffffff, val, 0);
}
__global__ void attention_fused_kernel(
const float* Q, const float* K, const float* V,
float* output,
int N, int d_model, int h
) {
int head = blockIdx.z;
int row = blockIdx.y;
if (head >= h || row >= N) return;
const int w = d_model / h;
const float scale = 1.0f / sqrtf((float)w);
int col_off = head * w;
extern __shared__ float smem[];
float* attn_scores = smem;
// 预留足够空间给 warp reduce (最多 32 个 warp) float* warp_data = smem + N;
int tid = threadIdx.x;
int lane_id = tid % 32;
int warp_id = tid / 32;
int num_warps = (blockDim.x + 31) / 32;
// === 步骤 1: Q * K^T === for (int j = tid; j < N; j += blockDim.x) {
float acc = 0.0f;
const float* qrow = Q + row * d_model + col_off;
const float* krow = K + j * d_model + col_off;
for (int k = 0; k < w; ++k) {
acc += qrow[k] * krow[k];
}
attn_scores[j] = acc * scale;
}
__syncthreads();
// === 步骤 2: Softmax === // 2.1 找最大值
float thread_max = -INFINITY;
for (int j = tid; j < N; j += blockDim.x) {
thread_max = fmaxf(thread_max, attn_scores[j]);
}
float warp_max = warpReduceMax(thread_max);
if (lane_id == 0) {
warp_data[warp_id] = warp_max;
}
__syncthreads();
// 第一个 warp 完成最终 reduce float block_max = -INFINITY;
if (warp_id == 0) {
float val = (lane_id < num_warps) ? warp_data[lane_id] : -INFINITY;
block_max = warpReduceMax(val);
// 广播给整个 block if (lane_id == 0) {
warp_data[0] = block_max;
}
}
__syncthreads();
block_max = warp_data[0];
// 2.2 计算 exp 并求和
float thread_sum = 0.0f;
for (int j = tid; j < N; j += blockDim.x) {
float exp_val = expf(attn_scores[j] - block_max);
attn_scores[j] = exp_val;
thread_sum += exp_val;
}
float warp_sum = warpReduceSum(thread_sum);
if (lane_id == 0) {
warp_data[warp_id] = warp_sum;
}
__syncthreads();
// 最终 sum reduce float block_sum = 0.0f;
if (warp_id == 0) {
float val = (lane_id < num_warps) ? warp_data[lane_id] : 0.0f;
block_sum = warpReduceSum(val);
// 广播给整个 block if (lane_id == 0) {
warp_data[0] = block_sum;
}
}
__syncthreads();
block_sum = warp_data[0];
// 防止除零
if (block_sum == 0.0f) {
block_sum = 1e-10f;
}
// 2.3 归一化
for (int j = tid; j < N; j += blockDim.x) {
attn_scores[j] /= block_sum;
}
__syncthreads();
// === 步骤 3: Attention * V === for (int col = tid; col < w; col += blockDim.x) {
float acc = 0.0f;
for (int j = 0; j < N; ++j) {
acc += attn_scores[j] * V[j * d_model + col_off + col];
}
output[row * d_model + col_off + col] = acc;
}
}
其实与之前的题相比,这里其实主要的难度在维度划分或者说数据依赖分析,矩阵乘法,softmax都是比较典型的划分,根据数据格式就能方便的划分清楚,但是attention的数据划分到每个头上是交错的,所以其实是一种偏向逻辑维度的划分。 那我们直接拉成(1,N,h)的griddim,转化为h个头的N个softmax计算,Q * K^T 部分我们直接从主存拿跨行的数据(flash_attn),那么就可以比较简单完成划分,每个block 负责一行的softmax,那其实就非常方便了,直接reduce两次::: code-group
//防数据规模大于blockdim,reduce到blockdim的数据规模
float thread_max = -INFINITY;
for (int j = tid; j < N; j += blockDim.x) { // [!code focus]
thread_max = fmaxf(thread_max, attn_scores[j]);
}
//cuda执行模型每个thread 都会单独执行
//然后通过shuffle进行warp协作,每个warp都拿到warp内最大值
//float的声明每个thread都进行过,所以当数据不是32整数倍的时候行为也是正确的
float warp_max = warpReduceMax(thread_max); // [!code focus]
//reduce的结果是每个warp的第一个线程持有最终的结果
//写回shared memory间接广播到整个block
if (lane_id == 0) { // [!code focus]
warp_data[warp_id] = warp_max;
}
__syncthreads(); // [!code warning]
//block有一个约束:不超过1024,所以两次warp的32个数据的reduce 必然能拿到我们想要的数据
// 第一个 warp 完成最终 reduce
float block_max = -INFINITY;
if (warp_id == 0) { // [!code focus]
float val = (lane_id < num_warps) ? warp_data[lane_id] : -INFINITY;
block_max = warpReduceMax(val); // [!code focus]
}
// ... from previous step ...
if (warp_id == 0) {
// ...
// 广播给整个 block
if (lane_id == 0) { // [!code focus]
warp_data[0] = block_max;
}
}
__syncthreads(); // [!code warning]
block_max = warp_data[0]; // [!code focus]
warpReduceMax的逻辑是这样,reduce sum的时候一些值初始化成0就行了,不多赘述 然后拿到最终的reduce结果softmax一下就可以了,具体实现的难点我个人感觉是上面提到的维度划分和偏移量计算,因为前面提到的问题导致数据不连续就有比较复杂的偏移量计算,然后经常大脑过载取不到数据 tips:一些常量可以打上const,gpu有专有的内存空间用来存放常量
GPU 常量内存优化 一些在 Kernel 执行期间不变的变量可以标记为 const可以优化到常量内存里
逻辑维度的划分与映射
// Kernel 启动配置
dim3 gridDim(1, N, h); // [!code focus]
// ...
attention_fused_kernel<<<gridDim, blockDim, smem_size, stream>>>(...);
__global__ void attention_fused_kernel(...) {
// blockIdx.z 直接映射到 Head 的索引
int head = blockIdx.z; // [!code focus]
// blockIdx.y 直接映射到 Query Token (行) 的索引
int row = blockIdx.y; // [!code focus]
// ...
}
偏移量计算
// head_dimension, 即 d_k
const int w = d_model / h;
// ...
// Head offset within a row
int col_off = head * w; // [!code focus]
// Q + row * d_model -> 移动到第 row 行的行首
// + col_off -> 再从行首偏移到当前 head 的数据起点
const float* qrow = Q + row * d_model + col_off;
// 在 Q * K^T 循环中
const float* krow = K + j * d_model + col_off;
// 在 Attention * V 循环中
// V[j * d_model + col_off + col] 的基地址部分
V + j * d_model + col_off
线程工作映射
// tid 负责计算第 j 个注意力分数
for (int j = tid; j < N; j += blockDim.x) { // [!code focus]
// ... 内积计算 ...
attn_scores[j] = acc * scale;
}
// tid 负责计算输出向量的第 col 个分量
for (int col = tid; col < w; col += blockDim.x) { // [!code focus]
// ... 加权求和计算 ...
output[row * d_model + col_off + col] = acc;
}
差不多就是这些了