본문으로 건너뛰기

© 2026 Molayo

Dev.to헤드라인2026. 05. 23. 22:26

Multi-Head Latent Attention (MLA)

요약

DeepSeek-V2/V3 및 Kimi K2.x 모델에 적용된 Multi-Head Latent Attention(MLA) 메커니즘을 설명합니다. 저차원 잠재 공간 투영을 통해 KV 캐시를 획기적으로 압축하여 추론 효율성을 극대화하는 기술적 원리를 다룹니다.

핵심 포인트

  • 저차원 잠재 투영을 통한 KV 캐시 압축 기술
  • 표준 MHA 대비 5~10배 이상의 캐시 압축 달성
  • DeepSeek-V3 기준 최대 64배의 압축률 구현
  • 프리픽스 캐싱 및 페이지드 어텐션 구현 방식의 변화 유도

저차원 투영 (Low-rank projections)을 통한 KV 캐시 (KV cache) 압축 — DeepSeek-V2/V3 및 Kimi K2.x의 기반이 되는 어텐션 메커니즘.

왜 이것이 중요한가
Multi-Head Latent Attention (MLA)는 DeepSeek-V2, DeepSeek-V3, 그리고 Kimi K2.x 모델에서 표준 Multi-Head Attention (MHA)을 대체하는 어텐션 변형 방식입니다. 헤드당 전체 KV 쌍을 캐싱하는 대신, MLA는 이를 저차원 잠재 공간 (low-dimensional latent space)으로 투영하여 품질 저하를 최소화하면서도 5~10배의 KV 캐시 (KV cache) 압축을 달성합니다. MLA는 프리픽스 캐싱 (prefix caching), 청크드 프리필 (chunked prefill), 그리고 페이지드 어텐션 (paged attention)이 구현되어야 하는 방식을 변화시킵니다.

공식 정의

표준 Multi-Head Attention (MHA)
입력 $\mathbf{X} \in \mathbb{R}^{n \times d}$ 에 대해, MHA는 헤드별 투영을 계산합니다:
$\mathbf{Q}_h = \mathbf{X} \mathbf{W}_Q^{(h)}, \quad \mathbf{K}_h = \mathbf{X} \mathbf{W}_K^{(h)}, \quad \mathbf{V}_h = \mathbf{X} \mathbf{W}_V^{(h)}$
여기서 $\mathbf{W}_Q^{(h)} \in \mathbb{R}^{d \times d_k}, \quad \mathbf{W}_K^{(h)} \in \mathbb{R}^{d \times d_k}, \quad \mathbf{W}_V^{(h)} \in \mathbb{R}^{d \times d_v}$ 입니다.
토큰당 KV 캐시 (KV cache) 크기: $2 \times n_h \times d_k$ 개의 요소.

MLA: 저차원 잠재 투영 (Low-Rank Latent Projection)
MLA는 헤드별 KV 투영을 공유된 저차원 잠재 압축 (shared low-rank latent compression)으로 대체합니다:
압축 (KV → Latent): $\mathbf{c}^{KV} = \mathbf{X} \mathbf{W}{DKV} \in \mathbb{R}^{n \times d_c}$
여기서 $\mathbf{W}
{DKV} \in \mathbb{R}^{d \times d_c}$ 는 다운 투영 행렬 (down-projection matrix)이며, $d_c \ll n_h \times d_k$ 입니다.

압축 해제 (Decompression) (Latent → KV): $\mathbf{K}h = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)}, \quad \mathbf{V}h = \mathbf{c}^{KV} \mathbf{W}{UV}^{(h)}$ 여기서 $\mathbf{W}{UK}^{(h)} \in \mathbb{R}^{d_c \times d_k}$ 및 $\mathbf{W}{UV}^{(h)} \in \mathbb{R}^{d_c \times d_v}$는 업 프로젝션 행렬 (up-projection matrices)입니다.

토큰당 KV 캐시 (KV cache per token): 오직 $\mathbf{c}^{KV} \in \mathbb{R}^{d_c}$만이 저장됩니다 — 즉, 차원이 $d_c$인 단일 벡터입니다.

압축률 (Compression Ratio) $n_h$개의 헤드와 헤드 차원 $d_k$를 가진 모델의 경우:
$\text{Compression Ratio} = \frac{2 \cdot n_h \cdot d_k}{d_c}$

DeepSeek-V3의 경우: $n_h = 128, d_k = 128, d_c = 512$ 이므로:
$\frac{2 \times 128 \times 128}{512} = 64 \times$ 압축

쿼리 압축 (Query Compression, 선택 사항) MLA는 학습 효율성을 위해 쿼리 (queries) 또한 압축합니다:
$\mathbf{c}^Q = \mathbf{X} \mathbf{W}_{DQ} \in \mathbb{R}^{n \times d_c'}$
$\mathbf{Q}h = \mathbf{c}^Q \mathbf{W}{UQ}^{(h)}$
이는 KV 캐시에 영향을 주지 않지만, 학습 중 활성화 메모리 (activation memory)를 줄여줍니다.

회전 위치 임베딩 (Rotary Position Embedding, RoPE) 처리
RoPE는 압축 해제된 쿼리 (queries)와 키 (keys)에 적용됩니다. KV 캐시를 작게 유지하기 위해, MLA는 별도의 "흡수된" (absorbed) 키 프로젝션에 RoPE를 적용합니다:
$\hat{\mathbf{K}}h = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)})$
여기서 $\mathbf{W}_{KR}^{(h)} \in \mathbb{R}^{d_c \times d_r}$ (단, $d_r \ll d_k$)는 위치 정보를 전달하는 좁은 프로젝션 (narrow projection)입니다.

캐시된 표현(cached representation)은 $\mathbf{c}^{KV}$ (위치 불가지론적 (position-agnostic))로 유지되며, RoPE 키 $\hat{\mathbf{K}}_h$는 어텐션 시점에 캐시된 잠재 표현(latent)으로부터 재계산됩니다.

핵심 개념

  1. 가중치 흡수 (Weight Absorption, 핵심 트릭)
    MLA의 결정적인 통찰은 업-프로젝션 행렬(up-projection matrices) $\mathbf{W}_{UK}^{(h)}$가 어텐션 계산 중에 쿼리 프로젝션(query projection)으로 흡수될 수 있다는 점입니다:

$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}_h \mathbf{K}_h^T}{\sqrt{d_k}}\right) \mathbf{V}_h$

압축 해제된 형태(decompressed forms)를 대입하면:

$\mathbf{Q}h \mathbf{K}h^T = (\mathbf{c}^Q \mathbf{W}{UQ}^{(h)})(\mathbf{c}^{KV} \mathbf{W}{UK}^{(h)})^T = \mathbf{c}^Q (\mathbf{W}{UQ}^{(h)} {\mathbf{W}{UK}^{(h)}}^T) {\mathbf{c}^{KV}}^T$

만약 $\mathbf{W}{absorbed}^{(h)} = \mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T \in \mathbb{R}^{d_c' \times d_c}$라고 정의한다면:

$\mathbf{Q}_h \mathbf{K}h^T = \mathbf{c}^Q \mathbf{W}{absorbed}^{(h)} {\mathbf{c}^{KV}}^T$

이는 어텐션 스코어(attention score)를 잠재 표현(latent representations)으로부터 직접 계산할 수 있음을 의미하며, 스코어 계산을 위해 $\mathbf{K}$와 $\mathbf{V}$를 명시적으로 압축 해제(decompression)할 필요가 없음을 뜻합니다. 다만, 출력을 위해서는 여전히 $\mathbf{V}$의 압축 해제가 필요합니다.

실질적인 시사점: 디코딩(decoding) 중에 전체 $\mathbf{K}$ 행렬을 실체화(materializing)하지 않고도 어텐션 스코어를 계산할 수 있습니다. 소프트맥스(softmax) 이후에는 $\mathbf{V}$만 압축 해제하면 됩니다.

  1. 분리된 RoPE 전략 (Decoupled RoPE Strategy)
    RoPE는 위치 의존적인 키(position-dependent keys)를 요구하는데, 이는 위치 불가지론적인 잠재 표현을 캐싱하는 것과 충돌합니다.

MLA는 이를 분리된 키(decoupled key)를 통해 해결합니다:

  • 콘텐츠 키 (Content key): $\mathbf{K}h^{content} = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)}$ — 잠재 형태(latent form)로 캐싱됨
  • 위치 키 (Position key): $\mathbf{K}h^{rope} = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)})$ — 크기가 작고 위치를 인식하며, 별도로 캐싱되어야 함

어텐션 점수(attention score)는 다음과 같이 계산됩니다:
$$\text{score}(q, k) = \frac{\mathbf{Q}_h^{content} \cdot {\mathbf{K}_h^{content}}^T}{\sqrt{d_k}} + \frac{\mathbf{Q}_h^{rope} \cdot {\mathbf{K}_h^{rope}}^T}{\sqrt{d_r}}$$

실질적인 영향: KV 캐시(KV cache)는 $\mathbf{c}^{KV}$ (잠재 표현)와 $\mathbf{K}_h^{rope}$ (분리된 RoPE 키)를 모두 저장합니다. 토큰당 총 캐시 크기: $d_c + n_h \times d_r$.

  1. MLA vs GQA vs MHA
속성MHAGQAMLA
KV 그룹 수$n_h$$n_h / g$1 (잠재 표현)
토큰당 캐시$2n_h d_k$$2(n_h/g)d_k$$d_c + n_h d_r$
품질기준점 (Baseline)약간의 하락유사함
어텐션 점수$\mathbf{Q} \mathbf{K}^T$$\mathbf{Q} \mathbf{K}^T$ (공유 K)$\mathbf{Q} \mathbf{K}^T$ (잠재 표현)
RoPE 호환성네이티브 (Native)네이티브 (Native)분리됨 (Decoupled)

GQA는 쿼리 그룹 간에 KV 헤드를 공유함으로써 캐시를 줄입니다. MLA는 공유된 잠재 표현으로 투영함으로써 캐시를 더욱 공격적으로 줄입니다. 품질 차이는 미미한데, 이는 업-프로젝션(up-projection) 행렬이 학습되어 헤드별 정보를 재구성할 수 있기 때문입니다.

  1. 배치 서빙(Batched Serving)에 미치는 영향

MLA는 서빙 시 메모리 대 연산(memory-vs-compute) 트레이드오프를 극적으로 변화시킵니다:

  • 메모리 제한적 디코딩 단계 (Memory-bound decoding phase): MHA의 경우, 긴 컨텍스트는 KV 캐시로 인해 GPU HBM을 고갈시킵니다.

MLA의 압축을 통해 다음과 같은 이점을 얻을 수 있습니다:

  • 더 긴 컨텍스트 윈도우 (동일 메모리 내에서 10배 더 많은 토큰 처리)
  • 더 큰 배치 크기 (더 많은 동시 요청 처리)
  • 더 나은 프리픽스 캐싱 (Prefix Caching) 히트율 (더 작은 캐시 엔트리)

연산 집약적 프리필 단계 (Compute-bound prefill phase): MLA는 압축 해제 (Decompression) 오버헤드를 추가하지만, 이는 다음과 같이 분산(Amortized)됩니다:

  • 프리필은 이미 연산 집약적입니다 ($O(n^2)$ attention)
  • 업-프로젝션 (Up-projection)을 위한 추가적인 행렬 곱셈 (Matmuls)은 레이어당 $O(n \times d_c \times d_v)$입니다.
  • 최종 효과: 미미한 프리필 속도 저하, 압도적인 디코딩 속도 향상
  1. MLA + 투기적 디코딩 (Speculative Decoding)
    이 부분은 Siraj의 EAGLE-3 작업과 관련하여 매우 흥미로운 지점입니다:
  • 초안 모델 (Draft model) 제약 사항: 초안 모델은 타겟 모델의 MLA 프로젝션 (Projections)과 호환되는 잠재 KV 상태 (Latent KV states)를 생성해야 합니다.
  • 단순히 더 작은 MHA 모델을 초안 모델로 사용하는 것은 KV 형식 불일치를 발생시킵니다.
  • EAGLE-3의 트리 기반 투기 (Tree-based speculation)는 잠재(Latent) $\rightarrow$ 압축 해제(Decompressed) $\rightarrow$ 검증(Verify) $\rightarrow$ 잠재(Latent)로 이어지는 라운드트립을 처리해야 합니다.

MLA를 이용한 검증:

  • 초안 토큰은 초안 모델에 의해 생성됩니다.
  • 타겟 모델은 전체 MLA 어텐션(잠재 벡터 압축 해제, 어텐션 계산)을 실행하여 검증합니다.
  • 수락된 토큰의 KV 엔트리는 전체 KV 캐시가 아닌 잠재 캐시 ($c^{KV}$)에 추가되어야 합니다.
  • 이는 초안 모델이 다음 중 하나를 수행해야 함을 의미합니다: (a) 잠재 공간 (Latent space)에서 예측하거나, (b) KV 출력을 잠재 공간으로 프로젝션해야 합니다.

vLLM 구현 과제: vLLM의 PagedAttention은 MHA/GQA를 위해 설계되었습니다. MLA는 다음을 요구합니다:

  • KV 쌍 대신 잠재 벡터 ($d_c$)를 저장하는 수정된 페이지 테이블 (Page table)
  • 흡수된 (Absorbed) + 분리된 (Decoupled) RoPE 계산을 위한 커스텀 어텐션 커널 (Custom attention kernels)
  • 압축 해제 경로를 위한 CUDA Graph 캡처와의 통합

구현

import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    """ DeepSeek-V2/V3 및 Kimi K2.x 아키텍처와 일치하는 MLA 어텐션 레이어. """

주요 특징:
- 저차원 KV 압축 (c_KV 잠재 벡터(latent vector)만 캐시로 저장)
- 위치 인지 어텐션(position-aware attention)을 위한 분리된 RoPE (Decoupled RoPE)
- 효율적인 점수 계산을 위한 가중치 흡수 (Weight absorption)
""" def __init__ ( self , d_model : int = 4096 , n_heads : int = 128 , d_k : int = 128 , d_v : int = 128 , d_c : int = 512 , # KV 잠재 차원 (압축 대상) d_c_prime : int = 1536 , # Query 잠재 차원 d_r : int = 64 , # 헤드당 분리된 RoPE 키 차원 max_seq_len : int = 8192 , rope_base : float = 10000.0 , ): super (). __init__ () self . d_model = d_model self . n_heads = n_heads self . d_k = d_k self . d_v = d_v self . d_c = d_c self . d_c_prime = d_c_prime self . d_r = d_r # === 다운 프로젝션 (Down-projections, 압축) === self . w_dkv = nn . Linear ( d_model , d_c , bias = False ) # KV 잠재 벡터 self . w_dq = nn . Linear ( d_model , d_c_prime , bias = False ) # Q 잠재 벡터 # === 업 프로젝션 (Up-projections, 압축 해제) === # KV 업 프로젝션: 잠재 벡터 -> 헤드별 K 및 V self . w_uk = nn . Linear ( d_c , n_heads * d_k , bias = False ) self . w_uv = nn . Linear ( d_c , n_heads * d_v , bias = False ) # Q 업 프로젝션: 잠재 벡터 -> 헤드별 Q self . w_uq = nn . Linear ( d_c_prime , n_heads * d_k , bias = False ) # === 분리된 RoPE 프로젝션 (Decoupled RoPE projections) === self . w_kr = nn . Linear ( d_c , n_heads * d_r , bias = False ) # 잠재 벡터로부터의 RoPE 키 self . w_qr = nn . Linear ( d_c_prime , n_heads * d_r , bias = False ) # 잠재 벡터로부터의 RoPE 쿼리 # === 출력 프로젝션 (Output projection) === self . w_o = nn . Linear ( n_heads * d_v , d_model , bias = False ) # RoPE 주파수 inv_freq = 1.0 / ( rope_base ** ( torch . arange ( 0 , d_r , 2 ). float () / d_r )) self . register_buffer ( ' inv_freq ' , inv_freq ) def _apply_rope ( self , x : torch . Tensor , seq_len : int ) -> torch . Tensor : """ [batch, seq, n_heads, d_r] 형상의 텐서에 회전 위치 임베딩 (Rotary Position Embedding)을 적용합니다. """ t = torch . arange ( seq_len , device = x . device , dtype = self . inv_freq . dtype ) freqs = torch . outer ( t , self . inv_freq ) # [seq, d_r//2] cos = freqs .

cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2] sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).
flatten(-2)
return rotated
def forward(self, x: torch.Tensor, kv_cache: torch.Tensor = None, start_pos: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
start_pos: Position offset for RoPE
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
"""
B, S, _ = x.shape # === Step 1: Compress to latent space ===
c_kv = self.w_dkv(x) # [B, S, d_c] — THIS is what gets cached
c_q = self.w_dq(x) # [B, S, d_c'] # === Step 2: Decompress for attention computation === # K, V up-projection from latent
k_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]
v = self.w_uv(c_kv) # [B, S, n_heads * d_v]
q_content = self.w_uq(c_q) # [B, S, n_heads * d_k] # Reshape to multi-head format
q_content = q_content.view(B, S, self.n_heads, self.d_k)
k_content = k_content.view(B, S, self.n_heads, self.d_k)
v = v.view(B, S, self.n_heads, self.d_v)
# === Step 3: Decoupled RoPE === # Project to rope-specific dimensions and ap

AI 자동 생성 콘텐츠

본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.

원문 바로가기
0

댓글

0