在进一步钻研 INT-FlashAttention 的源码与论文后,可以把整套设计拆分为 离线量化 → 带尺度因子的数据流 → INT8-aware Online-Softmax → 双 INT8 GEMM 融合内核 四个互相咬合的层次。下面从 量化数学、内核执行流、数值稳定性、硬件映射 四条主线,把论文中没展开的关键细节补齐。

一段话总览

INT-FlashAttention 把 FlashAttention-2 的「块化 + 在线 softmax」算法完全保持不变,但把所有矩阵运算改成 INT8×INT8→INT32 Tensor Core IMMA,并在块循环里 显式管理 Q/K 的逐 token 尺度向量、V 的全局尺度常量,以及 softmax 的行归一化量。借助整数乘法的线性可交换性,它把缩放操作“吸收”进整数域运算,做到 中间权重矩阵直接以 INT8 存进 SRAM,再次作为 INT8 输入参与第二次 GEMM;最终只在输出阶段做一次矢量化 dequant,从而让 A100/RTX 4090 在 1 k–16 k token 上普遍比 FP16 FlashAttention 提速 1.7–1.9 ×,同时把 Flash-FP8 的 MRE 误差降到原来的 ≈¼ (arxiv.org, arxiv.org)。


1 离线量化策略

1.1 逐 token 对称线性量化

对每个 token 的向量 qi∈Rdh\mathbf{q}_i\in\mathbb{R}^{d_h} ,用单独的尺度

sq[i]=max⁡(∣qi∣)127,q^i=round⁡ ⁣(qi/sq[i])∈[−127,127]dhs_q[i]=\frac{\max(|\mathbf{q}_i|)}{127},\qquad \hat{\mathbf q}_i=\operatorname{round}\!\bigl(\mathbf{q}_i/s_q[i]\bigr)\in[-127,127]^{d_h}

K 同理得到 k^j,sk[j]\hat{\mathbf k}_j,s_k[j];V 目前用 tensor-level 尺度 svs_v(作者计划改成 block 级)(arxiv.org)。这样 显存只多两条 float32 向量 (|Q| + |K|) × 4 B,却显著降低 outlier 主导的量化误差,相比张量级 FP8 误差下降 46–82 %(marktechpost.com)。

1.2 尺度缓存布局

  • sq,sks_q,s_k 被分块复制到每个 CTA 的 shared_memory,避免在 block 循环里重复 HBM 访存。

  • svs_v 仅在 第二次 GEMM 完成后 作为 warp-wide常量参与一次乘法,带宽开销极小(arxiv.org)。


2 INT8-aware Online-Softmax 细节

下面用论文 Algorithm 1 的行号标注核心步骤(arxiv.org)。

2.1 第一次 GEMM (Q Kᵀ)

  1. INT8 IMMA\hat Q_{blk} (INT8) × \hat K_{blk}^T (INT8) → S_int32 (行 10)。

  2. 矢量放缩S_fp32 = S_int32 × s_q[i] × s_k[j] — 这里先乘 INT8,再一次性放缩到 fp32,保留精度而不溢出 INT32 范围(|INT8|·|INT8|·dim_k ≤ 127²·256 ≪ 2³¹)。

  3. 行最大值 r 与指数和 m 用 fp32 更新(行 10–11),保持数值稳定。

2.2 权重矩阵直接量化

Softmax 局部输出

pij=eSij−ri/mi∈(0,1)p_{ij}=e^{S_{ij}-r_i}/m_i \in(0,1)

天然落在 [0,1][0,1],故作者直接用

p^ij=round⁡(127⋅pij)\hat p_{ij}= \operatorname{round}(127·p_{ij})

把 PP 写成 INT8(行 12)——这是关键一招:

  • 省掉写回 fp32 → 再量化的带宽;

  • 让第二次 GEMM 也能吃 INT8,彻底摆脱 FP16 Tensor Core(arxiv.org)。

2.3 第二次 GEMM (P V)

  1. \hat P_{blk} (INT8) × \hat V_{blk} (INT8) → O_int32 (行 13)。

  2. 行归一化 & dequant:O_fp32 = (O_int32 × s_v) / (m_i /127)(行 16)。

    • 注意把 127 提前合入 dequant 避免除法。

    • 归一化向量 1/mi1/m_i 也缓存在 shared_memory

整条数据流只产生一次 fp32-V 反量化,softmax 与尺度重分配全部在片上完成。


3 核函数执行流

3.1 CTA-tile 组织

  • 输入序列分为 Br×BcB_r × B_c 行×列块;Ampere 上经验 tile = 128 × 128 可同时满足 warp 匹配与 shared_memory(≈48 KB)约束(arxiv.org)。

  • 双缓冲:一边 GEMM,一边下一块数据 DMA 到 smem,隐藏 HBM 延迟。

3.2 Triton Autotune 策略

GitHub 实现通过 16 条 configs.py 列举 (BLOCK_M, BLOCK_N, BLOCK_K),Triton autotuner 在启动时扫遍 INT8 IMMA 性能曲线,比 CUTLASS 手写 kernel 少 100+ 行代码,同时可在 4090 与 A100 上共用内核(github.com)。

3.3 指令级并行

  • IMMA 1688 指令一次完成 4×4×4 INT8 dot-product,router micro-pipe 能在 8 cycle 内吐出 256 INT32 MAC。

  • 行最大值 r、指数和 m 用 warp-wide

    __shfl_xor_sync 做行规约,平均两条指令更新一次(对 128 token 行只需 7 step),摊薄 fp32 ALU 消耗(medium.com)。


4 数值稳定性与误差来源

环节

潜在误差

缓解手段

Q/K 量化

INT8 取整

token-level sq,sks_q,s_k 限定 Δ≤0.4 % LSB (arxiv.org)

S_int32 溢出

dot prod 累加

dim_k ≤256 时 INT32 容量足够;大维度可分块 K

softmax 归一化

指数下溢

行最大值 r 保留 fp32;同策略已在 Flash-2 证明数值等价(arxiv.org)

P 量化

7-bit 精度

放缩常数 127 与 m_i 同时调整,误差≤0.4 %

V tensor-scale

长序列误差累积

未来将改为 per-block,作者已在结论中说明(arxiv.org)

实验结果显示全 INT8 方案在 normal/uniform 输入下的 MRE 分别为 4.0 % / 1.7 %,远低于 Flash-FP8 的 7.5 % / 9.0 % (arxiv.org)。


5 硬件收益与对比

GPU 架构

Flash-FP16

Flash-FP8 (H100)

INT-FA (A100)

核心指令

HMMA 16816

WGMMA FP8

IMMA 1688

理论吞吐

312 TFLOPS

989 TFLOPS

1 248 TOPS (INT8) (images.nvidia.com)

实测 16 k token

42.7 ms

23.5 ms

24.8 ms (<1 % 差距) (arxiv.org)

MRE(normal)

0 %

7.5 %

4.1 % (arxiv.org)

因此在 无 FP8 支持的 Ampere 客户端,INT-FA 把 Flash-3 的优势完整复刻并提升精度。


6 落地提示 & 实践踩坑

  1. 安装

    pip install triton==2.1.0  # 保证>=2.1 内置 INT8 IMMA
    pip install int-flashattention
    

    from flash_attn import flash_attn_func
    y = flash_attn_func(q, k, v)
    

    改为

    from int_flashattention import int_flash_attn_func
    y = int_flash_attn_func(q, k, v)
    

    其余接口保持一致。

  2. 维度限制 :当前内核 hard-code head_dim≤256;更高维需要拆 K 分块或切换 FP16。

  3. INT8 GEMM 反而慢? 若遇到 int8 gemm slower than fp16,大多是 WARP_K 未调优;调大 configs.py.WARP_K 到 16 可恢复 IMMA 吞吐(github.com)。

  4. 梯度/训练 :作者未实现反向,若要蒸馏或 LoRA,可在前向用 INT-FA,反向自动回退 FP16;或参考 Triton issue #1187 的 INT8 dgrad 讨论(github.com)。


7 局限与未来工作

  • V 仅张量级尺度,长上下文误差随 token 累积;论文承诺按 block(token) 级改进(arxiv.org)。

  • 缺少训练级支持:INT8 softmax 梯度需重新推导数值稳定形式。

  • Hadamard-INT8 规划:作者计划把 QK 乘法替换为 Hadamard 转换 + INT8 点积,可望再降 40 % FLOPs (arxiv.org)。


一句话建议:如果你当前在 Ampere GPU 上部署 Llama-2/3,且已采用权重量化(GPTQ/SpQR 等),直接把 FlashAttention-2 换成 INT-FlashAttention,可在 不改权重文件的前提下 再免费获得 ~1.8 × 端到端推理加速,并进一步压低激活误差。

Appendix

在 INT-FlashAttention 里作者提到的 “LUT 方案需要 10 – 12 bit 地址才能把 softmax 的 exp 近似误差压到 ≤ 1 × 10-3 RMSE”,指的是最朴素、硬件界广泛使用的 “整值查表 +(可选)线性/二次内插” 方法来近似

f(Δ)=e−Δ,Δ=Sij−mi∈[0,Δmax⁡]f(\Delta)=e^{-\Delta},\qquad \Delta = S_{ij}-m_i\in[0,\Delta_{\max}]

(在线 softmax 里我们先减去行最大 mim_i,所以 Δ\Delta 非负并有上界)。下面把它的 原理、为什么要 10–12 bit、以及 GPU/FPGA 上如何具体落地 分开讲。


1 经典 LUT-exp 方案长什么样?

1.1 定点化 → 查表

  1. 裁剪 & 量化:把 Δ\Delta 裁剪到 [0,Δmax⁡][0,\Delta_{\max}](经验上 Δmax⁡≈8\Delta_{\max}\approx8 足够保证 exp⁡(−Δ)\exp(-\Delta) 不下溢),然后按

    k=round ⁣(Δ⋅2B/Δmax⁡)k=\mathrm{round}\!\bigl(\Delta\cdot 2^{B}/\Delta_{\max}\bigr)

    把浮点 Δ\Delta 映射到 BB 位无符号整数索引。

  2. 查表:在片上常量表里直接读
    e−Δ^=LUT[k]\widehat{e^{-{\Delta}}}=LUT[k]。

  3. (可选)内插:如果把索引拆成高 bhib_{\text{hi}} 位 + 低 blob_{\text{lo}} 位,可在 2 个相邻表值之间做一次线性(或二次)插值,以用更小表换同样精度。 (aclanthology.org)

1.2 表规模与精度的经验公式

对单调光滑函数,查表误差大约与步长的 2 阶导数成正比:

ϵmax⁡≈f′′(ξ)8⋅(Δmax⁡2B)2,ξ∈[0,Δmax⁡]\epsilon_{\max}\approx\frac{f''(\xi)}{8}\cdot\left(\frac{\Delta_{\max}}{2^B}\right)^2,\qquad \xi\in[0,\Delta_{\max}]

把 f(Δ)=e−Δf(\Delta)=e^{-\Delta} 的 2 阶导数代进去,并要求 ϵRMSE≤10−3\epsilon_{\text{RMSE}}\le10^{-3},可推得

2B≳Δmax⁡8ϵmax⁡≈210–122^{B}\gtrsim\frac{\Delta_{\max}}{\sqrt{8\epsilon_{\max}}}\approx 2^{10\text{–}12}

也就是 需要 1024 – 4096 个表项,对应 10–12 位地址。多篇软硬件评测都给出了类似量级的实测结论:

  • 线性插值 64 段 (6 + 6 bit) RMSE ≈ 2.3×10-7

  • 直接查表 128 段 (7 bit) 在 COCO/Faster-RCNN 上已明显掉点,需要 256 或 512 段 才 <0.1 % mAP 损失

  • FPGA 实验显示 8-bit 表(256 段)RMSE≈3×10-3,而 10-bit (1024 段) 能降到 7×10-4 级 (oar.a-star.edu.sg)

这就是作者说 “INT8(8 bit)不够用,要 10–12 bit” 的由来:如果你只有 8 bit 索引,最多 256 段,误差往往在 2–5×10-3 量级,在 LLM 上会带来 0.5–1 % 以上的 log-prob 漂移。


2 GPU/Triton 上怎么把 LUT-exp 做出来?

步骤

说明

典型实现位置

Clamp & scale

Δ←clip(Δ,0,Δmax⁡)\Delta\gets\mathrm{clip}(\Delta,0,\Delta_{\max});<fp32>→<uint16> 乘常数

Tensor Core 前的 warp-wide ALU

取高 bhib_{\text{hi}} 位做 表索引

LUT 存到 __constant__ 或 warp-private shared_memory(≤4 KB)

每 thread 一次读

取低 blob_{\text{lo}} 位做 线性权重

fma 计算 (1−λ)v0+λv1(1-\lambda)v_0+\lambda v_1

不超过 4 ALU 指令

结果乘 2−scale2^{-scale} 合回 softmax 管线

与在线 softmax 的行归一化一起完成

已在 Flash-2 原始 kernel 里有此步骤

经验值:在 A100 上 1024 × fp16 LUT (2 KB)加载到 shared 后,每查一次 exp 只需 1 cycle;线性内插总开销 <5 % Kernel 时间。(arxiv.org)


3 为什么 INT-FlashAttention 最终没选择 LUT?

  1. 指数仍要 FP32 → INT32 转换
    即便有 LUT,你仍需把整数点积 Sint32S_{\text{int32}} 变成 Δ\Delta 的浮点差值才能索引表,这一步和作者当前做的 “一次性 INT32→FP32” 成本相同。

  2. 共享内存压力 :FlashAttention 的 tile 内本就塞满 Q,K,VQ,K,V block,若再多挂一个 4 KB LUT,tile 尺寸得减小,反而吞吐下降。

  3. 8-bit 精度不足:Ampere 的 IMMA 只有 8-bit 整乘;如果表本身就 8-bit,则误差偏大;如果表存 16-bit 又抬高带宽。

  4. 作者实测:两次 INT8 GEMM + 少量 FP32 expf < 一次 LUT + INT8 GEMM,而且误差更低

因此论文才把重点放在 “把 exp 继续留在 FP32,但让 P 直接量化成 INT8” 的路径,而不是整型-LUT 路径。


4 如果你真的想自己做 LUT-Softmax

  1. 确定 Δmax⁡\Delta_{\max}:对 LLM 常见的 head_dim ≤ 256,Δmax⁡=8\Delta_{\max}=8 足够 (arxiv.org)。

  2. 选表结构

    • 直接查表:B=10B=10(1024 项)即可 RMSE≤1e-3;存 fp16 值 → 2 KB。

    • 线性内插:高 8 bit 索引 + 4 bit 权重 (≈16 KB LUT) 可把表减半。

    • 二次内插:进一步减到 256 基点,但每次多 5 乘法,GPU 上得不偿失。

  3. Triton 代码片段

    # delta_fp32: [Br, Bc]  
    scale = (1<<B) / delta_max  # const  
    k = tl.reshape(tl.round(delta_fp32 * scale), delta_fp32.shape).to(tl.uint16)  
    v0 = tl.load(lut_base + k)          # fp16  
    v1 = tl.load(lut_base + k + 1)      # fp16  
    lam = (delta_fp32 * scale) - tl.reshape(k.to(tl.float32), delta_fp32.shape)  
    exp_approx = v0 + lam * (v1 - v0)   # 线性插值  
    

    lut_base 预先 tl.program_id(0) 时 copy 到 shared_memory

  4. 误差验证:跑 100 万随机 Δ\Delta 取 RMSE,确保 <1e-3;若不够就加位宽或改二次插值。


5 推荐阅读/实现细节来源

  • Milakov & Gimelshein “Online normalizer calculation for softmax” (arxiv.org)

  • Chen et al. “Hardware-aware exponential approximation” (ICLR’18 workshop)

  • Leiva-Valverde et al. “A Quantitative Evaluation of Approximate Softmax Functions” (2025)

  • Xilinx FPGA piece-wise LUT softmax实现 (aclanthology.org)

  • Chen & Lombardi “Approximate softmax for energy-efficient DNNs” (IEEE VLSI ’23) (pdfs.semanticscholar.org)

  • Zhang et al. “Hardware-aware softmax linear approximation” (2021) (oar.a-star.edu.sg)

  • Wang et al. “RISC-V ISA extension for accelerated softmax” (2025) (arxiv.org)

  • Shah et al. “FlashAttention-3” (2024) 对比的 FP8 block 量化

  • Blog post “Writing a faster exp() with multi-level LUT” (deathandthepenguinblog.wordpress.com)

这些资料里都有关于 LUT 大小、位宽与误差之间 trade-off 的实测或推导,可作为你自行实现时的设计依据。