DASH: 단일 GPU에서 몇 분 만에 수행하는 하이브리드 어텐션(Hybrid Attention)을 위한 빠른 미분 가능한 아키텍처 탐색
요약
DASH는 하이브리드 어텐션 아키텍처 설계를 위해 제안된 빠르고 미분 가능한 탐색 프레임워크입니다. 이산적인 레이어별 연산자 배치를 연속적인 아키텍처 로짓으로 완화하여 탐색 효율성을 극대화했으며, 단일 GPU에서 단 20분 만에 최적의 아키텍처를 찾아낼 수 있습니다. 기존 Jet-Nemotron 방식 대비 훨씬 적은 토큰을 사용하면서도 더 강력한 성능을 입증했습니다.
핵심 포인트
- 이산적 연산자 배치를 연속적인 아키텍처 로짓으로 완화하여 미분 가능한 탐색 가능
- 모델 및 연산자 가중치를 고정한 상태에서 아키텍처 전용 탐색을 수행하여 효율성 증대
- Jet-Nemotron 대비 약 0.006% 수준의 토큰만 사용하여 탐색 시간 및 비용 획기적 단축
- Qwen2.5-3B-Instruct 기반 실험에서 기존 베이스라인 및 Jet-Nemotron보다 우수한 성능 달성
하이브리드 어텐션 (Hybrid attention) 아키텍처는 모델 품질을 유지하면서 LLM 추론 효율성을 향상시키기 위한 점점 더 중요한 패러다임이 되고 있으며, 이에 따라 하이브리드 아키텍처 설계가 핵심적인 문제로 부상하고 있습니다. 기존의 설계 방식은 레이어별 연산자 할당을 위해 수동적인 경험적 규칙이나 프록시 기반의 선택기 신호 (proxy-based selector signals)에 의존하는 경우가 많습니다. Jet-Nemotron과 같은 최근의 NAS 스타일 시스템은 자동화된 하이브리드 아키텍처 탐색의 가능성을 보여주었습니다. 그러나 Jet-Nemotron의 PostNAS 탐색 단계만으로도 200B 개의 토큰을 사용하며, 이는 이러한 탐색 파이프라인을 하이브리드 아키텍처 설계를 위한 일상적인 방법으로 사용하기 어렵게 만듭니다. 우리는 하이브리드 어텐션 아키텍처 설계를 위한 빠른 미분 가능한 탐색 프레임워크인 DASH를 소개합니다. DASH는 이산적인 레이어별 어텐션 연산자 배치를 연속적인 아키텍처 로짓 (architecture logits)으로 완화하고, 재사용 가능한 교사 정렬 선형 후보군 (teacher-aligned linear candidates)을 준비하며, 모델 및 연산자 가중치를 고정한 상태에서 아키텍처 전용 탐색을 수행하여 탐색 효율성을 크게 향상시킵니다. Qwen2.5-3B-Instruct에서 DASH는 기존의 포괄적인 선택기 스타일 하이브리드 어텐션 설계 베이스라인들을 지속적으로 능가하며, 직접적인 미분 가능 탐색이 더 강력한 하이브리드 아키텍처를 발견할 수 있음을 보여줍니다. 또한, DASH는 공개된 Jet-Nemotron 모델보다 더 강력한 RULER 성능을 달성하는 동시에, 중첩되는 짧은 컨텍스트 및 일반 벤치마크에서도 경쟁력을 유지합니다. 특히, 각 DASH 탐색 실행은 단 12.3M 개의 토큰만을 사용하며 단일 RTX Pro 6000 GPU에서 약 20분이 소요되는데, 이는 Jet-Nemotron에서 보고된 PostNAS 탐색 토큰의 단 0.006%에 해당합니다. 이러한 결과는 고품질의 하이브리드 어텐션 아키텍처를 몇 분 단위의 미분 가능한 탐색을 통해 얻을 수 있음을 시사하며, 하이브리드 아키텍처 설계에 있어 유망한 방향을 제시합니다.
AI 자동 생성 콘텐츠
본 콘텐츠는 arXiv cs.LG의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기