Flash Attention: 작동 방식과 중요성
요약
Flash Attention은 트랜스포머 모델의 어텐션 연산 시 발생하는 메모리 대역폭 병목 현상을 해결하기 위한 기술입니다. HBM 대신 빠른 SRAM을 활용한 타일링(Tiling) 기법을 통해 불필요한 데이터 이동을 최소화하고 연산 속도를 2-4배 향상시킵니다.
핵심 포인트
- 어텐션 연산의 병목은 연산량이 아닌 메모리 대역폭(IO-bound) 문제임
- SRAM 내에서 연산을 수행하는 타일링(Tiling) 기법을 통해 HBM 트래픽 제거
- 모델 변경이나 정밀도 손실 없이 2-4배의 속도 향상 달성
- N x N 크기의 거대한 어텐션 행렬을 메모리에 직접 생성하지 않음
Flash Attention: 작동 방식과 중요성
당신의 학습 작업은 시간당 3달러의 비용으로 A100을 사용하고 있습니다. 손실(Loss)은 감소하고 있고, 그래디언트(Gradients)는 흐르고 있으며, 모델의 손실 곡선은 교과서적인 로그(Logarithmic) 형태를 보입니다. 하지만 스텝 시간(Step time)을 프로파일링하고 GPU가 실제로 무엇을 하고 있는지 살펴보면 놀라운 사실을 발견하게 될 것입니다. GPU 연산 유닛(Compute units)이 시간의 40-60% 동안 유휴(Idle) 상태라는 점입니다. 병목 현상(Bottleneck)은 산술 연산이 아니라 메모리 대역폭(Memory bandwidth)입니다. GPU의 HBM (High-Bandwidth Memory, A100 기준 1.5-2 TB/s)이 연산 유닛이 데이터를 소비하고자 하는 속도를 따라가지 못하고 있습니다. 그리고 모든 트랜스포머(Transformer) 학습 또는 추론 실행에서 메모리 트래픽의 가장 큰 비중을 차지하는 것은 어텐션(Attention) 연산이며, 이는 단순한 방식(Naive)으로 매 포워드 패스(Forward pass)마다 전체 N x N 어텐션 행렬을 HBM에 읽고 씁니다.
Flash Attention은 바로 그 문제를 해결하기 위해 존재합니다. 어텐션 연산을 GPU의 SRAM (빠른 온칩 메모리, A100 기준 약 20 MB) 내부에 완전히 머무르는 타일(Tiles)로 융합(Fusing)함으로써 불필요한 HBM 트래픽을 제거합니다. 그 결과, 모델 변경 없이 정밀도(Precision) 손실 없이 어텐션에 제한된(Attention-bound) 워크로드에서 2-4배의 엔드 투 엔드(End-to-end) 속도 향상을 얻을 수 있습니다.
어텐션 메모리 비용이 중요한 이유
단일 헤드(Single head)에서의 표준 셀프 어텐션(Self-attention) 레이어는 각각 (N, d) 형태의 세 가지 행렬 Q, K, V와 함께 작동하며, 여기서 N은 시퀀스 길이(Sequence length)이고 d는 헤드 차원(Head dimension)입니다. 단순한 연산 방식은 다음과 같습니다:
- S = Q @ K^T 계산 -- 형태 (N, N)
- P = softmax(S, dim=-1) 계산 -- 형태 (N, N)
- O = P @ V 계산 -- 형태 (N, d)
결정적인 비용은 S와 P가 각각 N x N 개의 항목을 가진다는 점입니다. d=128인 4096 토큰 시퀀스의 경우, 헤드당 1,600만 개의 항목이 됩니다. FP16 기준, 이는 헤드당 32 MB입니다. 32개의 헤드가 있다면 모든 헤드에 걸친 전체 N x N 행렬은 1 GB가 되며, 이는 단일 A100 GPU의 SRAM 약 20 MB보다 훨씬 큽니다. 표준 구현 방식은 이 1 GB를 HBM에 쓰고(느림), softmax를 위해 다시 읽고(HBM read), 결과를 다시 쓰고(HBM write), 그런 다음 V 곱셈을 위해 다시 읽습니다.
Flash Attention은 softmax 연산을 SRAM에 들어갈 만큼 충분히 작은 블록(block) 단위로 타일링(tiling)함으로써, 이 $N \times N$ 행렬을 완전히 생성(materializing)하는 것을 방지합니다.
Flash Attention이 실제로 하는 일
Tri Dao와 Stanford 그룹(2022)의 핵심 통찰은 attention 연산이 연산량 제한(compute-bound)이 아니라 입출력 제한(IO-bound) 문제이며, 지배적인 비용은 HBM과 SRAM 사이의 데이터 이동이라는 점이었습니다. A100에서 SRAM 대역폭은 약 20 TB/s(연산 유닛에서 SRAM까지)인 반면, HBM 대역폭은 약 2 TB/s입니다. 10배의 차이가 납니다. 만약 연산을 SRAM 내에 머물도록 구조화할 수 있다면, 그것이 승리합니다.
그 메커니즘은 알고리즘적으로 명확합니다:
- **Q, K, V 행렬을 블록화(Block)**하여 SRAM에 들어갈 수 있을 만큼 작은 타일(tile)로 나눕니다.
- 온라인 softmax 알고리즘(점진적으로 업데이트 가능한 안전한 softmax)을 사용하여 각 블록에 대한 부분 softmax(partial softmax)를 계산합니다.
- 레지스터(register)에 블록별 재스케일링(rescaling) 통계치를 유지하면서, **부분 결과들을 출력값에 누적(accumulate)**합니다.
- 헤드(head)당 여러 번 읽고 쓰는 대신, 레이어당 한 번씩 최종 출력을 HBM에 기록합니다.
이는 전형적인 타일링(tiling) 기법이지만, softmax가 전역 정규화(global normalization) 역할을 하는 attention 특유의 문제에 적용된 것입니다. softmax는 전체 행에 대한 분모를 필요로 하기 때문에 타일 단위로 단순히 합산할 수는 없습니다. 이 논문의 핵심적인 알고리즘적 기여는 각 타일이 로컬 softmax를 계산한 다음, 새로운 타일이 도착함에 따라 실행 중인 출력값을 수정할 수 있게 해주는 온라인 안전 softmax(online-safe softmax)입니다.
# Flash Attention 순전파(forward pass) 블록 하나에 대한 의사코드(Pseudocode)
def flash_attention_block(Q_block, K_block, V_block):
# Q_block: (B_r, d), K_block: (B_c, d), V_block: (B_c, d)
...
이 알고리즘은 HBM에서 Q, K, V를 한 번 읽고, SRAM에서 타일 단위로 처리한 다음, O를 HBM에 한 번 씁니다. 단순한(naive) 방식과 비교해 보십시오. 길이가 $N$인 시퀀스의 경우, 표준 구현은 $N \times N$ attention 행렬을 HBM에 읽고 쓰며, 이는 $O(N^2 d)$의 HBM 트래픽을 발생시킵니다. Flash Attention은 이를 $O(N^2 d / M)$으로 줄이며, 여기서 $M$은 SRAM 크기입니다. 즉, SRAM 용량에 비례하여 트래픽이 감소합니다.
다음 다이어그램은 타일링 (Tiling)이 전체 어텐션 행렬 (Attention matrix)의 실체화 (Materialization)를 어떻게 건너뛰는지 보여줍니다:
flowchart TB
subgraph SRAM["GPU SRAM (~20 MB)"]
QB[Q tile<br/>(B_r x d)]
...
HBM에서 SRAM으로 향하는 각 화살표는 느린 DMA 전송 (DMA transfer)입니다. 단순한 구현 (Naive implementation)은 행(row)당, 그리고 헤드(head)당 $O(N)$번의 이러한 전송을 수행합니다. Flash Attention은 K와 V에 대해 정확히 두 번의 패스 (Pass) (읽기 및 타일 단위 처리)를 수행한 후, O를 한 번 기록합니다.
Flash Attention v1 vs v2 vs v3
| 버전 | 연도 | 주요 개선 사항 | 단순 구현 대비 속도 향상 | GPU 타겟 |
|---|---|---|---|---|
| v1 | 2022 | 타일링 (Tiling) + 온라인 소프트맥스 (Online softmax), $O(N^2)$ 회피 | 2x | A100 (Ampere) |
| ... |
Flash Attention v2는 마스크 (Mask) 생성 및 스케일링 (Scaling)에 필요한 상당수의 비-행렬 곱셈 (Non-matrix-multiply) 명령어를 제거했습니다. 이는 Tensor Core가 워크로드(Workload)가 순수 행렬 곱셈 (Matrix multiplication)일 때 가장 효율적이며, 추가적인 요소별 연산 (Elementwise operations)은 활용도 (Utilization)를 떨어뜨리기 때문에 중요합니다. v2 논문에 따르면 65M 파라미터 모델의 단일 순전파 (Forward pass) 시간이 6.5ms (PyTorch 표준)에서 2.6ms (Flash Attention v2)로 단축되었습니다.
2024년에 발표된 Flash Attention v3는 H100의 Hopper 아키텍처를 타겟으로 합니다. 이는 WGMMA 명령어 (Warp-group MMA)를 사용하여, 타일링된 소프트맥스 (Tiled softmax) 패스 동안 GPU가 데이터 이동 (Data movement)과 연산을 중첩 (Overlap)할 수 있게 합니다. v1/v2의 동기식 (Synchronous) SRAM 읽기는 지연 시간 (Latency)을 숨기는 비동기식 (Asynchronous) 복사로 대체되었습니다. 또한, v3는 점수 계산 (Score computation)을 위한 데이터 이동을 다시 절반으로 줄이는 FP8 지원을 도입했습니다.
오늘날 Flash Attention이 사용되는 곳
Flash Attention은 사실상 거의 모든 주요 LLM 프레임워크에 통합되어 있습니다. 가장 일반적인 경로는 PyTorch 2.0부터 Flash Attention 백엔드를 탑재한 PyTorch의 scaled_dot_product_attention (SDPA)을 통하는 것입니다:
import torch.nn.functional as F
# 조건이 충족되면 자동으로 Flash Attention을 사용합니다:
...
대부분의 경우 flash_attn을 직접 임포트(import)할 필요는 없습니다. PyTorch의 SDPA (Scaled Dot Product Attention)는 사용 가능한 최적의 백엔드로 자동 분기(dispatch)합니다. 즉, 사용 가능하다면 Flash Attention을 사용하고, 그렇지 않으면 memory-efficient attention을 사용하며, 둘 다 안 될 경우 naive implementation (기본 구현)으로 폴백(fallback)합니다.
직접 접근하려면 PyPI의 flash-attn 패키지가 FlashAttention 모듈을 제공합니다:
pip install flash-attn
이 명령은 사용자의 CUDA 및 PyTorch 조합에 맞는 사전 빌드된 휠(wheel)을 설치합니다 (PyPI 휠은 v2.8.x부터 사용 가능합니다). 만약 사용자의 구성에 맞는 휠이 존재하지 않는다면, 소스에서 빌드하는 데 약 15분이 소요되며 CUDA 컴파일러가 필요합니다.
from flash_attn import flash_attn_func
output = flash_attn_func(
...
flash_attn_func API를 사용하면 백엔드 파라미터(parameters)를 직접 제어할 수 있으며, 이는 vLLM, Hugging Face transformers, 그리고 torch.compile 경로에서 사용하는 방식입니다.
일반적인 실수 (Common pitfalls)
is_causal / padding 상호작용. 만약 인과적 마스크 (causal mask)와 별도의 패딩 마스크 (padding mask, 서로 다른 길이의 배치 시퀀스를 위한 것)를 함께 사용한다면, 이들 사이의 상호작용은 간단하지 않습니다. Flash Attention이 이를 처리해야 하지만, 인과적 마스크와 개별 패딩이 모두 포함된 attn_mask를 전달할 때는 주의 깊은 구성이 필요합니다. 가장 안전한 방법은 causal=True로 두고 동일한 길이로 패딩하거나, 적절한 위치에 -inf가 포함된 전체 N x N 크기의 배치별 마스크를 사용하는 것입니다.
헤드 차원 (Head dimension) 제한. Flash Attention은 역사적으로 헤드 차원에 제약이 있었습니다. v1은 head_dim <= 128을 요구했습니다. v2는 이를 head_dim <= 256으로 늘렸습니다. v3는 최대 256까지 지원합니다. 모델이 head_dim=96 또는 head_dim=64를 사용한다면 문제없습니다. 만약 head_dim=512를 실험 중이라면 (드물지만 일부 Vision Transformer에서 볼 수 있음), Flash Attention은 해당 어텐션 연산을 가속화할 수 없습니다.
CUDA graph 호환성. Flash Attention은 타일 크기(tile size)에 따라 가변적인 양의 공유 메모리 (shared memory)를 사용하며, 이는 CUDA graph 캡처 (capture) 시 문제를 일으킬 수 있습니다. 만약 mode="reduce-overhead" 옵션과 함께 torch.compile을 사용 중이라면, Flash Attention 커널이 그래프 캡처를 방해하지 않는지 테스트하십시오. v2.8.x 버전에서 이 부분이 개선되었으나, 모든 PyTorch 버전에서 상호작용이 보장되는 것은 아닙니다.
AMD GPU 및 비-CUDA 백엔드. Flash Attention은 CUDA 커널입니다. 따라서 별도의 설정 없이 AMD ROCm에서 실행되지 않습니다. ROCm 생태계에는 triton 기반의 Flash Attention이라는 대안적인 구현체가 있지만, 이는 성능 특성이 다르며 즉시 교체 가능한 (drop-in replacement) 방식이 아닙니다. AMD GPU를 사용 중이라면 성능이 동일하다고 가정하기 전에 벤치마크를 수행하십시오.
SDPA의 자동 폴백 (fallback)이 문제를 숨길 수 있음. PyTorch의 SDPA는 Flash Attention 조건을 충족하지 못할 경우 조용히 기본 구현 (naive implementation)으로 폴백하기 때문에, 서로 다른 GPU 유형에서 의도치 않게 다른 커널이 사용되어도 이를 알아차리지 못할 수 있습니다. 재현 가능한 성능이 중요하다면 어떤 SDPA 백엔드가 선택되었는지 항상 로그를 남기십시오.
사용하지 말아야 할 때
다음과 같은 경우 Flash Attention은 잘못된 최적화입니다:
-
병목 현상이 Attention이 아닌 MLP 레이어에 있는 경우. 배치 크기(Batch size)가 1이고 시퀀스 길이(Sequence length)가 짧은(512 토큰 미만) 추론(Inference) 워크로드의 경우, Attention 연산은 전체 시간에서 차지하는 비중이 매우 작습니다. MLP 프로젝션(Projections)이 지배적입니다. 이 경우 Attention을 최적화하면 2
4배가 아닌 510%의 속도 향상만을 얻게 됩니다. 먼저 프로파일링(Profile)을 수행하십시오. -
CPU 추론을 사용하는 경우. Flash Attention은 CUDA 지원 GPU를 필요로 합니다. CPU는 완전히 다른 Attention 경로를 사용합니다.
-
정수 전용 Attention이 필요한 경우 (예: CPU/엣지 디바이스에서의 양자화된 KV 캐시). Flash Attention은 CUDA로 구현되어 있으며 FP16/BF16 데이터를 기대합니다. 양자화된 Attention 커널(MatMul-free LLM 등)은 다른 알고리즘을 사용합니다.
-
빠른 반복을 위해 작은 모델을 학습시키는 경우. 모델의 에포크(Epoch)당 소요 시간이 30초라면, Attention을 최적화해도 병목 현상을 해결할 수 없습니다. Flash Attention을 임포트하고 설정하는 오버헤드(크지는 않지만 0은 아님)가 낭비되는 노력이 될 수 있습니다.
-
시퀀스 길이가 매우 긴 경우 (100K+ 토큰). 매우 긴 시퀀스의 경우, SDPA(일반적인 길이에서의 Flash Attention)의 메모리 효율적 Attention은 여전히 타일링(Tiling)의 효과를 떨어뜨리는 HBM 패스(HBM pass)를 요구할 수 있습니다. 100K 토큰 이상에서는 단일 GPU의 SRAM 내부가 아닌 GPU 간에 샤딩(Shard)하는 Ring Attention / DeepSpeed Ulysses / Stripe Attention 방식이 더 적합합니다.
요약 (TL;DR)
- Flash Attention은 Q, K, V 행렬을 GPU SRAM에 들어갈 수 있는 블록 단위로 타일링(tiling)하며, HBM(High Bandwidth Memory)에 전체 N x N 어텐션 행렬을 생성하지 않고도 소프트맥스(softmax)를 온라인(online) 방식으로 계산합니다.
- v2.8.3.post1은 현재 안정적인 릴리스 버전입니다 (2026년 6월 기준). v2는 병렬성(parallelism)을 개선하고 길이 제한을 제거했습니다. v3는 H100 전용 WGMMA 명령어와 FP8 지원을 추가했습니다.
- 모델 아키텍처의 변경 없이도 정밀도 손실(precision loss) 없이 A100급 GPU에서는 2-4배, H100에서는 3-7배의 속도 향상을 얻을 수 있습니다.
- PyTorch의
F.scaled_dot_product_attention을 통해 자동으로 사용하거나,flash_attn패키지를 통해 직접 사용할 수 있습니다. - head_dim 제한(v2/v3에서 최대 256), CUDA graph 호환성, 그리고 성능 저하를 숨길 수 있는 SDPA 백엔드의 암묵적인 폴백(fallback) 현상을 주의해야 합니다.
- 병목 지점이 어텐션이 아니거나, CPU/AMD 환경을 사용 중이거나, GPU 간 샤딩(sharding)이 필요한 극단적인 시퀀스 길이를 다루는 경우에는 Flash Attention을 사용하지 마십시오.
다음 포스트: 샘플링 전략(sampling strategies)에 대한 실질적인 비교 — temperature, top-p, top-k, min-p, 그리고 실제 운영 시스템에서 어떤 것이 더 나은 출력 품질을 만들어내는지에 대해 다룹니다.
AI 자동 생성 콘텐츠
본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기