INT-FlashAttention: Enabling Flash Attention for INT8 Quantization
在进一步钻研 INT-FlashAttention 的源码与论文后,可以把整套设计拆分为 离线量化 → 带尺度因子的数据流 → INT8-aware Online-Softmax → 双 INT8 GEMM 融合内核 四个互相咬合的层次。下面从 量化数学、内核执行流、数值稳定性、硬件映射 四条主线,把论文中没展开的关键细节补齐。
一段话总览
INT-FlashAttention 把 FlashAttention-2 的「块化 + 在线 softmax」算法完全保持不变,但把所有矩阵运算改成 INT8×INT8→INT32 Tensor Core IMMA,并在块循环里 显式管理 Q/K 的逐 token 尺度向量、V 的全局尺度常量,以及 softmax 的行归一化量。借助整数乘法的线性可交换性,它把缩放操作“吸收”进整数域运算,做到 中间权重矩阵直接以 INT8 存进 SRAM,再次作为 INT8 输入参与第二次 GEMM;最终只在输出阶段做一次矢量化 dequant,从而让 A100/RTX 4090 在 1 k–16 k token 上普遍比 FP16 FlashAttention 提速 1.7–1.9 ×,同时把 Flash-FP8 的 MRE 误差降到原来的 ≈¼ (arxiv.org, arxiv.org)。
1 离线量化策略
1.1 逐 token 对称线性量化
对每个 token 的向量 qi∈Rdh\mathbf{q}_i\in\mathbb{R}^{d_h} ,用单独的尺度
sq[i]=max(∣qi∣)127,q^i=round (qi/sq[i])∈[−127,127]dhs_q[i]=\frac{\max(|\mathbf{q}_i|)}{127},\qquad \hat{\mathbf q}_i=\operatorname{round}\!\bigl(\mathbf{q}_i/s_q[i]\bigr)\in[-127,127]^{d_h}
K 同理得到 k^j,sk[j]\hat{\mathbf k}_j,s_k[j];V 目前用 tensor-level 尺度 svs_v(作者计划改成 block 级)(arxiv.org)。这样 显存只多两条 float32 向量 (|Q| + |K|) × 4 B,却显著降低 outlier 主导的量化误差,相比张量级 FP8 误差下降 46–82 %(marktechpost.com)。
1.2 尺度缓存布局
sq,sks_q,s_k 被分块复制到每个 CTA 的
shared_memory
,避免在 block 循环里重复 HBM 访存。svs_v 仅在 第二次 GEMM 完成后 作为 warp-wide常量参与一次乘法,带宽开销极小(arxiv.org)。
2 INT8-aware Online-Softmax 细节
下面用论文 Algorithm 1 的行号标注核心步骤(arxiv.org)。
2.1 第一次 GEMM (Q Kᵀ)
INT8 IMMA:
\hat Q_{blk} (INT8) × \hat K_{blk}^T (INT8) → S_int32
(行 10)。矢量放缩:
S_fp32 = S_int32 × s_q[i] × s_k[j]
— 这里先乘 INT8,再一次性放缩到 fp32,保留精度而不溢出 INT32 范围(|INT8|·|INT8|·dim_k ≤ 127²·256 ≪ 2³¹)。行最大值 r 与指数和 m 用 fp32 更新(行 10–11),保持数值稳定。
2.2 权重矩阵直接量化
Softmax 局部输出
pij=eSij−ri/mi∈(0,1)p_{ij}=e^{S_{ij}-r_i}/m_i \in(0,1)
天然落在 [0,1][0,1],故作者直接用
p^ij=round(127⋅pij)\hat p_{ij}= \operatorname{round}(127·p_{ij})
把 PP 写成 INT8(行 12)——这是关键一招:
省掉写回 fp32 → 再量化的带宽;
让第二次 GEMM 也能吃 INT8,彻底摆脱 FP16 Tensor Core(arxiv.org)。
2.3 第二次 GEMM (P V)
\hat P_{blk} (INT8) × \hat V_{blk} (INT8) → O_int32
(行 13)。行归一化 & dequant:
O_fp32 = (O_int32 × s_v) / (m_i /127)
(行 16)。注意把 127 提前合入 dequant 避免除法。
归一化向量 1/mi1/m_i 也缓存在
shared_memory
。
整条数据流只产生一次 fp32-V 反量化,softmax 与尺度重分配全部在片上完成。
3 核函数执行流
3.1 CTA-tile 组织
输入序列分为 Br×BcB_r × B_c 行×列块;Ampere 上经验 tile = 128 × 128 可同时满足 warp 匹配与
shared_memory
(≈48 KB)约束(arxiv.org)。双缓冲:一边 GEMM,一边下一块数据 DMA 到
smem
,隐藏 HBM 延迟。
3.2 Triton Autotune 策略
GitHub 实现通过 16 条 configs.py
列举 (BLOCK_M, BLOCK_N, BLOCK_K)
,Triton autotuner 在启动时扫遍 INT8 IMMA 性能曲线,比 CUTLASS 手写 kernel 少 100+ 行代码,同时可在 4090 与 A100 上共用内核(github.com)。
3.3 指令级并行
IMMA 1688 指令一次完成 4×4×4 INT8 dot-product,router micro-pipe 能在 8 cycle 内吐出 256 INT32 MAC。
行最大值 r、指数和 m 用 warp-wide
__shfl_xor_sync
做行规约,平均两条指令更新一次(对 128 token 行只需 7 step),摊薄 fp32 ALU 消耗(medium.com)。
4 数值稳定性与误差来源
实验结果显示全 INT8 方案在 normal/uniform 输入下的 MRE 分别为 4.0 % / 1.7 %,远低于 Flash-FP8 的 7.5 % / 9.0 % (arxiv.org)。
5 硬件收益与对比
因此在 无 FP8 支持的 Ampere 客户端,INT-FA 把 Flash-3 的优势完整复刻并提升精度。
6 落地提示 & 实践踩坑
安装
pip install triton==2.1.0 # 保证>=2.1 内置 INT8 IMMA pip install int-flashattention
把
from flash_attn import flash_attn_func y = flash_attn_func(q, k, v)
改为
from int_flashattention import int_flash_attn_func y = int_flash_attn_func(q, k, v)
其余接口保持一致。
维度限制 :当前内核 hard-code
head_dim≤256
;更高维需要拆 K 分块或切换 FP16。INT8 GEMM 反而慢? 若遇到
int8 gemm slower than fp16
,大多是 WARP_K 未调优;调大configs.py.WARP_K
到 16 可恢复 IMMA 吞吐(github.com)。梯度/训练 :作者未实现反向,若要蒸馏或 LoRA,可在前向用 INT-FA,反向自动回退 FP16;或参考 Triton issue #1187 的 INT8 dgrad 讨论(github.com)。
7 局限与未来工作
V 仅张量级尺度,长上下文误差随 token 累积;论文承诺按 block(token) 级改进(arxiv.org)。
缺少训练级支持:INT8 softmax 梯度需重新推导数值稳定形式。
Hadamard-INT8 规划:作者计划把 QK 乘法替换为 Hadamard 转换 + INT8 点积,可望再降 40 % FLOPs (arxiv.org)。
一句话建议:如果你当前在 Ampere GPU 上部署 Llama-2/3,且已采用权重量化(GPTQ/SpQR 等),直接把 FlashAttention-2 换成 INT-FlashAttention,可在 不改权重文件的前提下 再免费获得 ~1.8 × 端到端推理加速,并进一步压低激活误差。
Appendix
在 INT-FlashAttention 里作者提到的 “LUT 方案需要 10 – 12 bit 地址才能把 softmax 的 exp 近似误差压到 ≤ 1 × 10-3 RMSE”,指的是最朴素、硬件界广泛使用的 “整值查表 +(可选)线性/二次内插” 方法来近似
f(Δ)=e−Δ,Δ=Sij−mi∈[0,Δmax]f(\Delta)=e^{-\Delta},\qquad \Delta = S_{ij}-m_i\in[0,\Delta_{\max}]
(在线 softmax 里我们先减去行最大 mim_i,所以 Δ\Delta 非负并有上界)。下面把它的 原理、为什么要 10–12 bit、以及 GPU/FPGA 上如何具体落地 分开讲。
1 经典 LUT-exp 方案长什么样?
1.1 定点化 → 查表
裁剪 & 量化:把 Δ\Delta 裁剪到 [0,Δmax][0,\Delta_{\max}](经验上 Δmax≈8\Delta_{\max}\approx8 足够保证 exp(−Δ)\exp(-\Delta) 不下溢),然后按
k=round (Δ⋅2B/Δmax)k=\mathrm{round}\!\bigl(\Delta\cdot 2^{B}/\Delta_{\max}\bigr)
把浮点 Δ\Delta 映射到 BB 位无符号整数索引。
查表:在片上常量表里直接读
e−Δ^=LUT[k]\widehat{e^{-{\Delta}}}=LUT[k]。(可选)内插:如果把索引拆成高 bhib_{\text{hi}} 位 + 低 blob_{\text{lo}} 位,可在 2 个相邻表值之间做一次线性(或二次)插值,以用更小表换同样精度。 (aclanthology.org)
1.2 表规模与精度的经验公式
对单调光滑函数,查表误差大约与步长的 2 阶导数成正比:
ϵmax≈f′′(ξ)8⋅(Δmax2B)2,ξ∈[0,Δmax]\epsilon_{\max}\approx\frac{f''(\xi)}{8}\cdot\left(\frac{\Delta_{\max}}{2^B}\right)^2,\qquad \xi\in[0,\Delta_{\max}]
把 f(Δ)=e−Δf(\Delta)=e^{-\Delta} 的 2 阶导数代进去,并要求 ϵRMSE≤10−3\epsilon_{\text{RMSE}}\le10^{-3},可推得
2B≳Δmax8ϵmax≈210–122^{B}\gtrsim\frac{\Delta_{\max}}{\sqrt{8\epsilon_{\max}}}\approx 2^{10\text{–}12}
也就是 需要 1024 – 4096 个表项,对应 10–12 位地址。多篇软硬件评测都给出了类似量级的实测结论:
线性插值 64 段 (6 + 6 bit) RMSE ≈ 2.3×10-7
直接查表 128 段 (7 bit) 在 COCO/Faster-RCNN 上已明显掉点,需要 256 或 512 段 才 <0.1 % mAP 损失
FPGA 实验显示 8-bit 表(256 段)RMSE≈3×10-3,而 10-bit (1024 段) 能降到 7×10-4 级 (oar.a-star.edu.sg)
这就是作者说 “INT8(8 bit)不够用,要 10–12 bit” 的由来:如果你只有 8 bit 索引,最多 256 段,误差往往在 2–5×10-3 量级,在 LLM 上会带来 0.5–1 % 以上的 log-prob 漂移。
2 GPU/Triton 上怎么把 LUT-exp 做出来?
经验值:在 A100 上 1024 × fp16 LUT (2 KB)加载到
shared
后,每查一次 exp 只需 1 cycle;线性内插总开销 <5 % Kernel 时间。(arxiv.org)
3 为什么 INT-FlashAttention 最终没选择 LUT?
指数仍要 FP32 → INT32 转换
即便有 LUT,你仍需把整数点积 Sint32S_{\text{int32}} 变成 Δ\Delta 的浮点差值才能索引表,这一步和作者当前做的 “一次性 INT32→FP32” 成本相同。共享内存压力 :FlashAttention 的 tile 内本就塞满 Q,K,VQ,K,V block,若再多挂一个 4 KB LUT,tile 尺寸得减小,反而吞吐下降。
8-bit 精度不足:Ampere 的 IMMA 只有 8-bit 整乘;如果表本身就 8-bit,则误差偏大;如果表存 16-bit 又抬高带宽。
作者实测:两次 INT8 GEMM + 少量 FP32 expf < 一次 LUT + INT8 GEMM,而且误差更低
因此论文才把重点放在 “把 exp 继续留在 FP32,但让 P 直接量化成 INT8” 的路径,而不是整型-LUT 路径。
4 如果你真的想自己做 LUT-Softmax
确定 Δmax\Delta_{\max}:对 LLM 常见的 head_dim ≤ 256,Δmax=8\Delta_{\max}=8 足够 (arxiv.org)。
选表结构
直接查表:B=10B=10(1024 项)即可 RMSE≤1e-3;存 fp16 值 → 2 KB。
线性内插:高 8 bit 索引 + 4 bit 权重 (≈16 KB LUT) 可把表减半。
二次内插:进一步减到 256 基点,但每次多 5 乘法,GPU 上得不偿失。
Triton 代码片段
# delta_fp32: [Br, Bc] scale = (1<<B) / delta_max # const k = tl.reshape(tl.round(delta_fp32 * scale), delta_fp32.shape).to(tl.uint16) v0 = tl.load(lut_base + k) # fp16 v1 = tl.load(lut_base + k + 1) # fp16 lam = (delta_fp32 * scale) - tl.reshape(k.to(tl.float32), delta_fp32.shape) exp_approx = v0 + lam * (v1 - v0) # 线性插值
(
lut_base
预先tl.program_id(0)
时 copy 到shared_memory
)误差验证:跑 100 万随机 Δ\Delta 取 RMSE,确保 <1e-3;若不够就加位宽或改二次插值。
5 推荐阅读/实现细节来源
Milakov & Gimelshein “Online normalizer calculation for softmax” (arxiv.org)
Chen et al. “Hardware-aware exponential approximation” (ICLR’18 workshop)
Leiva-Valverde et al. “A Quantitative Evaluation of Approximate Softmax Functions” (2025)
Xilinx FPGA piece-wise LUT softmax实现 (aclanthology.org)
Chen & Lombardi “Approximate softmax for energy-efficient DNNs” (IEEE VLSI ’23) (pdfs.semanticscholar.org)
Zhang et al. “Hardware-aware softmax linear approximation” (2021) (oar.a-star.edu.sg)
Wang et al. “RISC-V ISA extension for accelerated softmax” (2025) (arxiv.org)
Shah et al. “FlashAttention-3” (2024) 对比的 FP8 block 量化
Blog post “Writing a faster exp() with multi-level LUT” (deathandthepenguinblog.wordpress.com)
这些资料里都有关于 LUT 大小、位宽与误差之间 trade-off 的实测或推导,可作为你自行实现时的设计依据。