Flash Attention: N × N 메모리 폭증 없는 정확한 어텐션 (Exact Attention)
요약
Flash Attention은 트랜스포머 모델의 어텐션 연산 시 발생하는 N×N 메모리 병목 현상을 해결하는 기술입니다. 타일링과 온라인 소프트맥스 기법을 통해 거대한 어텐션 행렬을 HBM에 쓰지 않고 SRAM 내에서 처리하여 연산 속도를 높이고 메모리 사용량을 줄입니다.
핵심 포인트
- 어텐션 연산은 계산량보다 메모리 대역폭에 제한을 받는 메모리 바운드 특성을 가짐
- N×N 크기의 어텐션 행렬을 HBM에 저장하지 않고 작은 블록 단위로 처리
- 타일링(Tiling) 기법을 통해 데이터를 SRAM 내에서 효율적으로 관리
- 온라인 소프트맥스(Online Softmax)를 사용하여 수치적 안정성을 유지하며 스트리밍 연산 수행
만약 당신이 트랜스포머 (Transformer)를 더 긴 컨텍스트로 밀어붙이려다 GPU 메모리 부족(Out of Memory) 현상을 목격한 적이 있다면, 당신은 어텐션 (Attention)의 진정한 병목 지점을 마주한 것입니다. 그것은 곱셈의 횟수가 아닙니다. 어텐션이 메모리에 쓰고자 하는 거대한 행렬입니다. Flash Attention은 정확히 동일한 답을 계산하면서도 그 행렬을 사라지게 만드는 기술입니다.
메모리가 어디로 가는가
셀프 어텐션 (Self-attention)은 한 줄로 요약됩니다: 쿼리 (Queries) Q, 키 (Keys) K, 값 (Values) V에 대해, 출력은 softmax(QKᵀ / √d) · V입니다. 모든 쿼리는 모든 키에 대해 점수를 매기고, 이 점수들은 softmax를 통해 가중치 (Weights)가 되며, 이 가중치들이 값들을 혼합합니다.
문제는 QKᵀ에 숨어 있습니다. 이는 모든 토큰 쌍에 대해 하나의 항목을 가지므로, 길이가 N인 시퀀스에 대해 그 형태는 N × N이 됩니다. 이는 이차적 (Quadratic)입니다. 토큰이 1,000개에서 2,000개로 늘어나면 점수 행렬은 두 배가 되는 것이 아니라 네 배가 됩니다. 8,192개 토큰의 경우 약 6,700만 개의 숫자, 즉 약 128 MB를 보게 되는데, 이는 어텐션 헤드 (Attention head)당, 레이어 (Layer)당 발생하는 수치입니다. 점수들은 메모리에 쓰여지고, softmax를 위해 다시 읽히고, 가중치로서 다시 쓰여지며, V와 곱해지기 위해 한 번 더 읽힙니다. 거대 모델의 경우 이러한 트래픽은 엄청나며, 모델 자체보다 더 빠르게 증가합니다.
데이터 이동이 실제 비용인 이유
여기가 대부분의 사람들이 건너뛰는 부분입니다. GPU에는 두 종류의 메모리가 있습니다. 온칩 SRAM (On-chip SRAM)은 단위당 수십 KB로 매우 작지만 믿을 수 없을 정도로 빠릅니다. HBM은 수십 기가바이트에 달하는 거대한 "GPU RAM" 풀이지만, 접근 속도는 대략 10배 정도 느립니다. 어텐션은 데이터를 옮길 때마다 수행하는 산술 연산이 상대적으로 적으며, 이는 어텐션이 수학 계산을 하기보다 HBM을 기다리는 데 대부분의 시간을 소비한다는 것을 의미합니다. 시스템 언어로 표현하자면, 어텐션은 메모리 바운드 (Memory-bound)입니다.
이 단 하나의 사실이 최적화 목표 전체를 바꿉니다. 만약 메모리 바운드라면, 곱셈을 줄이는 것은 거의 도움이 되지 않습니다. 도움이 되는 것은 데이터를 덜 이동하는 것입니다. 따라서 목표는 다음과 같습니다: 그 N × N 행렬을 느린 메모리에 절대 쓰지 않는 것.
타일링 (Tiling)과 스트리밍 softmax (Streaming softmax)
Flash Attention은 Q, K, V를 빠른 SRAM 내에 들어갈 수 있는 작은 블록(block)들로 나눕니다. 바깥쪽 루프에서는 쿼리(query) 블록들을 순회하고, 각 쿼리 블록에 대해 안쪽 루프에서 키/값(key/value) 블록들을 스트리밍(streaming)합니다. 어느 순간에도 칩 위에는 하나의 작은 쿼리 타일(tile), 하나의 작은 키/값 타일, 그리고 하나의 스코어(score) 블록만이 존재합니다. 전체 행렬은 블록 단위로 소비되고 버려집니다. 전체 행렬은 결코 조립되지 않습니다.
장애물은 소프트맥스 (softmax)입니다. 일반적으로 소프트맥스는 수치적 안정성 (numerical safety)을 위해 행의 최댓값을 빼고, 지수 함수를 취한 뒤, 전체 합으로 나누어야 하므로 한 번에 행 전체를 필요로 합니다. 만약 스코어가 한 번에 하나의 블록씩 도착한다면, 행 전체를 볼 수 없습니다. 해결책은 온라인 소프트맥스 (online softmax)입니다. 실행 중인 최댓값 $m$과 실행 중인 합 $l$을 유지하며, 각 블록이 도착할 때마다 이를 업데이트하는 방식입니다.
새로운 블록이 이전에 보았던 것보다 더 큰 최댓값을 나타내면, 이전의 모든 지수 값들은 더 작은 최댓값을 기준으로 계산되었으므로 현재 스케일이 잘못된 상태가 됩니다. 따라서 새로운 블록의 기여분을 더하기 전에, 실행 중인 합과 실행 중인 출력 누산기 (output accumulator)에 하나의 보정 계수인 $ ext{exp}(m_{ ext{old}} - m_{ ext{new}})$를 곱해줍니다. 분자와 분모에 동일한 보정이 적용되기 때문에, 최종 비율은 정확합니다. 블록당 한 번씩 적용되는 이 저렴한 재스케일링 (rescale)이 전체 알고리즘의 핵심입니다.
역전파 (backward pass) 과정도 동일한 처리를 거칩니다. 그래디언트 (gradient)를 위해 $N imes N$ 행렬을 저장하는 대신, Flash Attention은 아주 작은 행별 통계치만을 저장하고 역전파 중에 각 스코어 타일을 즉석에서 재계산 (recompute)합니다. 재계산에는 몇 번의 추가적인 곱셈이 소요되지만, 이 연산은 메모리 대역폭에 제한 (memory-bound)을 받으며 해당 곱셈 비용은 거의 무시할 수 있는 수준이기에 명백한 이득입니다. 학습 메모리는 선형 (linear) 상태를 유지합니다.
근사치가 아닌 정확한 방식 (Exact, not approximate)
이 점은 반복할 가치가 있습니다. Flash Attention은 희소 어텐션 (sparse attention)도 아니고, 저차원 근사 (low-rank shortcut)도 아니며, 손실이 발생하는 근사치 (lossy approximation)도 아닙니다. 이는 표준 어텐션 (standard attention)과 동일한 수학적 함수를 계산하되, 느린 메모리에 접근하는 순서만 다를 뿐입니다. 출력값은 단순한 방식 (naive version)과 부동 소수점 오차 수준까지 일치합니다. 저는 인터랙티브 페이지에 작은 수치 검증 기능을 넣어 두었습니다. 동일한 무작위 입력에 대해 단순한 전체 행렬 어텐션 (naive full-matrix attention)과 타일링된 온라인 소프트맥스 어텐션 (tiled online-softmax attention)을 비교했을 때, 가장 큰 차이는 약 1e-16 수준입니다.
따라서 메모리는 $O(N^2)$에서 $O(N)$으로 감소하고, HBM에 대한 읽기 및 쓰기 작업이 급감하며, 애초에 이 작업이 메모리 대역폭에 제한 (memory-bound)되어 있었기 때문에 실제 실행 시간 (wall-clock time)도 크게 단축됩니다. 연산량 (FLOPs)은 거의 변하지 않습니다.
이것이 중요한 이유
어텐션의 메모리 사용량을 선형 (linear)으로 만든 것은 컨텍스트 윈도우 (context window)가 수천 토큰에서 수만, 수십만 토큰으로 급증할 수 있었던 큰 이유입니다. 동일한 하드웨어에서 학습 비용이 저렴해졌고, 긴 프롬프트 추론 (long-prompt inference)이 실용화되었으며, 이 변화가 정확하고 즉시 적용 가능한 (drop-in) 방식이었기 때문에 거의 모든 곳에서 채택되었습니다. FlashAttention-2는 병렬성을 개선하여 처리량 (throughput)을 대략 두 배로 높였으며, FlashAttention-3는 비동기 실행 (asynchronous execution)과 저정밀도 경로 (low-precision paths)를 통해 최신 하드웨어를 목표로 하면서도 동일한 정확성 보장을 유지합니다.
실무에서 커널 (kernel)을 직접 작성하는 일은 거의 없습니다. PyTorch에서는 F.scaled_dot_product_attention이 이를 자동으로 호출합니다. Hugging Face에서는 attn_implementation="flash_attention_2"를 전달하기만 하면 됩니다. 플래그 하나로, 결과는 같으면서, 메모리는 훨씬 적게 사용합니다.
타일링 (tiling), 온라인 소프트맥스 (online softmax), 그리고 정확성 검증을 여기에서 직접 확인해 보세요: https://dev48v.infy.uk/ai/days/day25-flash-attention.html
AI 자동 생성 콘텐츠
본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기