5.1 CUDA Softmax 朴素实现优化
本文从朴素实现出发,逐步引入 Safe Softmax、Block 级并行、Warp Shuffle、向量化访存等优化手段,带你写出既正确又高效的 Softmax Kernel。
Softmax 是 Transformer 中 Attention 计算的核心组件,看似简单的公式背后隐藏着数值溢出陷阱。本文从朴素实现暴露的数值问题出发,逐步引入 Safe Softmax、Block 级并行、Warp Shuffle、向量化访存等优化手段,每一步都有原理分析和性能对比,帮你写出既正确又高效的 Softmax Kernel。
📑 目录
- 1. Softmax 基础与数值问题
- 2. 版本 V0:朴素实现(单线程处理一行)
- 3. 版本 V1:Safe Softmax 三遍扫描
- 4. 版本 V2:Warp Shuffle 优化规约
- 5. 版本 V3:向量化加载 float4
- 6. 版本 V4:两遍融合 Kernel
- 7. 性能对比与工程建议
- 总结
- 自我检验清单
- 参考资料
1. Softmax 基础与数值问题
1.1 什么是 Softmax
假设你要从一堆分数中挑选出”每个分数相对于总体的重要程度”——高分者占比大、低分者占比小、且所有占比加起来恰好为 1。Softmax 就是干这件事的函数:它把一组任意实数”压缩”成一个概率分布。
数学定义:对于输入向量 ,Softmax 输出为:
在 Transformer 的 Self-Attention 中,Softmax 作用于 矩阵的每一行,将原始的注意力分数转化为归一化权重。其输入维度通常是序列长度 (如 2048、4096 甚至 128K),这使得 Softmax 的高效实现对模型整体性能有直接影响。
1.2 数值溢出问题
Softmax 的公式涉及指数运算 ,这在浮点数表示中极易溢出:
- FP32:最大可表示值约为 ,对应 。当 时, 就变成
+inf - FP16:最大值约为 ,对应 。当 时即溢出
- 下溢同样危险:当 很小(如 )时,,分母可能下溢为 0,导致
NaN
在实际的 Attention 计算中, 的值可以轻松超过 88(尤其是在大 且未做缩放的情况下),朴素实现必然崩溃。
1.3 解决方案:减最大值技巧
数学上可以证明,对输入向量的每个元素减去同一个常数 ,不改变 Softmax 的输出:
选 后,所有指数的参数都 ,结果落在 之间,彻底消除上溢风险。这就是 Safe Softmax(数值稳定 Softmax)的核心。
1.4 计算流程与性能瓶颈
Safe Softmax 的计算自然分为三遍扫描(Three-Pass):
| 遍次 | 操作 | 数学表达 |
|---|---|---|
| 第 1 遍 | 求行最大值 | |
| 第 2 遍 | 指数求和 | |
| 第 3 遍 | 归一化 |
每一遍都需要读取整行数据,对于长度为 的向量,总共需要 次全局内存读取。Softmax 的计算强度很低(每次读取对应 1-2 次浮点运算),属于典型的 Memory-Bound 操作,优化目标是减少内存访问次数和提升带宽利用率。
1.5 测试环境
本文所有代码使用 CUDA 12.x 编写,测试在 A100 80GB SXM4 上进行:
| 指标 | 数值 |
|---|---|
| 理论内存带宽 | 2 TB/s |
| SM 数量 | 108 |
| 每 SM Shared Memory | 164 KB |
测试配置:矩阵形状为 ,即 4096 行、每行 4096 个元素(模拟典型 Attention 场景)。
2. 版本 V0:朴素实现(单线程处理一行)
2.1 算法思路
最简单的并行化策略:每个线程独立处理输入矩阵的一行,串行完成该行的 max、exp-sum、normalize 三步。行与行之间完全独立,天然并行。
2.2 Kernel 实现
// V0: 朴素实现,一个线程处理一行
__global__ void softmax_v0(float* input, float* output, int M, int N) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M) return;
float* x = input + row * N;
float* y = output + row * N;
// Pass 1: 求最大值
float max_val = -INFINITY;
for (int i = 0; i < N; i++) {
max_val = fmaxf(max_val, x[i]);
}
// Pass 2: 指数求和
float sum = 0.0f;
for (int i = 0; i < N; i++) {
sum += expf(x[i] - max_val);
}
// Pass 3: 归一化
float inv_sum = 1.0f / sum;
for (int i = 0; i < N; i++) {
y[i] = expf(x[i] - max_val) * inv_sum;
}
}
2.3 性能问题分析
⚠️ 注意:V0 虽然正确,但性能极差,原因有三:
- 并行度不足:每行只有 1 个线程, 个元素被串行处理。GPU 的数千个 CUDA Core 绝大部分在空闲
- 非合并访存:同一 Warp 内的 32 个线程处理 32 个不同行,它们在内循环中访问的地址相差 字节(一整行的跨度),造成严重的非合并访存
- 冗余计算:
expf(x[i] - max_val)在 Pass 2 和 Pass 3 中重复计算了两次
V0 实测带宽利用率:约 5%(~102 GB/s)——基本是”可以跑通”但不可用的状态。
3. 版本 V1:Safe Softmax 三遍扫描
3.1 改进思路
V0 的核心问题是并行度不够——一行 个元素只用 1 个线程处理。解决方案是让一个 Block 内的多个线程协作处理同一行。每个线程负责一行中的一段数据,通过 Shared Memory 做并行规约求出 max 和 sum。
这等于把 Softmax 分解为两个 Reduce 操作(max-reduce + sum-reduce)加一个 element-wise 操作(normalize),每个 Reduce 都使用我们在 Reduce 优化文章中学到的技术。
3.2 Kernel 实现
// V1: 一个 Block 处理一行,Block 内并行规约
__global__ void softmax_v1(float* input, float* output, int M, int N) {
extern __shared__ float smem[]; // 用于规约
int row = blockIdx.x;
int tid = threadIdx.x;
float* x = input + row * N;
float* y = output + row * N;
// Pass 1: 并行求最大值
float max_val = -INFINITY;
for (int i = tid; i < N; i += blockDim.x) {
max_val = fmaxf(max_val, x[i]);
}
smem[tid] = max_val;
__syncthreads();
// Shared Memory 规约求全局最大值
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] = fmaxf(smem[tid], smem[tid + s]);
}
__syncthreads();
}
max_val = smem[0];
__syncthreads();
// Pass 2: 并行求指数和
float sum = 0.0f;
for (int i = tid; i < N; i += blockDim.x) {
sum += expf(x[i] - max_val);
}
smem[tid] = sum;
__syncthreads();
// Shared Memory 规约求总和
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
sum = smem[0];
__syncthreads();
// Pass 3: 归一化写出
float inv_sum = 1.0f / sum;
for (int i = tid; i < N; i += blockDim.x) {
y[i] = expf(x[i] - max_val) * inv_sum;
}
}
调用方式:
int block_size = 256;
int smem_size = block_size * sizeof(float);
softmax_v1<<<M, block_size, smem_size>>>(d_input, d_output, M, N);
3.3 性能分析
相比 V0 的提升来源:
- 并行度大幅提升:256 个线程协作处理一行,每线程只需处理 个元素
- 合并访存:同一 Block 内的相邻线程访问行内的相邻地址(stride=1),形成完美的合并访存
仍存在的问题:
- Shared Memory 规约需要多轮
__syncthreads() - 三遍扫描意味着输入数据被读取了 3 次
V1 实测带宽利用率:约 35%(~714 GB/s),相比 V0 提升约 7 倍。
4. 版本 V2:Warp Shuffle 优化规约
4.1 改进思路
V1 的规约完全依赖 Shared Memory,存在两个问题:一是数据需要”寄存器 → SMEM → 寄存器”往返,路径偏长;二是每一轮规约都需要 __syncthreads(),在线程数已经较少的后期阶段同步开销占比偏高。
从 Reduce 优化的经验来看,Warp Shuffle(__shfl_down_sync)可以在寄存器间直接交换数据,延迟更低、无 Bank Conflict、Warp 内天然同步。我们将规约改为”Warp 内 Shuffle + Warp 间 Shared Memory”的两级结构:先在每个 Warp 内用 Shuffle 规约出局部结果,再把每个 Warp 的结果写到一个 32 元素的 SMEM 数组,由第一个 Warp 做最终规约。
4.2 Kernel 实现
// Warp 内规约:Max
__device__ float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
// Warp 内规约:Sum
__device__ float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
// 两级规约:Block 内 Max
__device__ float blockReduceMaxShuffle(float val) {
__shared__ float warp_max[32];
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
val = warpReduceMax(val);
if (lane == 0) warp_max[wid] = val;
__syncthreads();
int num_warps = blockDim.x / 32;
val = (lane < num_warps) ? warp_max[lane] : -INFINITY;
if (wid == 0) val = warpReduceMax(val);
// 广播结果给所有线程
__shared__ float block_result;
if (threadIdx.x == 0) block_result = val;
__syncthreads();
return block_result;
}
// 两级规约:Block 内 Sum
__device__ float blockReduceSumShuffle(float val) {
__shared__ float warp_sum[32];
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
val = warpReduceSum(val);
if (lane == 0) warp_sum[wid] = val;
__syncthreads();
int num_warps = blockDim.x / 32;
val = (lane < num_warps) ? warp_sum[lane] : 0.0f;
if (wid == 0) val = warpReduceSum(val);
__shared__ float block_result;
if (threadIdx.x == 0) block_result = val;
__syncthreads();
return block_result;
}
// V2: Warp Shuffle 两级规约
__global__ void softmax_v2(float* input, float* output, int M, int N) {
int row = blockIdx.x;
int tid = threadIdx.x;
float* x = input + row * N;
float* y = output + row * N;
// Pass 1: 求行最大值
float local_max = -INFINITY;
for (int i = tid; i < N; i += blockDim.x) {
local_max = fmaxf(local_max, x[i]);
}
float max_val = blockReduceMaxShuffle(local_max);
// Pass 2: 求指数和
float local_sum = 0.0f;
for (int i = tid; i < N; i += blockDim.x) {
local_sum += expf(x[i] - max_val);
}
float sum = blockReduceSumShuffle(local_sum);
// Pass 3: 归一化
float inv_sum = 1.0f / sum;
for (int i = tid; i < N; i += blockDim.x) {
y[i] = expf(x[i] - max_val) * inv_sum;
}
}
4.3 为什么 Warp Shuffle 更快
| 对比维度 | Shared Memory 规约 | Warp Shuffle 规约 |
|---|---|---|
| 数据路径 | 寄存器 → SMEM → 寄存器 | 寄存器 → 寄存器 |
| Bank Conflict | 可能存在 | 不存在 |
| 同步开销 | 每轮需要 __syncthreads() | Warp 内天然同步 |
| 延迟 | ~20-30 cycles | ~5 cycles |
💡 提示:Warp Shuffle 的延迟约为 Shared Memory 的 1/4 到 1/6(~5 cycles vs ~20-30 cycles),且无需 __syncthreads() 同步。在规约阶段线程数已经很少的情况下,这种延迟差距和同步开销的节省对总体性能的影响尤为显著。
V2 实测带宽利用率:约 52%(~1060 GB/s),相比 V1 提升约 1.5 倍。
5. 版本 V3:向量化加载 float4
5.1 改进思路
V2 的瓶颈已经转移到全局内存加载阶段。每次循环中,每个线程只加载 1 个 float(4 字节),指令调度开销相对于数据吞吐比例偏高。使用 float4 向量化加载,每条指令搬运 16 字节(4 个 float),可以:
- 减少加载指令总数(减少 4 倍的循环迭代)
- 提升内存事务的利用效率
- 增加指令级并行(ILP)
5.2 Kernel 实现
// V3: float4 向量化加载
__global__ void softmax_v3(float* input, float* output, int M, int N) {
int row = blockIdx.x;
int tid = threadIdx.x;
float* x = input + row * N;
float* y = output + row * N;
// 向量化指针
float4* x4 = reinterpret_cast<float4*>(x);
float4* y4 = reinterpret_cast<float4*>(y);
int N4 = N / 4; // float4 元素数
// Pass 1: 向量化求最大值
float local_max = -INFINITY;
for (int i = tid; i < N4; i += blockDim.x) {
float4 data = x4[i];
local_max = fmaxf(local_max, fmaxf(fmaxf(data.x, data.y),
fmaxf(data.z, data.w)));
}
// 处理尾部元素(N 不是 4 的倍数时)
for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
local_max = fmaxf(local_max, x[i]);
}
float max_val = blockReduceMaxShuffle(local_max);
// Pass 2: 向量化求指数和
float local_sum = 0.0f;
for (int i = tid; i < N4; i += blockDim.x) {
float4 data = x4[i];
local_sum += expf(data.x - max_val) + expf(data.y - max_val)
+ expf(data.z - max_val) + expf(data.w - max_val);
}
for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
local_sum += expf(x[i] - max_val);
}
float sum = blockReduceSumShuffle(local_sum);
// Pass 3: 向量化归一化
float inv_sum = 1.0f / sum;
for (int i = tid; i < N4; i += blockDim.x) {
float4 data = x4[i];
float4 result;
result.x = expf(data.x - max_val) * inv_sum;
result.y = expf(data.y - max_val) * inv_sum;
result.z = expf(data.z - max_val) * inv_sum;
result.w = expf(data.w - max_val) * inv_sum;
y4[i] = result;
}
for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
y[i] = expf(x[i] - max_val) * inv_sum;
}
}
5.3 性能改进
每个线程的循环次数从 降到 ,减少了 4 倍的循环迭代开销和指令调度次数。同时 float4 加载能更好地隐藏内存延迟。
⚠️ 注意:使用 float4 要求输入数组 16 字节对齐。在实际工程中,通常在内存分配时使用 cudaMalloc(保证 256 字节对齐),并确保 是 4 的倍数(不满足时用尾部处理兜底)。
V3 实测带宽利用率:约 65%(~1325 GB/s),相比 V2 提升约 1.25 倍。
6. 版本 V4:两遍融合 Kernel
6.1 三遍扫描的冗余
回顾 V1-V3 的计算流程:三遍扫描读取输入 3 次。对于 的矩阵,总读取量为 MB。如果能把 Pass 2(exp-sum)和 Pass 3(normalize)合并为一遍,就能减少到 2 次读取(128 MB),读取量减少 1/3,理论加速比为 1.5 倍。
合并的关键观察:Pass 3 计算 时需要用到 (指数和),而 在 Pass 2 结束后才确定。如果我们不存储 Pass 2 的中间结果,只保存 和 ,然后在第二遍同时计算 和除以 ,就实现了融合。
事实上,纯粹的两遍方案需要 Online Softmax 算法(下篇文章的主题)。但我们可以做一个折中——将 Pass 2 和 Pass 3 部分融合:在第 2 遍中,每个线程将自己负责的 暂存到 Shared Memory,然后求出全局 后立即从 SMEM 读取并写出归一化结果,避免第 3 次全局内存读取。
6.2 Kernel 实现
// V4: 两遍 Kernel(Pass 1 求 max,Pass 2 融合 exp-sum + normalize)
// 需要额外 Shared Memory 暂存 exp 值
__global__ void softmax_v4(float* input, float* output, int M, int N) {
extern __shared__ float smem[]; // 大小 = N(暂存 exp 值)
int row = blockIdx.x;
int tid = threadIdx.x;
float* x = input + row * N;
float* y = output + row * N;
float4* x4 = reinterpret_cast<float4*>(x);
int N4 = N / 4;
// Pass 1: 求行最大值(向量化 + Warp Shuffle)
float local_max = -INFINITY;
for (int i = tid; i < N4; i += blockDim.x) {
float4 data = x4[i];
local_max = fmaxf(local_max, fmaxf(fmaxf(data.x, data.y),
fmaxf(data.z, data.w)));
}
for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
local_max = fmaxf(local_max, x[i]);
}
float max_val = blockReduceMaxShuffle(local_max);
// Pass 2: 计算 exp 并暂存,同时累加求 sum
float local_sum = 0.0f;
for (int i = tid; i < N; i += blockDim.x) {
float exp_val = expf(x[i] - max_val);
smem[i] = exp_val; // 暂存到 Shared Memory
local_sum += exp_val;
}
float sum = blockReduceSumShuffle(local_sum);
// 直接从 Shared Memory 读取 exp 值并归一化写出
float inv_sum = 1.0f / sum;
for (int i = tid; i < N; i += blockDim.x) {
y[i] = smem[i] * inv_sum;
}
}
6.3 局限性与适用场景
⚠️ 注意:V4 需要 字节的 Shared Memory 来暂存 exp 值。当 时需要 16KB,在 A100(每 SM 164KB Shared Memory)上完全够用。但当 很大(如 32768 以上)时,Shared Memory 装不下整行数据,此方案就不适用了。
对于大 的场景,真正的解决方案是 Online Softmax——只需一次扫描(或两次,但不需要暂存整行),将在下一篇文章详细介绍。
| 适用场景 | 推荐方案 |
|---|---|
| (SMEM 放得下) | V4 两遍融合 |
| (SMEM 不够) | V3 三遍扫描,或 Online Softmax |
V4 实测带宽利用率:约 75%(~1529 GB/s)(),相比 V3 提升约 1.15 倍。
7. 性能对比与工程建议
7.1 各版本性能汇总
| 版本 | 核心优化点 | 带宽利用率 | 相对加速 |
|---|---|---|---|
| V0 单线程/行 | 无 | ~5% | 1.0x |
| V1 Block 并行 | 多线程协作 + 合并访存 + SMEM 规约 | ~35% | 7.0x |
| V2 Warp Shuffle | 寄存器级两级规约 | ~52% | 10.4x |
| V3 向量化加载 | float4 减少指令数 | ~65% | 13.0x |
| V4 两遍融合 | 减少一次全局读取 | ~75% | 15.0x |
📌 关键点:V4 在 时达到 ~75% 带宽利用率。进一步的优化需要引入 Online Softmax 算法,从算法层面减少扫描次数。
7.2 实际工程选择
| 场景 | 推荐方案 |
|---|---|
| 学习/教学 | V1 或 V2,逻辑清晰 |
| 的生产环境 | V4(两遍融合 + Warp Shuffle + float4) |
| 的生产环境 | Online Softmax(下篇文章) |
| 追求极致性能 | 使用 cuDNN 或 FlashAttention 中的融合实现 |
💡 提示:在实际的 LLM 推理框架(如 vLLM、TensorRT-LLM)中,Softmax 通常和 Attention 的其他部分(Scale、Mask、MatMul)融合成一个大 Kernel,避免中间结果落地到全局内存。单独优化 Softmax 主要用于理解原理和特定独立场景(如分类层的 Softmax)。
📝 总结
从 V0 到 V4,Softmax 优化的核心思路可以归纳为三条主线:
- 提升并行度:从单线程处理整行 → Block 内多线程协作,利用规约实现 max 和 sum 的并行计算(V0 → V1)
- 加速规约:从 Shared Memory 朴素循环 → Warp Shuffle 两级规约,在寄存器层面完成数据交换,大幅降低同步与访存延迟(V1 → V2)
- 减少内存访问:从逐元素加载 → float4 向量化 → 两遍融合减少重复读取,逼近 Memory-Bound 的理论极限(V2 → V3 → V4)
而更进一步的突破——从三遍扫描到一遍扫描——需要从算法层面重新设计,这就是下一篇 Online Softmax 要解决的问题。
🎯 自我检验清单
- 能解释朴素 Softmax 在 FP32/FP16 下为什么会数值溢出
- 能证明”减最大值”不改变 Softmax 输出结果的数学原理
- 能说明 Safe Softmax 三遍扫描各自的作用及为什么需要三遍
- 能写出 Block 级并行规约求 max 和 sum 的 CUDA 代码
- 能使用
__shfl_down_sync实现 Warp 内 max-reduce 和 sum-reduce - 能解释两级规约(Warp 内 + Warp 间)的完整流程
- 能说明 float4 向量化加载的对齐要求和性能收益来源
- 能分析两遍融合方案的 Shared Memory 容量限制
- 能根据 的大小选择合适的 Softmax 实现方案
- 能用 Nsight Compute 测量 Softmax Kernel 的带宽利用率并定位瓶颈
📚 参考资料
- NVIDIA CUDA C++ Programming Guide - Mathematical Functions
- Online normalizer calculation for softmax - Milakov & Gimelshein, 2018
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- CUDA C++ Best Practices Guide - Memory Optimizations
- 深入浅出GPU优化系列:softmax优化
- NVIDIA cuDNN Documentation - Softmax
- From Online Softmax to FlashAttention