跳到主要内容
CUDA编程与算子优化

5.2 CUDA Online Softmax 实现优化

本文从算法推导出发,逐步实现并优化 Online Softmax 的 CUDA Kernel,这也是理解 FlashAttention 的核心前置知识。

CUDA GPU Softmax Online Softmax FlashAttention 算子优化

传统 Safe Softmax 需要三遍扫描输入数据,而 Online Softmax 通过巧妙的数学递推,在一遍扫描中同时完成 max 和 sum 的计算,将内存读取次数从 3 次降到 2 次。本文从算法推导出发,逐步实现并优化 Online Softmax 的 CUDA Kernel,这也是理解 FlashAttention 的核心前置知识。

📑 目录


1. 为什么需要 Online Softmax

1.1 三遍扫描的代价

上一篇文章中,Safe Softmax 采用三遍扫描:

m=maxj(xj)第1遍d=jexjm第2遍yi=eximd第3遍\underbrace{m = \max_j(x_j)}_{\text{第1遍}} \quad \rightarrow \quad \underbrace{d = \sum_j e^{x_j - m}}_{\text{第2遍}} \quad \rightarrow \quad \underbrace{y_i = \frac{e^{x_i - m}}{d}}_{\text{第3遍}}

这三遍各读一次输入数据,总内存读取量为 3N3N。对于 Memory-Bound 的操作来说,减少一遍扫描就相当于性能提升 33%。问题是:第 2 遍(求 sum)必须知道第 1 遍的结果(max),看似没法合并。

打个比方:你要给全班同学的成绩做归一化。传统方法是:先翻一遍花名册找最高分,再翻一遍算调整后的总分,最后翻一遍写出归一化分数——翻了三遍花名册。能不能只翻两遍甚至一遍就搞定?

1.2 核心洞察

Online Softmax 的关键洞察是:max 和 sum 可以在一遍扫描中同时维护,代价只是在发现新的最大值时”修正”之前累积的 sum。

想象你一边翻花名册一边累加”调整分”。翻到第 50 个人时发现他的分比之前的最高分还高——那你之前算的”调整分”全都偏了!但偏多少是确切知道的:之前每个人的 eximolde^{x_i - m_{\text{old}}} 应该变成 eximnewe^{x_i - m_{\text{new}}},总和只需乘一个修正因子 emoldmnewe^{m_{\text{old}} - m_{\text{new}}}

这就是 Online Softmax 的精髓——用乘法修正代替重新计算


2. Online Softmax 算法推导

2.1 递推关系推导

设已经处理了前 kk 个元素,当前维护的状态为:

  • mk=max(x0,x1,,xk1)m_k = \max(x_0, x_1, \ldots, x_{k-1})(前 kk 个元素的最大值)
  • dk=j=0k1exjmkd_k = \sum_{j=0}^{k-1} e^{x_j - m_k}(基于当前 max 的指数和)

当处理第 k+1k+1 个元素 xkx_k 时:

步骤 1:更新最大值

mk+1=max(mk,xk)m_{k+1} = \max(m_k, x_k)

步骤 2:修正已有的 sum 并加入新元素

dk+1=dkemkmk+1+exkmk+1d_{k+1} = d_k \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}}

推导过程:

dk+1=j=0kexjmk+1\begin{aligned} d_{k+1} &= \sum_{j=0}^{k} e^{x_j - m_{k+1}} \end{aligned} =j=0k1exjmk+1+exkmk+1\begin{aligned} &= \sum_{j=0}^{k-1} e^{x_j - m_{k+1}} + e^{x_k - m_{k+1}} \end{aligned} =j=0k1exjmkemkmk+1+exkmk+1\begin{aligned} &= \sum_{j=0}^{k-1} e^{x_j - m_k} \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}} \end{aligned} =dkemkmk+1+exkmk+1\begin{aligned} &= d_k \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}} \end{aligned}

💡 提示:关键在于 emkmk+1e^{m_k - m_{k+1}} 这个修正因子。如果新元素不是新的最大值(mk+1=mkm_{k+1} = m_k),则 emkmk+1=e0=1e^{m_k - m_{k+1}} = e^0 = 1,修正退化为简单的累加——没有任何额外开销。

2.2 完整算法伪代码

// 第 1 遍:Online 计算 max 和 sum(同时进行)
m = -∞
d = 0
for j = 0 to N-1:
    m_new = max(m, x[j])
    d = d * exp(m - m_new) + exp(x[j] - m_new)
    m = m_new

// 第 2 遍:归一化输出
for j = 0 to N-1:
    y[j] = exp(x[j] - m) / d

从三遍降到两遍:第 1 遍同时得到 mmdd,第 2 遍完成归一化。内存读取次数从 3N3N 降到 2N2N,理论性能提升 33%。

2.3 数值稳定性保证

Online Softmax 的数值稳定性和 Safe Softmax 完全一致:

  • mm 始终是已见元素的最大值,所有指数参数 xjm0x_j - m \leq 0,不会上溢
  • 修正因子 emkmk+11e^{m_k - m_{k+1}} \leq 1(因为 mk+1mkm_{k+1} \geq m_k),不会上溢
  • 最终 dd 和 Safe Softmax 的三遍结果在数学上完全等价

2.4 测试环境

与上一篇相同:A100 80GB SXM4,测试矩阵 (M,N)=(4096,4096)(M, N) = (4096, 4096)


3. 版本 V0:单线程 Online Softmax

3.1 算法思路

最直接的实现:每个线程处理一行,串行地维护 (m,d)(m, d) 状态,最后再做一遍归一化。

3.2 Kernel 实现

// V0: 单线程 Online Softmax(一个线程处理一行)
__global__ void online_softmax_v0(float* input, float* output, int M, int N) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= M) return;

    float* x = input  + row * N;
    float* y = output + row * N;

    // 第 1 遍:Online 同时计算 max 和 sum
    float m = -INFINITY;
    float d = 0.0f;

    for (int i = 0; i < N; i++) {
        float xi = x[i];
        float m_new = fmaxf(m, xi);
        d = d * expf(m - m_new) + expf(xi - m_new);
        m = m_new;
    }

    // 第 2 遍:归一化
    float inv_d = 1.0f / d;
    for (int i = 0; i < N; i++) {
        y[i] = expf(x[i] - m) * inv_d;
    }
}

3.3 性能分析

和 Safe Softmax V0 类似,单线程处理整行存在严重的并行度不足和非合并访存问题。但相比三遍扫描的 V0,这里只读了 2 遍输入数据。

⚠️ 注意:V0 的 Online 递推中每步都需要 expf 调用(用于修正因子),这增加了计算负担。但由于 Softmax 本身是 Memory-Bound 的,这些额外计算在高性能版本中会被内存访问延迟完全隐藏。

V0 实测带宽利用率:约 7%(~143 GB/s)——仅作为正确性验证的基线。


4. 版本 V1:Block 级并行 + Warp Shuffle 合并规约

4.1 并行化挑战

Online Softmax 的递推是串行的——每一步的 (m,d)(m, d) 依赖上一步。如何并行化?

关键观察:可以先让每个线程独立处理一部分元素,各自维护局部的 (mlocal,dlocal)(m_{\text{local}}, d_{\text{local}}),最后将这些局部结果合并

两个局部结果 (m1,d1)(m_1, d_1)(m2,d2)(m_2, d_2) 的合并公式:

mmerged=max(m1,m2)\begin{aligned} m_{\text{merged}} &= \max(m_1, m_2) \end{aligned} dmerged=d1em1mmerged+d2em2mmerged\begin{aligned} d_{\text{merged}} &= d_1 \cdot e^{m_1 - m_{\text{merged}}} + d_2 \cdot e^{m_2 - m_{\text{merged}}} \end{aligned}

这个合并操作满足结合律(可以用树形规约),因此可以高效并行化。

4.2 用 Warp Shuffle 实现合并规约

参考上一篇 Safe Softmax 的优化经验,规约最好直接做在寄存器层面:Warp Shuffle 延迟低、无 Bank Conflict、Warp 内天然同步,比 Shared Memory 规约更高效。

与普通 Reduce 的区别在于:Online Softmax 的规约不是简单的 max 或 sum,而是同时合并 (m, d) 两个值。好在 __shfl_down_sync 可以分别传输这两个值,每轮结束后用上面的合并公式更新即可。整体结构沿用”Warp 内 Shuffle + Warp 间 Shared Memory”的两级规约。

4.3 Kernel 实现

// Warp 内 Online 合并规约
__device__ void warpReduceOnline(float& m, float& d) {
    for (int offset = 16; offset > 0; offset >>= 1) {
        float m2 = __shfl_down_sync(0xffffffff, m, offset);
        float d2 = __shfl_down_sync(0xffffffff, d, offset);

        float m_new = fmaxf(m, m2);
        d = d * expf(m - m_new) + d2 * expf(m2 - m_new);
        m = m_new;
    }
}

// V1: Block 级并行 + Warp Shuffle 两级合并规约
__global__ void online_softmax_v1(float* input, float* output, int M, int N) {
    int row = blockIdx.x;
    int tid = threadIdx.x;
    int lane = tid % 32;
    int wid  = tid / 32;

    float* x = input  + row * N;
    float* y = output + row * N;

    // 第 1 遍:每线程 Online 处理一段
    float local_m = -INFINITY;
    float local_d = 0.0f;

    for (int i = tid; i < N; i += blockDim.x) {
        float xi = x[i];
        float m_new = fmaxf(local_m, xi);
        local_d = local_d * expf(local_m - m_new) + expf(xi - m_new);
        local_m = m_new;
    }

    // 第一级:Warp 内合并
    warpReduceOnline(local_m, local_d);

    // Warp 间通过 Shared Memory 交换
    __shared__ float warp_m[32];
    __shared__ float warp_d[32];

    if (lane == 0) {
        warp_m[wid] = local_m;
        warp_d[wid] = local_d;
    }
    __syncthreads();

    // 第二级:Warp 0 做最终合并
    int num_warps = blockDim.x / 32;
    if (wid == 0) {
        local_m = (lane < num_warps) ? warp_m[lane] : -INFINITY;
        local_d = (lane < num_warps) ? warp_d[lane] : 0.0f;
        warpReduceOnline(local_m, local_d);
    }

    // 广播最终结果
    __shared__ float final_m, final_d;
    if (tid == 0) {
        final_m = local_m;
        final_d = local_d;
    }
    __syncthreads();

    float row_max = final_m;
    float inv_sum = 1.0f / final_d;

    // 第 2 遍:归一化输出
    for (int i = tid; i < N; i += blockDim.x) {
        y[i] = expf(x[i] - row_max) * inv_sum;
    }
}

4.4 与三遍 Safe Softmax 的对比

对比项Safe Softmax(三遍)Online Softmax(两遍)
全局内存读取次数3N2N
规约次数2 次(max + sum 分开)1 次(合并规约)
规约复杂度简单加法/取 max带修正因子的合并
计算量较少稍多(修正因子的 exp)

⚠️ 注意:合并顺序会影响浮点精度(加法不严格满足结合律),但对于 Softmax 这种对微小误差不敏感的场景,树形规约的精度完全足够。

💡 提示:Online 版本的规约操作比普通 sum-reduce 复杂(每步需要 2 次 expf),但节省的内存带宽远大于多出的计算量。对于 Memory-Bound 操作,减少一次全局读取的收益是决定性的。

V1 实测带宽利用率:约 58%(~1183 GB/s),相比单线程 V0 提升约 8.3 倍,相比三遍 Safe Softmax 中的对应版本提升约 1.4 倍。


5. 版本 V2:向量化加载 float4

5.1 改进思路

和 Safe Softmax 的优化路径一致,使用 float4 向量化加载减少指令数。Online 递推阶段,每次加载 4 个元素,逐个更新 (m,d)(m, d) 状态。

5.2 Kernel 实现

// V2: float4 向量化加载 + Online Softmax
__global__ void online_softmax_v2(float* input, float* output, int M, int N) {
    int row = blockIdx.x;
    int tid = threadIdx.x;
    int lane = tid % 32;
    int wid  = tid / 32;

    float* x = input  + row * N;
    float* y = output + row * N;

    float4* x4 = reinterpret_cast<float4*>(x);
    float4* y4 = reinterpret_cast<float4*>(y);
    int N4 = N / 4;

    // 第 1 遍:向量化 Online 递推
    float local_m = -INFINITY;
    float local_d = 0.0f;

    for (int i = tid; i < N4; i += blockDim.x) {
        float4 data = x4[i];

        // 对 4 个元素依次更新 (m, d)
        float vals[4] = {data.x, data.y, data.z, data.w};
        for (int k = 0; k < 4; k++) {
            float m_new = fmaxf(local_m, vals[k]);
            local_d = local_d * expf(local_m - m_new) + expf(vals[k] - m_new);
            local_m = m_new;
        }
    }
    // 尾部处理
    for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
        float xi = x[i];
        float m_new = fmaxf(local_m, xi);
        local_d = local_d * expf(local_m - m_new) + expf(xi - m_new);
        local_m = m_new;
    }

    // 两级 Warp Shuffle 合并(与 V1 相同)
    warpReduceOnline(local_m, local_d);

    __shared__ float warp_m[32];
    __shared__ float warp_d[32];
    if (lane == 0) {
        warp_m[wid] = local_m;
        warp_d[wid] = local_d;
    }
    __syncthreads();

    int num_warps = blockDim.x / 32;
    if (wid == 0) {
        local_m = (lane < num_warps) ? warp_m[lane] : -INFINITY;
        local_d = (lane < num_warps) ? warp_d[lane] : 0.0f;
        warpReduceOnline(local_m, local_d);
    }

    __shared__ float final_m, final_d;
    if (tid == 0) {
        final_m = local_m;
        final_d = local_d;
    }
    __syncthreads();

    float row_max = final_m;
    float inv_sum = 1.0f / final_d;

    // 第 2 遍:向量化归一化
    for (int i = tid; i < N4; i += blockDim.x) {
        float4 data = x4[i];
        float4 result;
        result.x = expf(data.x - row_max) * inv_sum;
        result.y = expf(data.y - row_max) * inv_sum;
        result.z = expf(data.z - row_max) * inv_sum;
        result.w = expf(data.w - row_max) * inv_sum;
        y4[i] = result;
    }
    for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
        y[i] = expf(x[i] - row_max) * inv_sum;
    }
}

5.3 float4 与 Online 递推的配合

虽然 float4 一次加载 4 个元素,但 Online 递推仍然逐元素更新 (m,d)(m, d)。这看似没有收益,但实际上:

  1. 内存层面float4 加载用更少的指令传输相同的数据量,减少了指令调度的瓶颈
  2. 计算层面:4 个 expf 调用可以流水线化,编译器能更好地排布指令
  3. 整体效果:循环迭代减少 4 倍,循环开销(分支判断、地址计算)大幅降低

V2 实测带宽利用率:约 70%(~1427 GB/s),相比 V1 提升约 1.21 倍。


6. 版本 V3:寄存器缓存消除第二遍读取

6.1 终极优化思路

两遍扫描意味着输入数据仍然被读了两次。能不能只读一次?

答案是:如果每行元素不多(N/blockDimN / \text{blockDim} \leq 几十),每个线程需要处理的元素数量有限,可以将第 1 遍读入的数据暂存在寄存器中,第 2 遍直接从寄存器取值做归一化,避免第二次全局内存读取。

这就是真正的一遍(One-Pass)Softmax——全局内存只读一次、写一次。

6.2 Kernel 实现

// V3: 寄存器缓存实现 One-Pass Softmax
// 适用于 N / blockDim.x 较小的场景(每线程处理元素数不超过 MAX_ELEMS)
#define MAX_ELEMS_PER_THREAD 32

__global__ void online_softmax_v3(float* input, float* output, int M, int N) {
    int row = blockIdx.x;
    int tid = threadIdx.x;
    int lane = tid % 32;
    int wid  = tid / 32;

    float* x = input  + row * N;
    float* y = output + row * N;

    // 寄存器数组缓存输入数据
    float reg_cache[MAX_ELEMS_PER_THREAD];

    // 第 1 遍(也是唯一一遍全局读取):读数据 + Online 递推
    float local_m = -INFINITY;
    float local_d = 0.0f;
    int count = 0;

    for (int i = tid; i < N; i += blockDim.x) {
        float xi = x[i];
        reg_cache[count] = xi;  // 缓存到寄存器
        count++;

        float m_new = fmaxf(local_m, xi);
        local_d = local_d * expf(local_m - m_new) + expf(xi - m_new);
        local_m = m_new;
    }

    // 两级 Warp Shuffle 合并
    warpReduceOnline(local_m, local_d);

    __shared__ float warp_m[32];
    __shared__ float warp_d[32];
    if (lane == 0) {
        warp_m[wid] = local_m;
        warp_d[wid] = local_d;
    }
    __syncthreads();

    int num_warps = blockDim.x / 32;
    if (wid == 0) {
        local_m = (lane < num_warps) ? warp_m[lane] : -INFINITY;
        local_d = (lane < num_warps) ? warp_d[lane] : 0.0f;
        warpReduceOnline(local_m, local_d);
    }

    __shared__ float final_m, final_d;
    if (tid == 0) {
        final_m = local_m;
        final_d = local_d;
    }
    __syncthreads();

    float row_max = final_m;
    float inv_sum = 1.0f / final_d;

    // 从寄存器缓存中读取数据并写出(无需再读全局内存)
    int idx = 0;
    for (int i = tid; i < N; i += blockDim.x) {
        y[i] = expf(reg_cache[idx] - row_max) * inv_sum;
        idx++;
    }
}

6.3 寄存器压力分析

⚠️ 注意:这个方案的核心限制是寄存器容量。每个线程使用 MAX_ELEMS_PER_THREAD 个 float 寄存器来缓存数据:

blockDimNN 上限每线程寄存器(仅缓存部分)
256256 × 32 = 819232 × 4B = 128B
256256 × 16 = 409616 × 4B = 64B
128128 × 32 = 409632 × 4B = 128B

A100 每个 SM 有 65536 个 32 位寄存器,每线程 255 个寄存器上限。当 MAX_ELEMS_PER_THREAD 过大时,寄存器溢出到 Local Memory(本质是全局内存),反而变慢。

📌 关键点:V3 适合 N8192N \leq 8192 且 blockDim=256 的场景。对于更大的 NN,应退回 V2(两遍方案)。

V3 实测带宽利用率:约 82%(~1672 GB/s)N=4096N=4096, blockDim=256),相比 V2 提升约 1.17 倍。


7. 版本 V4:多行并行 + Grid Stride

7.1 适用场景

前面的版本都是”一个 Block 处理一行”、“启动 MM 个 Block”。当 MM 非常大(如几万行)时,会启动同等数量的 Block,带来两个问题:一是 Grid 调度开销线性增加;二是每个 Block 只跑一行就退出,Block 内的 Shared Memory 缓冲、寄存器状态都无法跨行复用。

解决方案是用 Grid Stride 让固定数量的 Block(通常按 SM 数的若干倍设置)循环覆盖所有行,并继续叠加前面的所有优化(向量化加载 + Warp Shuffle 合并规约)。V4 在算法层面与 V2 完全相同(两遍扫描),区别只在于 Block 与行的映射关系——V2 是 1:1,V4 是 1:多。

⚠️ 注意:Grid Stride 并不能解决”MM 很小(如 batch=1)、NN 很大”导致 Block 数不足的问题——那种场景需要让多个 Block 协作处理同一行(行内分段规约 + 两段式 Kernel 或原子合并)。

7.2 Kernel 实现

// V4: Grid Stride 多行处理 + 所有优化集成
__global__ void online_softmax_v4(float* input, float* output, int M, int N) {
    int tid = threadIdx.x;
    int lane = tid % 32;
    int wid  = tid / 32;

    // Grid Stride 遍历所有行
    for (int row = blockIdx.x; row < M; row += gridDim.x) {
        float* x = input  + row * N;
        float* y = output + row * N;

        float4* x4 = reinterpret_cast<float4*>(x);
        float4* y4 = reinterpret_cast<float4*>(y);
        int N4 = N / 4;

        // 第 1 遍:向量化 Online 递推
        float local_m = -INFINITY;
        float local_d = 0.0f;

        for (int i = tid; i < N4; i += blockDim.x) {
            float4 data = x4[i];
            float vals[4] = {data.x, data.y, data.z, data.w};
            for (int k = 0; k < 4; k++) {
                float m_new = fmaxf(local_m, vals[k]);
                local_d = local_d * expf(local_m - m_new) + expf(vals[k] - m_new);
                local_m = m_new;
            }
        }
        for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
            float xi = x[i];
            float m_new = fmaxf(local_m, xi);
            local_d = local_d * expf(local_m - m_new) + expf(xi - m_new);
            local_m = m_new;
        }

        // 两级 Warp Shuffle 合并
        warpReduceOnline(local_m, local_d);

        __shared__ float warp_m[32];
        __shared__ float warp_d[32];
        if (lane == 0) {
            warp_m[wid] = local_m;
            warp_d[wid] = local_d;
        }
        __syncthreads();

        int num_warps = blockDim.x / 32;
        if (wid == 0) {
            local_m = (lane < num_warps) ? warp_m[lane] : -INFINITY;
            local_d = (lane < num_warps) ? warp_d[lane] : 0.0f;
            warpReduceOnline(local_m, local_d);
        }

        __shared__ float final_m, final_d;
        if (tid == 0) {
            final_m = local_m;
            final_d = local_d;
        }
        __syncthreads();

        float row_max = final_m;
        float inv_sum = 1.0f / final_d;

        // 第 2 遍:向量化归一化
        for (int i = tid; i < N4; i += blockDim.x) {
            float4 data = x4[i];
            float4 result;
            result.x = expf(data.x - row_max) * inv_sum;
            result.y = expf(data.y - row_max) * inv_sum;
            result.z = expf(data.z - row_max) * inv_sum;
            result.w = expf(data.w - row_max) * inv_sum;
            y4[i] = result;
        }
        for (int i = N4 * 4 + tid; i < N; i += blockDim.x) {
            y[i] = expf(x[i] - row_max) * inv_sum;
        }

        __syncthreads();  // 确保 Shared Memory 在下一行之前被重置
    }
}

调用方式:

int num_sms;
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, 0);
int grid_size = min(M, num_sms * 4);  // 最大化 GPU 利用率
int block_size = 256;
online_softmax_v4<<<grid_size, block_size>>>(d_input, d_output, M, N);

7.3 Grid Stride 的优势

MM 很大时,Grid Stride 让 Block 循环处理多行,好处是:

  • 固定 Grid 大小避免过大的 Grid 调度开销
  • Block 内的 Shared Memory 和寄存器状态可以跨行复用
  • GPU 始终保持满载

MM 较小时,grid_size = M,退化为 V2 的行为——每个 Block 处理一行。

V4 实测带宽利用率:约 72%(~1468 GB/s)M=4096,N=4096M=4096, N=4096),与 V2 基本持平(两遍方案在大 NN 下的理论上限)。Grid Stride 的收益主要体现在 MM 极大时的调度开销节省,对于本文的测试规模(M=4096M=4096)提升有限。


8. 性能对比与 FlashAttention 的联系

8.1 各版本性能汇总

版本核心优化点扫描次数带宽利用率相对加速
V0 单线程/行2~7%1.0x
V1 Block 并行 + Warp Shuffle寄存器级合并规约2~58%8.3x
V2 向量化加载float4 减少指令数2~70%10.0x
V3 寄存器缓存One-Pass(仅适合小 NN1~82%11.7x
V4 Grid Stride多行并行 + 适配任意 MM2~72%10.3x

📌 关键点:V3 的 One-Pass 方案在 N8192N \leq 8192 时效果最好,但受限于寄存器容量。V2/V4 的 Two-Pass 方案通用性更强,在各种 NN 下都能达到 70%+ 的带宽利用率。

8.2 与 FlashAttention 的联系

FlashAttention 的核心思想正是 Online Softmax 的推广。在 Self-Attention 中,完整的计算是:

Attention(Q,K,V)=Softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

传统实现需要先计算完整的 QKTQK^T 矩阵(N×NN \times N),再做 Softmax,再乘 VV。这要求将整个 N×NN \times N 的 Attention 矩阵存放在 HBM 中——当 N=128KN=128K 时这是 64GB 的内存。

FlashAttention 利用 Online Softmax 的递推思想,将 KKVV 分块加载到 SRAM(Shared Memory),每次只计算部分 QKTQK^T 并增量更新 Softmax 状态 (m,d)(m, d),最后修正输出。整个过程不需要存储完整的 Attention 矩阵

graph LR
    A["Online Softmax<br>递推公式"] --> B["分块计算 QK^T<br>增量更新 (m, d)"]
    B --> C["修正输出 O<br>O = O * correction + new_block"]
    C --> D["FlashAttention<br>IO-aware Attention"]

Online Softmax 中的修正因子 emoldmnewe^{m_{\text{old}} - m_{\text{new}}} 在 FlashAttention 中用来修正之前块的输出贡献

Onew=Oolddolddnewemoldmnew+emblockmnewdnewPblockVblockO_{\text{new}} = O_{\text{old}} \cdot \frac{d_{\text{old}}}{d_{\text{new}}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \frac{e^{m_{\text{block}} - m_{\text{new}}}}{d_{\text{new}}} \cdot P_{\text{block}} \cdot V_{\text{block}}

💡 提示:理解了本文的 Online Softmax 递推和合并公式,FlashAttention 的分块策略就只是在此基础上多了一个矩阵乘法的增量更新——核心数学完全一致。

8.3 工程选择建议

场景推荐方案
独立 Softmax(分类层等)V2 或 V3(根据 NN 大小选择)
Attention 中的 Softmax使用 FlashAttention 融合实现
学习/理解原理V1
自定义 Attention 变体基于 V2 的模式进行扩展

📝 总结

Online Softmax 的核心贡献是用一个简洁的递推公式,将 Safe Softmax 的三遍扫描合并为两遍:

  1. 数学基础:利用 dk+1=dkemkmk+1+exkmk+1d_{k+1} = d_k \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}} 在一遍中同时维护 max 和 sum
  2. 并行化关键(m,d)(m, d) 合并满足结合律,可以通过 Warp Shuffle 两级规约高效并行(V1)
  3. 逼近带宽极限:向量化加载(V2)减少指令数,寄存器缓存(V3)更进一步实现 One-Pass,Grid Stride(V4)适配任意 batch
  4. 通往 FlashAttention:Online Softmax 的递推思想是 FlashAttention 分块计算的数学根基

从工程角度看,Online Softmax 将 Memory-Bound 操作的内存访问次数降到理论最低,配合 Warp Shuffle 和向量化加载,在 A100 上可达到 70-82% 的带宽利用率。


🎯 自我检验清单

  • 能解释 Online Softmax 相比三遍 Safe Softmax 节省了哪一遍扫描及其原理
  • 能手动推导 dk+1=dkemkmk+1+exkmk+1d_{k+1} = d_k \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}} 的递推公式
  • 能写出两个局部 (m1,d1)(m_1, d_1)(m2,d2)(m_2, d_2) 合并为全局 (m,d)(m, d) 的代码
  • 能解释为什么 (m,d)(m, d) 合并操作满足结合律从而可以并行化
  • 能实现使用 __shfl_down_syncwarpReduceOnline 函数
  • 能说明 V3 寄存器缓存方案的适用条件和寄存器压力限制
  • 能解释 Online Softmax 与 FlashAttention 分块策略的数学联系
  • 能根据 MMNN 的大小选择合适的 Softmax 实现版本
  • 能使用 Nsight Compute 对比 Two-Pass 和 Three-Pass 方案的内存吞吐差异
  • 能将 Online Softmax 的 (m,d)(m, d) 递推思路扩展到其他需要增量计算的场景

📚 参考资料