type
status
date
slug
summary
tags
category
icon
password

背景

Transformers核心组件self-attention的空间复杂度为 ,为序列长度。从显存层面限制了模型长上下文upper bound。对于标准的self-attention计算而言,需要频繁进行HBM(high bandwidth memory, HBM)和SRAM的内存读写,存在IO瓶颈。
FlashAttention的核心创新点是通过online-softmax和tiling技巧来将self-attention的空间复杂度降至,并减少HBM与SRAM的IO通信。
image crop from flash-attention paper
image crop from flash-attention paper

Self Attention的定义

假定当前层attention的输入为为sequence的长度。通过 3个线性层得到query,key,value。
经过Self Attention层后,输出
  • 表示对每一行进行 softmax(行归一化)
notion image
在实际计算中序列长度远远大于hidden size 。因此Self attention的空间复杂度是,限制了LLM的long context的能力。

Self Attention的递归形式

为视角,将式(2)展开计算
notion image
定义
可以写成:
其中
简单对式(5)做一下变形,得到递推式
初值条件
根据上述递推式和初值条件,可以用递推的方式计算self-attention

考虑数值稳定性的递归形式

由于指数函数时,随着的增大指数增加,在计算中很容易出现数值上溢。实践中一般会对softmax的指数部分减去一个最大值,将指数部分的定义域控制在,以此来保证计算稳定性。
显然存在以下关系
在递推场景下,我们无法得到,我们该如何利用这个技巧呢
递推场景下,在时刻,虽然无法知道,但我们能够知道
将式(5)两边同乘
根据式(8)的结论,可以得到下面的递归式
不妨令
则上式可以简写为
初值条件
完整阐述一下算法流程如下:
用上述递归方案计算self-attention虽然复杂度仍为,但空间复杂度降至

考虑共享内存减少IO通信

上面的实现中,第二个Loop共用一个。在实践中,由于context的存在,kv的序列长度往往大于q。因此工程实现上通常是将kv放到第一个Loop,可以复用kv,此时的算法流程如下

Memory Efficient与Compute Efficient的Trade-off

虽然上面的实现能显著降低计算self-attention的显存占用,但没有利用GPU的并行能力,在实践中,会通过Tiling来提升并行度。
假定将分为组,将分为具体的算法流程如下:
该形式也是flash-attention的核心思想。注意,算法4和flash-atention论文中的形式有一点区别,但二者是等价的,感兴趣的读者可以自行推导验证。

小结

本文相对系统的介绍了flash-attention算法的核心思想。
局限性:文中的推导未考虑causal mask,多头并行,backward的梯度计算,更为深入的理解请参考flash-attention的官方实现。
若有疏漏之处,敬请指出。

Reference

[1] Online normalizer calculation for softmax
[2] Self-attention Does Not Need O(n2) Memory
[3] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

附录

注意⚠️: 本文代码仅用于展示算法推导过程,未包含工程实现所需的优化(如 causal mask、多头并行、dropout、backward kernel、半精度稳定性等)。实际工程中请参考 FlashAttention 官方实现。
 
 
相关文章
BLIP系列文章小结(BLIP, BLIP-2, InstructBLIP)
Lazy loaded image
BLIP-2小结
Lazy loaded image
BLIP 小结
Lazy loaded image
BLIP3技术小结(xGen-MM (BLIP-3): A Family of Open Large Multimodal Models)
Lazy loaded image
MM1技术小结(MM1: Methods, Analysis & Insights from Multimodal LLM Pre-training)
Lazy loaded image
🔥Lit: 进一步提升多模态模型Zero-Shot迁移学习的能力
Lazy loaded image
KV-Cache技术小结(MHA,GQA,MQA,MLA)匈牙利算法小结
Loading...
莫叶何竹🍀
莫叶何竹🍀
非淡泊无以明志,非宁静无以致远
最新发布
Step by Step Understanding Flash-Attention
2025-9-28
RL学习小结 (002): 策略梯度理论
2025-9-1
RL学习小结 (001): 基本概念、贝尔曼方程
2025-9-1
diffusion model(十九) :SDE视角下的扩散模型
2025-8-15
阅读顺序还原技术剖析——LayoutReader
2025-7-24
多模态模型如何处理任意分辨率输入——Tiling与Packing技术详解
2025-5-24