Sealessland logo Sealessland
CNSS 2025

Multi-head Attention

你需要写一个运行在 GPU 上的程序,实现多头注意力机制。

HPC

你需要写一个运行在 GPU 上的程序,实现多头注意力机制。具体来说,给出 Q (Query), K (Key), V (Value) 三个形状为 N * d_model 的矩阵,你需要计算:

MultiHeadi=Concat(head1,head2...headN)MultiHead_i = Concat(head_1, head_2...head_N)

同时:

headi=softmax(QiKiTdk)Vihead_i = softmax(\frac{Q_i K_i^T}{\sqrt{d_k}})V_i

Example :

N = 2, d_model = 4, h = 2

Q=[1.00.02.03.04.05.06.07.0]Q = \begin{bmatrix} 1.0 & 0.0 & 2.0 & 3.0 \\ 4.0 & 5.0 & 6.0 & 7.0 \end{bmatrix} K=[1.02.03.04.05.06.07.08.0]K = \begin{bmatrix} 1.0 & 2.0 & 3.0 & 4.0 \\ 5.0 & 6.0 & 7.0 & 8.0 \end{bmatrix} V=[0.51.01.52.02.53.03.54.0]V = \begin{bmatrix} 0.5 & 1.0 & 1.5 & 2.0 \\ 2.5 & 3.0 & 3.5 & 4.0 \end{bmatrix} Output=[2.392.893.504.002.503.003.504.00]Output = \begin{bmatrix} 2.39 & 2.89 & 3.50 & 4.00 \\ 2.50 & 3.00 & 3.50 & 4.00 \end{bmatrix}

数据限制:

  • 1 ≤ N ≤ 1000
  • 2 ≤ d_model ≤ 1024
  • 1 ≤ h ≤ d_model
  • d_model % h == 0
  • -10.0 ≤ values ≤ 10.0

要求:

  • 不能使用外部库
  • 不允许修改 Solve 函数

Answer:

__device__ inline float warpReduceMax(float val) {  
    for (int step = 32 / 2; step > 0; step /= 2) {  
        val = fmaxf(val, __shfl_down_sync(0xffffffff, val, step));  
    }  
    return __shfl_sync(0xffffffff, val, 0);  
}  
  
__device__ inline float warpReduceSum(float val) {  
    for (int step = 32 / 2; step > 0; step /= 2) {  
        val += __shfl_down_sync(0xffffffff, val, step);  
    }  
    return __shfl_sync(0xffffffff, val, 0);  
}

__global__ void attention_fused_kernel(  
    const float* Q, const float* K, const float* V,  
    float* output,  
    int N, int d_model, int h  
) {  
    int head = blockIdx.z;  
    int row = blockIdx.y;  
  
    if (head >= h || row >= N) return;  
  
    const int w = d_model / h;  
    const float scale = 1.0f / sqrtf((float)w);  
    int col_off = head * w;  
  
    extern __shared__ float smem[];  
    float* attn_scores = smem;  
    // 预留足够空间给 warp reduce (最多 32 个 warp)    float* warp_data = smem + N;  
  
    int tid = threadIdx.x;  
    int lane_id = tid % 32;  
    int warp_id = tid / 32;  
    int num_warps = (blockDim.x + 31) / 32;  
  
    // === 步骤 1: Q * K^T ===    for (int j = tid; j < N; j += blockDim.x) {  
        float acc = 0.0f;  
        const float* qrow = Q + row * d_model + col_off;  
        const float* krow = K + j * d_model + col_off;  
  
        for (int k = 0; k < w; ++k) {  
            acc += qrow[k] * krow[k];  
        }  
        attn_scores[j] = acc * scale;  
    }  
    __syncthreads();  
  
    // === 步骤 2: Softmax ===    // 2.1 找最大值  
    float thread_max = -INFINITY;  
    for (int j = tid; j < N; j += blockDim.x) {  
        thread_max = fmaxf(thread_max, attn_scores[j]);  
    }  
  
    float warp_max = warpReduceMax(thread_max);  
  
    if (lane_id == 0) {  
        warp_data[warp_id] = warp_max;  
    }  
    __syncthreads();  
  
    // 第一个 warp 完成最终 reduce    float block_max = -INFINITY;  
    if (warp_id == 0) {  
        float val = (lane_id < num_warps) ? warp_data[lane_id] : -INFINITY;  
        block_max = warpReduceMax(val);  
        // 广播给整个 block        if (lane_id == 0) {  
            warp_data[0] = block_max;  
        }  
    }  
    __syncthreads();  
    block_max = warp_data[0];  
  
    // 2.2 计算 exp 并求和  
    float thread_sum = 0.0f;  
    for (int j = tid; j < N; j += blockDim.x) {  
        float exp_val = expf(attn_scores[j] - block_max);  
        attn_scores[j] = exp_val;  
        thread_sum += exp_val;  
    }  
  
    float warp_sum = warpReduceSum(thread_sum);  
  
    if (lane_id == 0) {  
        warp_data[warp_id] = warp_sum;  
    }  
    __syncthreads();  
  
    // 最终 sum reduce    float block_sum = 0.0f;  
    if (warp_id == 0) {  
        float val = (lane_id < num_warps) ? warp_data[lane_id] : 0.0f;  
        block_sum = warpReduceSum(val);  
        // 广播给整个 block        if (lane_id == 0) {  
            warp_data[0] = block_sum;  
        }  
    }  
    __syncthreads();  
    block_sum = warp_data[0];  
  
    // 防止除零  
    if (block_sum == 0.0f) {  
        block_sum = 1e-10f;  
    }  
  
    // 2.3 归一化  
    for (int j = tid; j < N; j += blockDim.x) {  
        attn_scores[j] /= block_sum;  
    }  
    __syncthreads();  
  
    // === 步骤 3: Attention * V ===    for (int col = tid; col < w; col += blockDim.x) {  
        float acc = 0.0f;  
        for (int j = 0; j < N; ++j) {  
            acc += attn_scores[j] * V[j * d_model + col_off + col];  
        }  
        output[row * d_model + col_off + col] = acc;  
    }  
}

其实与之前的题相比,这里其实主要的难度在维度划分或者说数据依赖分析,矩阵乘法,softmax都是比较典型的划分,根据数据格式就能方便的划分清楚,但是attention的数据划分到每个头上是交错的,所以其实是一种偏向逻辑维度的划分。 那我们直接拉成(1,N,h)的griddim,转化为h个头的N个softmax计算,Q * K^T 部分我们直接从主存拿跨行的数据(flash_attn),那么就可以比较简单完成划分,每个block 负责一行的softmax,那其实就非常方便了,直接reduce两次::: code-group

//防数据规模大于blockdim,reduce到blockdim的数据规模
float thread_max = -INFINITY;
for (int j = tid; j < N; j += blockDim.x) { // [!code focus]
    thread_max = fmaxf(thread_max, attn_scores[j]);
}
//cuda执行模型每个thread 都会单独执行
//然后通过shuffle进行warp协作,每个warp都拿到warp内最大值
//float的声明每个thread都进行过,所以当数据不是32整数倍的时候行为也是正确的
float warp_max = warpReduceMax(thread_max); // [!code focus]
//reduce的结果是每个warp的第一个线程持有最终的结果
//写回shared memory间接广播到整个block
if (lane_id == 0) { // [!code focus]
    warp_data[warp_id] = warp_max;
}
__syncthreads(); // [!code warning]
//block有一个约束:不超过1024,所以两次warp的32个数据的reduce 必然能拿到我们想要的数据
// 第一个 warp 完成最终 reduce
float block_max = -INFINITY;
if (warp_id == 0) { // [!code focus]
    float val = (lane_id < num_warps) ? warp_data[lane_id] : -INFINITY;
    block_max = warpReduceMax(val); // [!code focus]
}
// ... from previous step ...
if (warp_id == 0) {
    // ...
    // 广播给整个 block
    if (lane_id == 0) { // [!code focus]
        warp_data[0] = block_max;
    }
}
__syncthreads(); // [!code warning]
block_max = warp_data[0]; // [!code focus]

warpReduceMax的逻辑是这样,reduce sum的时候一些值初始化成0就行了,不多赘述 然后拿到最终的reduce结果softmax一下就可以了,具体实现的难点我个人感觉是上面提到的维度划分和偏移量计算,因为前面提到的问题导致数据不连续就有比较复杂的偏移量计算,然后经常大脑过载取不到数据 tips:一些常量可以打上const,gpu有专有的内存空间用来存放常量

GPU 常量内存优化 一些在 Kernel 执行期间不变的变量可以标记为 const可以优化到常量内存里

逻辑维度的划分与映射

// Kernel 启动配置
dim3 gridDim(1, N, h); // [!code focus]
// ...
attention_fused_kernel<<<gridDim, blockDim, smem_size, stream>>>(...);
__global__ void attention_fused_kernel(...) {  
    // blockIdx.z 直接映射到 Head 的索引
    int head = blockIdx.z; // [!code focus]
    // blockIdx.y 直接映射到 Query Token (行) 的索引
    int row = blockIdx.y;  // [!code focus]
    // ...
}

偏移量计算

// head_dimension, 即 d_k
const int w = d_model / h;
// ...
// Head offset within a row
int col_off = head * w; // [!code focus]
// Q + row * d_model            -> 移动到第 row 行的行首
//             + col_off        -> 再从行首偏移到当前 head 的数据起点
const float* qrow = Q + row * d_model + col_off;

// 在 Q * K^T 循环中
const float* krow = K + j * d_model + col_off;

// 在 Attention * V 循环中
// V[j * d_model + col_off + col] 的基地址部分
V + j * d_model + col_off

线程工作映射

// tid 负责计算第 j 个注意力分数
for (int j = tid; j < N; j += blockDim.x) { // [!code focus]
    // ... 内积计算 ...
    attn_scores[j] = acc * scale;
}
// tid 负责计算输出向量的第 col 个分量
for (int col = tid; col < w; col += blockDim.x) { // [!code focus]
    // ... 加权求和计算 ...
    output[row * d_model + col_off + col] = acc;
}

差不多就是这些了