跳到主要内容
CUDA编程与算子优化

5.1 CUDA Softmax 朴素实现优化

本文从朴素实现出发,逐步引入 Safe Softmax、Block 级并行、Warp Shuffle、向量化访存等优化手段,带你写出既正确又高效的 Softmax Kernel。

CUDA GPU Softmax 数值稳定性 算子优化

Softmax 是 Transformer 中 Attention 计算的核心组件,看似简单的公式背后隐藏着数值溢出陷阱。本文从朴素实现暴露的数值问题出发,逐步引入 Safe Softmax、Block 级并行、Warp Shuffle、向量化访存等优化手段,每一步都有原理分析和性能对比,帮你写出既正确又高效的 Softmax Kernel。

📑 目录


1. Softmax 基础与数值问题

1.1 什么是 Softmax

假设你要从一堆分数中挑选出”每个分数相对于总体的重要程度”——高分者占比大、低分者占比小、且所有占比加起来恰好为 1。Softmax 就是干这件事的函数:它把一组任意实数”压缩”成一个概率分布。

数学定义:对于输入向量 x=(x0,x1,,xN1)\mathbf{x} = (x_0, x_1, \ldots, x_{N-1}),Softmax 输出为:

Softmax(xi)=exij=0N1exj\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=0}^{N-1} e^{x_j}}

在 Transformer 的 Self-Attention 中,Softmax 作用于 QKTQK^T 矩阵的每一行,将原始的注意力分数转化为归一化权重。其输入维度通常是序列长度 NN(如 2048、4096 甚至 128K),这使得 Softmax 的高效实现对模型整体性能有直接影响。

1.2 数值溢出问题

Softmax 的公式涉及指数运算 exie^{x_i},这在浮点数表示中极易溢出:

  • FP32:最大可表示值约为 3.4×10383.4 \times 10^{38},对应 e88.7e^{88.7}。当 xi>88.7x_i > 88.7 时,exie^{x_i} 就变成 +inf
  • FP16:最大值约为 6.55×1046.55 \times 10^4,对应 e11.09e^{11.09}。当 xi>11x_i > 11 时即溢出
  • 下溢同样危险:当 xix_i 很小(如 100-100)时,exi0e^{x_i} \approx 0,分母可能下溢为 0,导致 NaN

在实际的 Attention 计算中,QKTQK^T 的值可以轻松超过 88(尤其是在大 dkd_k 且未做缩放的情况下),朴素实现必然崩溃。

1.3 解决方案:减最大值技巧

数学上可以证明,对输入向量的每个元素减去同一个常数 cc,不改变 Softmax 的输出:

exicjexjc=exiecjexjec=exijexj\frac{e^{x_i - c}}{\sum_{j} e^{x_j - c}} = \frac{e^{x_i} \cdot e^{-c}}{\sum_{j} e^{x_j} \cdot e^{-c}} = \frac{e^{x_i}}{\sum_{j} e^{x_j}}

c=max(x)c = \max(\mathbf{x}) 后,所有指数的参数都 0\leq 0,结果落在 (0,1](0, 1] 之间,彻底消除上溢风险。这就是 Safe Softmax(数值稳定 Softmax)的核心。

1.4 计算流程与性能瓶颈

Safe Softmax 的计算自然分为三遍扫描(Three-Pass):

遍次操作数学表达
第 1 遍求行最大值m=maxj(xj)m = \max_j(x_j)
第 2 遍指数求和d=jexjmd = \sum_j e^{x_j - m}
第 3 遍归一化yi=exim/dy_i = e^{x_i - m} / d

每一遍都需要读取整行数据,对于长度为 NN 的向量,总共需要 3N3N 次全局内存读取。Softmax 的计算强度很低(每次读取对应 1-2 次浮点运算),属于典型的 Memory-Bound 操作,优化目标是减少内存访问次数提升带宽利用率

1.5 测试环境

本文所有代码使用 CUDA 12.x 编写,测试在 A100 80GB SXM4 上进行:

指标数值
理论内存带宽2 TB/s
SM 数量108
每 SM Shared Memory164 KB

测试配置:矩阵形状为 (M,N)=(4096,4096)(M, N) = (4096, 4096),即 4096 行、每行 4096 个元素(模拟典型 Attention 场景)。


2. 版本 V0:朴素实现(单线程处理一行)

2.1 算法思路

最简单的并行化策略:每个线程独立处理输入矩阵的一行,串行完成该行的 max、exp-sum、normalize 三步。行与行之间完全独立,天然并行。

2.2 Kernel 实现

// V0: 朴素实现,一个线程处理一行
__global__ void softmax_v0(float* input, float* output, int M, int N) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= M) return;

    float* x = input  + row * N;
    float* y = output + row * N;

    // Pass 1: 求最大值
    float max_val = -INFINITY;
    for (int i = 0; i < N; i++) {
        max_val = fmaxf(max_val, x[i]);
    }

    // Pass 2: 指数求和
    float sum = 0.0f;
    for (int i = 0; i < N; i++) {
        sum += expf(x[i] - max_val);
    }

    // Pass 3: 归一化
    float inv_sum = 1.0f / sum;
    for (int i = 0; i < N; i++) {
        y[i] = expf(x[i] - max_val) * inv_sum;
    }
}

2.3 性能问题分析

⚠️ 注意:V0 虽然正确,但性能极差,原因有三:

  1. 并行度不足:每行只有 1 个线程,N=4096N=4096 个元素被串行处理。GPU 的数千个 CUDA Core 绝大部分在空闲
  2. 非合并访存:同一 Warp 内的 32 个线程处理 32 个不同行,它们在内循环中访问的地址相差 N×4N \times 4 字节(一整行的跨度),造成严重的非合并访存
  3. 冗余计算expf(x[i] - max_val) 在 Pass 2 和 Pass 3 中重复计算了两次

V0 实测带宽利用率:约 5%(~102 GB/s)——基本是”可以跑通”但不可用的状态。


3. 版本 V1:Safe Softmax 三遍扫描

3.1 改进思路

V0 的核心问题是并行度不够——一行 NN 个元素只用 1 个线程处理。解决方案是让一个 Block 内的多个线程协作处理同一行。每个线程负责一行中的一段数据,通过 Shared Memory 做并行规约求出 max 和 sum。

这等于把 Softmax 分解为两个 Reduce 操作(max-reduce + sum-reduce)加一个 element-wise 操作(normalize),每个 Reduce 都使用我们在 Reduce 优化文章中学到的技术。

3.2 Kernel 实现

// V1: 一个 Block 处理一行,Block 内并行规约
__global__ void softmax_v1(float* input, float* output, int M, int N) {
    extern __shared__ float smem[];  // 用于规约

    int row = blockIdx.x;
    int tid = threadIdx.x;

    float* x = input  + row * N;
    float* y = output + row * N;

    // Pass 1: 并行求最大值
    float max_val = -INFINITY;
    for (int i = tid; i < N; i += blockDim.x) {
        max_val = fmaxf(max_val, x[i]);
    }
    smem[tid] = max_val;
    __syncthreads();

    // Shared Memory 规约求全局最大值
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            smem[tid] = fmaxf(smem[tid], smem[tid + s]);
        }
        __syncthreads();
    }
    max_val = smem[0];
    __syncthreads();

    // Pass 2: 并行求指数和
    float sum = 0.0f;
    for (int i = tid; i < N; i += blockDim.x) {
        sum += expf(x[i] - max_val);
    }
    smem[tid] = sum;
    __syncthreads();

    // Shared Memory 规约求总和
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            smem[tid] += smem[tid + s];
        }
        __syncthreads();
    }
    sum = smem[0];
    __syncthreads();

    // Pass 3: 归一化写出
    float inv_sum = 1.0f / sum;
    for (int i = tid; i < N; i += blockDim.x) {
        y[i] = expf(x[i] - max_val) * inv_sum;
    }
}

调用方式:

int block_size = 256;
int smem_size = block_size * sizeof(float);
softmax_v1<<<M, block_size, smem_size>>>(d_input, d_output, M, N);

3.3 性能分析

相比 V0 的提升来源:

  • 并行度大幅提升:256 个线程协作处理一行,每线程只需处理 N/256=16N/256 = 16 个元素
  • 合并访存:同一 Block 内的相邻线程访问行内的相邻地址(stride=1),形成完美的合并访存

仍存在的问题:

  • Shared Memory 规约需要多轮 __syncthreads()
  • 三遍扫描意味着输入数据被读取了 3 次

V1 实测带宽利用率:约 35%(~714 GB/s),相比 V0 提升约 7 倍。


4. 版本 V2:Warp Shuffle 优化规约

4.1 改进思路

V1 的规约完全依赖 Shared Memory,存在两个问题:一是数据需要”寄存器 → SMEM → 寄存器”往返,路径偏长;二是每一轮规约都需要 __syncthreads(),在线程数已经较少的后期阶段同步开销占比偏高。

从 Reduce 优化的经验来看,Warp Shuffle(__shfl_down_sync)可以在寄存器间直接交换数据,延迟更低、无 Bank Conflict、Warp 内天然同步。我们将规约改为”Warp 内 Shuffle + Warp 间 Shared Memory”的两级结构:先在每个 Warp 内用 Shuffle 规约出局部结果,再把每个 Warp 的结果写到一个 32 元素的 SMEM 数组,由第一个 Warp 做最终规约。

4.2 Kernel 实现

// Warp 内规约:Max
__device__ float warpReduceMax(float val) {
    for (int offset = 16; offset > 0; offset >>= 1) {
        val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
    }
    return val;
}

// Warp 内规约:Sum
__device__ float warpReduceSum(float val) {
    for (int offset = 16; offset > 0; offset >>= 1) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

// 两级规约:Block 内 Max
__device__ float blockReduceMaxShuffle(float val) {
    __shared__ float warp_max[32];

    int lane = threadIdx.x % 32;
    int wid  = threadIdx.x / 32;

    val = warpReduceMax(val);

    if (lane == 0) warp_max[wid] = val;
    __syncthreads();

    int num_warps = blockDim.x / 32;
    val = (lane < num_warps) ? warp_max[lane] : -INFINITY;
    if (wid == 0) val = warpReduceMax(val);

    // 广播结果给所有线程
    __shared__ float block_result;
    if (threadIdx.x == 0) block_result = val;
    __syncthreads();
    return block_result;
}

// 两级规约:Block 内 Sum
__device__ float blockReduceSumShuffle(float val) {
    __shared__ float warp_sum[32];

    int lane = threadIdx.x % 32;
    int wid  = threadIdx.x / 32;

    val = warpReduceSum(val);

    if (lane == 0) warp_sum[wid] = val;
    __syncthreads();

    int num_warps = blockDim.x / 32;
    val = (lane < num_warps) ? warp_sum[lane] : 0.0f;
    if (wid == 0) val = warpReduceSum(val);

    __shared__ float block_result;
    if (threadIdx.x == 0) block_result = val;
    __syncthreads();
    return block_result;
}

// V2: Warp Shuffle 两级规约
__global__ void softmax_v2(float* input, float* output, int M, int N) {
    int row = blockIdx.x;
    int tid = threadIdx.x;

    float* x = input  + row * N;
    float* y = output + row * N;

    // Pass 1: 求行最大值
    float local_max = -INFINITY;
    for (int i = tid; i < N; i += blockDim.x) {
        local_max = fmaxf(local_max, x[i]);
    }
    float max_val = blockReduceMaxShuffle(local_max);

    // Pass 2: 求指数和
    float local_sum = 0.0f;
    for (int i = tid; i < N; i += blockDim.x) {
        local_sum += expf(x[i] - max_val);
    }
    float sum = blockReduceSumShuffle(local_sum);

    // Pass 3: 归一化
    float inv_sum = 1.0f / sum;
    for (int i = tid; i < N; i += blockDim.x) {
        y[i] = expf(x[i] - max_val) * inv_sum;
    }
}

4.3 为什么 Warp Shuffle 更快

对比维度Shared Memory 规约Warp Shuffle 规约
数据路径寄存器 → SMEM → 寄存器寄存器 → 寄存器
Bank Conflict可能存在不存在
同步开销每轮需要 __syncthreads()Warp 内天然同步
延迟~20-30 cycles~5 cycles

💡 提示:Warp Shuffle 的延迟约为 Shared Memory 的 1/4 到 1/6(~5 cycles vs ~20-30 cycles),且无需 __syncthreads() 同步。在规约阶段线程数已经很少的情况下,这种延迟差距和同步开销的节省对总体性能的影响尤为显著。

V2 实测带宽利用率:约 52%(~1060 GB/s),相比 V1 提升约 1.5 倍。


5. 版本 V3:向量化加载 float4

5.1 改进思路

V2 的瓶颈已经转移到全局内存加载阶段。每次循环中,每个线程只加载 1 个 float(4 字节),指令调度开销相对于数据吞吐比例偏高。使用 float4 向量化加载,每条指令搬运 16 字节(4 个 float),可以:

  • 减少加载指令总数(减少 4 倍的循环迭代)
  • 提升内存事务的利用效率
  • 增加指令级并行(ILP)

5.2 Kernel 实现

// V3: float4 向量化加载
__global__ void softmax_v3(float* input, float* output, int M, int N) {
    int row = blockIdx.x;
    int tid = threadIdx.x;

    float* x = input  + row * N;
    float* y = output + row * N;

    // 向量化指针
    float4* x4 = reinterpret_cast<float4*>(x);
    float4* y4 = reinterpret_cast<float4*>(y);
    int N4 = N / 4;  // float4 元素数

    // Pass 1: 向量化求最大值
    float local_max = -INFINITY;
    for (int i = tid; i < N4; i += blockDim.x) {
        float4 data = x4[i];
        local_max = fmaxf(local_max, fmaxf(fmaxf(data.x, data.y),
                                            fmaxf(data.z, data.w)));
    }
    // 处理尾部元素(N 不是 4 的倍数时)
    for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
        local_max = fmaxf(local_max, x[i]);
    }
    float max_val = blockReduceMaxShuffle(local_max);

    // Pass 2: 向量化求指数和
    float local_sum = 0.0f;
    for (int i = tid; i < N4; i += blockDim.x) {
        float4 data = x4[i];
        local_sum += expf(data.x - max_val) + expf(data.y - max_val)
                   + expf(data.z - max_val) + expf(data.w - max_val);
    }
    for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
        local_sum += expf(x[i] - max_val);
    }
    float sum = blockReduceSumShuffle(local_sum);

    // Pass 3: 向量化归一化
    float inv_sum = 1.0f / sum;
    for (int i = tid; i < N4; i += blockDim.x) {
        float4 data = x4[i];
        float4 result;
        result.x = expf(data.x - max_val) * inv_sum;
        result.y = expf(data.y - max_val) * inv_sum;
        result.z = expf(data.z - max_val) * inv_sum;
        result.w = expf(data.w - max_val) * inv_sum;
        y4[i] = result;
    }
    for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
        y[i] = expf(x[i] - max_val) * inv_sum;
    }
}

5.3 性能改进

每个线程的循环次数从 N/blockDimN/\text{blockDim} 降到 N/(4×blockDim)N/(4 \times \text{blockDim}),减少了 4 倍的循环迭代开销和指令调度次数。同时 float4 加载能更好地隐藏内存延迟。

⚠️ 注意:使用 float4 要求输入数组 16 字节对齐。在实际工程中,通常在内存分配时使用 cudaMalloc(保证 256 字节对齐),并确保 NN 是 4 的倍数(不满足时用尾部处理兜底)。

V3 实测带宽利用率:约 65%(~1325 GB/s),相比 V2 提升约 1.25 倍。


6. 版本 V4:两遍融合 Kernel

6.1 三遍扫描的冗余

回顾 V1-V3 的计算流程:三遍扫描读取输入 3 次。对于 (4096,4096)(4096, 4096) 的矩阵,总读取量为 3×4096×4096×4=1923 \times 4096 \times 4096 \times 4 = 192 MB。如果能把 Pass 2(exp-sum)和 Pass 3(normalize)合并为一遍,就能减少到 2 次读取(128 MB),读取量减少 1/3,理论加速比为 1.5 倍。

合并的关键观察:Pass 3 计算 exim/de^{x_i - m} / d 时需要用到 dd(指数和),而 dd 在 Pass 2 结束后才确定。如果我们不存储 Pass 2 的中间结果,只保存 mmdd,然后在第二遍同时计算 exime^{x_i - m} 和除以 dd,就实现了融合。

事实上,纯粹的两遍方案需要 Online Softmax 算法(下篇文章的主题)。但我们可以做一个折中——将 Pass 2 和 Pass 3 部分融合:在第 2 遍中,每个线程将自己负责的 exime^{x_i - m} 暂存到 Shared Memory,然后求出全局 dd 后立即从 SMEM 读取并写出归一化结果,避免第 3 次全局内存读取。

6.2 Kernel 实现

// V4: 两遍 Kernel(Pass 1 求 max,Pass 2 融合 exp-sum + normalize)
// 需要额外 Shared Memory 暂存 exp 值
__global__ void softmax_v4(float* input, float* output, int M, int N) {
    extern __shared__ float smem[];  // 大小 = N(暂存 exp 值)

    int row = blockIdx.x;
    int tid = threadIdx.x;

    float* x = input  + row * N;
    float* y = output + row * N;

    float4* x4 = reinterpret_cast<float4*>(x);
    int N4 = N / 4;

    // Pass 1: 求行最大值(向量化 + Warp Shuffle)
    float local_max = -INFINITY;
    for (int i = tid; i < N4; i += blockDim.x) {
        float4 data = x4[i];
        local_max = fmaxf(local_max, fmaxf(fmaxf(data.x, data.y),
                                            fmaxf(data.z, data.w)));
    }
    for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
        local_max = fmaxf(local_max, x[i]);
    }
    float max_val = blockReduceMaxShuffle(local_max);

    // Pass 2: 计算 exp 并暂存,同时累加求 sum
    float local_sum = 0.0f;
    for (int i = tid; i < N; i += blockDim.x) {
        float exp_val = expf(x[i] - max_val);
        smem[i] = exp_val;  // 暂存到 Shared Memory
        local_sum += exp_val;
    }
    float sum = blockReduceSumShuffle(local_sum);

    // 直接从 Shared Memory 读取 exp 值并归一化写出
    float inv_sum = 1.0f / sum;
    for (int i = tid; i < N; i += blockDim.x) {
        y[i] = smem[i] * inv_sum;
    }
}

6.3 局限性与适用场景

⚠️ 注意:V4 需要 N×4N \times 4 字节的 Shared Memory 来暂存 exp 值。当 N=4096N=4096 时需要 16KB,在 A100(每 SM 164KB Shared Memory)上完全够用。但当 NN 很大(如 32768 以上)时,Shared Memory 装不下整行数据,此方案就不适用了。

对于大 NN 的场景,真正的解决方案是 Online Softmax——只需一次扫描(或两次,但不需要暂存整行),将在下一篇文章详细介绍。

适用场景推荐方案
N4096N \leq 4096(SMEM 放得下)V4 两遍融合
N>4096N > 4096(SMEM 不够)V3 三遍扫描,或 Online Softmax

V4 实测带宽利用率:约 75%(~1529 GB/s)N=4096N=4096),相比 V3 提升约 1.15 倍。


7. 性能对比与工程建议

7.1 各版本性能汇总

版本核心优化点带宽利用率相对加速
V0 单线程/行~5%1.0x
V1 Block 并行多线程协作 + 合并访存 + SMEM 规约~35%7.0x
V2 Warp Shuffle寄存器级两级规约~52%10.4x
V3 向量化加载float4 减少指令数~65%13.0x
V4 两遍融合减少一次全局读取~75%15.0x

📌 关键点:V4 在 N=4096N=4096 时达到 ~75% 带宽利用率。进一步的优化需要引入 Online Softmax 算法,从算法层面减少扫描次数。

7.2 实际工程选择

场景推荐方案
学习/教学V1 或 V2,逻辑清晰
N4096N \leq 4096 的生产环境V4(两遍融合 + Warp Shuffle + float4)
N>4096N > 4096 的生产环境Online Softmax(下篇文章)
追求极致性能使用 cuDNN 或 FlashAttention 中的融合实现

💡 提示:在实际的 LLM 推理框架(如 vLLM、TensorRT-LLM)中,Softmax 通常和 Attention 的其他部分(Scale、Mask、MatMul)融合成一个大 Kernel,避免中间结果落地到全局内存。单独优化 Softmax 主要用于理解原理和特定独立场景(如分类层的 Softmax)。


📝 总结

从 V0 到 V4,Softmax 优化的核心思路可以归纳为三条主线:

  1. 提升并行度:从单线程处理整行 → Block 内多线程协作,利用规约实现 max 和 sum 的并行计算(V0 → V1)
  2. 加速规约:从 Shared Memory 朴素循环 → Warp Shuffle 两级规约,在寄存器层面完成数据交换,大幅降低同步与访存延迟(V1 → V2)
  3. 减少内存访问:从逐元素加载 → float4 向量化 → 两遍融合减少重复读取,逼近 Memory-Bound 的理论极限(V2 → V3 → V4)

而更进一步的突破——从三遍扫描到一遍扫描——需要从算法层面重新设计,这就是下一篇 Online Softmax 要解决的问题。


🎯 自我检验清单

  • 能解释朴素 Softmax 在 FP32/FP16 下为什么会数值溢出
  • 能证明”减最大值”不改变 Softmax 输出结果的数学原理
  • 能说明 Safe Softmax 三遍扫描各自的作用及为什么需要三遍
  • 能写出 Block 级并行规约求 max 和 sum 的 CUDA 代码
  • 能使用 __shfl_down_sync 实现 Warp 内 max-reduce 和 sum-reduce
  • 能解释两级规约(Warp 内 + Warp 间)的完整流程
  • 能说明 float4 向量化加载的对齐要求和性能收益来源
  • 能分析两遍融合方案的 Shared Memory 容量限制
  • 能根据 NN 的大小选择合适的 Softmax 实现方案
  • 能用 Nsight Compute 测量 Softmax Kernel 的带宽利用率并定位瓶颈

📚 参考资料