
Blackwell의 NVIDIA NVFP4, MaxText에서 JAX 학습 속도 1.8배 향상
요약
NVIDIA Blackwell GPU의 NVFP4 4비트 포맷이 JAX 기반 MaxText 환경에서 FP8 대비 최대 1.8배의 학습 속도 향상을 제공합니다. 이 기술은 전용 FP4 텐서 코어를 통해 메모리 점유율을 낮추면서도 70B 파라미터 규모까지 정확도 손실 없이 효율적인 학습을 지원합니다.
핵심 포인트
- NVFP4 도입으로 FP8 대비 학습 처리량 1.8배 향상
- Blackwell 전용 FP4 텐서 코어로 메모리 대역폭 병목 완화
- 70B 파라미터 모델까지 정확도 저하 없이 사용 가능
- Google MaxText 라이브러리에서 네이티브 지원 시작
NVIDIA의 Blackwell GPU에 탑재된 NVFP4 4비트 포맷은 JAX/MaxText 환경에서 FP8 대비 최대 1.8배의 학습 속도 향상을 제공합니다. NVIDIA는 최대 70B(700억) 파라미터 규모의 모델에 대해 FP8 대비 정확도 손실이 없다고 주장합니다.
주요 사실
- NVFP4는 Blackwell에서 FP8 대비 1.8배의 학습 속도 향상을 제공합니다.
- 이 포맷은 두 개의 4비트 값을 단일 8비트 레지스터에 패킹합니다.
- 최대 70B 파라미터 모델에 대해 정확도 손실이 없다고 주장되었습니다.
- MaxText는 이제 네이티브 FP4 지원을 포함합니다.
- Blackwell에는 H100에는 없는 전용 FP4 텐서 코어 (Tensor Cores)가 있습니다.
NVIDIA는 Blackwell GPU를 위한 4비트 부동 소수점 정밀도 포맷인 NVFP4를 발표했으며, 이는 JAX를 기반으로 구축된 Google의 MaxText LLM 학습 라이브러리에 통합되었습니다. NVIDIA 기술 블로그에 따르면, 이 포맷은 두 개의 4비트 값을 단일 8비트 레지스터에 패킹하여, 공유 지수 (Shared Exponent) 방식을 통해 동적 범위 (Dynamic Range)를 유지하면서 FP8 대비 산술 밀도 (Arithmetic Density)를 효과적으로 두 배로 높입니다. NVIDIA는 GPT-3 175B 모델 학습 실행에서 NVFP4를 벤치마킹하여 1.8배의 처리량 (Throughput) 개선을 달성했으며, 최대 70B 파라미터 모델까지 정확도 저하가 보고되지 않았습니다. 회사는 더 큰 규모의 모델에 대한 결과나 전체 어블레이션 테이블 (Ablation Tables)은 공개하지 않았습니다.
왜 지금 FP4가 중요한가
모델 크기가 1조(trillion) 파라미터 임계값을 넘어서면서, NVIDIA가 저정밀도 학습 (lower-precision training)을 광범위하게 추진하는 시점과 맞물려 있습니다. Blackwell 아키텍처에는 Hopper (H100) GPU에는 없는 전용 FP4 텐서 코어 (tensor cores)가 포함되어 있습니다. 이는 메모리 대역폭 (memory bandwidth)이 병목 현상이 되는 사전 학습 (pre-training) 및 미세 조정 (fine-tuning) 워크로드에서 Blackwell에 구체적인 이점을 제공합니다. 즉, 파라미터당 메모리 점유율 (memory footprint)을 FP16 대비 2배, FP8 대비 1.5배 줄여줍니다. 175B 모델의 경우, FP4를 사용하면 FP16 대비 약 87 GB를 절약할 수 있으며, 이는 잠재적으로 더 큰 배치 크기 (batch sizes)를 가능하게 하거나 파이프라인 병렬성 (pipeline parallelism)을 줄일 수 있음을 의미합니다.
JAX 생태계 관점
Google의 오픈 소스 LLM 학습 라이브러리인 MaxText가 이제 NVFP4를 네이티브로 지원합니다. 이는 MaxText가 Google DeepMind의 Gemini 모델을 위한 주요 학습 프레임워크라는 점에서 주목할 만합니다. [Google의 관계 그래프에 따르면] Google은 NVIDIA의 주요 고객인 동시에 TPU를 통해 AI 하드웨어 분야에서 경쟁자이기도 합니다. NVIDIA는 NVFP4를 MaxText에 내장함으로써, Google의 내부 학습 스택과 MaxText를 사용하는 모든 외부 사용자가 별도의 커스텀 커널 (custom kernel) 개발 없이도 Blackwell의 저정밀도를 즉시 활용할 수 있도록 보장합니다. 블로그에 따르면 이 통합은 순전파 (forward pass)와 역전파 (backward pass)를 모두 포함합니다.
정밀도 주장에 대한 검증
"FP8 대비 정확도 손실 없음"이라는 NVIDIA의 주장은 회의적인 시각이 필요합니다. 이 회사는 최대 70B 파라미터 규모의 모델로 테스트를 수행했지만, 퍼플렉시티 (perplexity) 점수, 다운스트림 태스크 (downstream task) 평가 또는 수렴 곡선 (convergence curves)을 공개하지 않았습니다. 비교를 하자면, FP8 학습은 안정성을 유지하기 위해 종종 손실 스케일링 (loss scaling)과 그래디언트 클리핑 (gradient clipping)을 필요로 하며, FP4는 양자화 노이즈 (quantization noise)를 가중시킵니다. 특히 100B 이상의 모델에 대한 독립적인 재현 없이는, FP4가 특정 워크로드(예: 미세 조정)에는 허용 가능한 수준일 수 있지만 다른 워크로드(예: 처음부터 시작하는 사전 학습)에는 적합하지 않을 수 있는 어느 정도의 성능 저하를 초래할 것이라고 가정하는 것이 안전합니다.
주목해야 할 점
175B 규모에서 FP4 정확도(accuracy)를 독립적으로 재현하는 것이 중요하며, 이상적으로는 Google DeepMind가 Blackwell 클러스터에서 MaxText를 사용하여 이를 수행해야 합니다. 또한 PyTorch에서의 FP4 지원 여부와 AMD의 MI400 시리즈가 자체적인 4비트 포맷으로 대응할지도 지켜봐야 합니다.
출처: news.google.com
원문 게시지: gentic.news
AI 자동 생성 콘텐츠
본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기