Hub Bucket을 통한 1조 개의 파라미터 전송: TRL의 Delta Weight Sync
요약
비동기 강화학습(Async RL) 시 모델 전체를 전송하는 대신, 변경된 가중치(delta)만 전송하여 대역폭을 획기적으로 절감하는 기술을 소개합니다. TRL의 Delta Weight Sync를 통해 Hugging Face Bucket을 활용한 분리형 훈련(disaggregated training)이 가능해집니다.
핵심 포인트
- 연속된 RL 단계 간 가중치 변화는 전체의 2% 미만임
- Sparse safetensors를 통해 페이로드 크기를 대폭 감소
- 공유 오브젝트 스토어를 활용한 분리형 훈련 구조 구현
- 훈련과 추론 클러스터 간 물리적 거리 제약 해소
요약(TL;DR), 여러분은 모델을 훈련시켜야 하고 저희는 그 점을 존중하기에 빠르게 전달합니다:
- 비동기 강화학습 (Async RL)에는 숨겨진 비밀이 있습니다. 매 단계마다 트레이너 (trainer)가 모델 전체를 추론 엔진 (inference engine)으로 전송해야 한다는 점입니다. bf16 형식의 7B 모델의 경우 14 GB이며, 프런티어 (frontier) 1T 모델 체크포인트의 경우 매 단계마다 테라바이트 단위에 달합니다.
- 알고 보니 그럴 필요가 없습니다. 연속된 두 번의 RL 옵티마이저 (optimizer) 단계 사이에는 bf16 가중치 (weights)의 약 99%가 비트 단위로 동일하며 (최악의 경우에도 98% 미만으로 떨어지지 않습니다), 실제 차이 (delta)는 매우 작습니다.
- 저희는 변경된 요소만을 희소 safetensors (sparse safetensors) 파일로 인코딩하여 Hugging Face Bucket에 업로드하고, vLLM이 이를 가져오도록 지시하는 TRL PR을 반영했습니다. Qwen3-0.6B 모델의 경우, 단계당 페이로드 (payload)가 1.2 GB에서 20~35 MB로 급감했습니다.
- 금상첨화로, 저희는 트레이너는 한 서버에, vLLM은 Hugging Face Space에, Wordle 환경은 다른 Space에 위치하며 가중치가 단일 Hub bucket을 통해 흐르는 완전한 분리형 훈련 (disaggregated training)을 실행했습니다. 공유 클러스터도, RDMA도
Fireworks는 그들의 포스트 "Frontier RL Is Cheaper Than You Think"에서 이와 관련하여 매우 기억에 남는 수치를 제시했습니다. 그들의 설정 기준, fp8 형식의 프론티어(Frontier) 1T-파라미터 체크포인트의 경우, 전체 스냅샷(full snapshot)은 1024 GiB에 달하며, 통념에 따르면 롤아웃 플릿(rollout fleet)을 업데이트할 때마다 이 전체를 전송해야 합니다. 이는 사람들이 메가 클러스터(mega-clusters), RDMA 패브릭(RDMA fabrics), 그리고 전용 교차 지역 링크(dedicated cross-region links)를 포함한 다이어그램을 그리기 시작하게 만드는 수준의 숫자입니다. 그들이 측정한 인접 체크포인트 간의 평균 델타(delta) 값은 **20.3 GiB, 즉 전체 모델의 1.98%**였으며, "bf16 형식의 가중치 중 98% 이상이 연속된 체크포인트 간에 비트 단위로 동일(bit-equivalent)하게 유지됨"을 확인했습니다.
Cursor의 Composer 2 보고서도 이와 유사한 이야기를 들려줍니다. 그들은 훈련(training)과 추론(inference)을 서로 다른 지역에서 실행하며, 이를 공유 S3 버킷 (shared S3 bucket)(그들의 정확한 표현)을 통해 하나로 엮습니다. 트레이너(trainer)는 매 훈련 단계마다 압축된 가중치 차이(weight diffs)를 이 버킷에 업로드합니다. 각 클러스터는 공유된 델타 체인(delta chain)으로부터 독립적으로 다운로드하고 재구성하며, 이는 "훈련 클러스터와의 직접적인 연결이 필요하지 않음"을 의미합니다. 양측은 파라미터에 대해 직접적으로 소통하지 않습니다. 버킷이 곧 통신 선로(wire) 역할을 합니다.
두 논문은 세 가지 사항에 동의하며, 이 포스트의 나머지 내용은 본질적으로 이 내용들을 충실히 오픈 소스로 번역한 것이기에 우리는 이 점을 천천히 반복하고자 합니다:
- 인접한 두 RL 단계 사이에서 대부분의 가중치는 실제로 변하지 않았습니다.
- 변경된 부분만 전송한다면, 대역폭 비용(bandwidth bill)을 대략 두 자릿수(two orders of magnitude) 정도 절감할 수 있습니다.
- 이러한 미세한 차이(diffs)를 공유 오브젝트 스토어(shared object store)를 통해 라우팅한다면, 트레이너와 추론 클러스터가 반드시 동일한 데이터 센터에 존재할 필요가 없습니다.
유일하게 부족했던 점은 이 이야기를 pip install 할 수 있는 버전이었습니다.
그래서 우리가 직접 만들었습니다.
무언가를 연결하기 전에, 왜 이 게임 자체가 승산이 있는지 이해할 가치가 있습니다. "가중치의 98%가 변하지 않는다"는 주장은 데모에서는 작동하지만 실제 환경에서는 무너지는 그런 숫자 중 하나처럼 의심스럽게 들릴 수 있습니다. 하지만 그렇지 않습니다. 이는 RL(강화학습)이 사용하는 학습률(learning rates)에서 bf16 산술 연산이 작동하는 방식에서 비롯됩니다.
bf16 숫자는 7개의 가수(mantissa) 비트를 가집니다. 두 개의 연속된 2의 거듭제곱 사이에는 정확히 $2^7 = 128$개의 표현 가능한 값이 존재하므로, $|w|$ 주변의 인접한 bf16 숫자 사이의 간격은 대략 $|w| imes 2^{-7}$입니다. 업데이트 값이 그 간격의 절반 미만일 때, 즉 $|\Delta w| < |w|/256$일 때, 업데이트는 bf16 캐스팅(cast) 과정에서 흡수되어 버립니다. 이것이 PULSE 연구진이 그들의 Figure 3에서 보여준 "bf16 가시성 임계값(bf16 visibility threshold)"입니다.
이제 Adam 옵티마이저가 무엇을 하는지 살펴보겠습니다. 예를 들어 RL 학습률이 $3 imes 10^{-6}$일 때, 단일 가중치에 대한 업데이트는 다음과 같습니다:
정규화된 단계(normalized step) $\hat{m}/(\sqrt{\hat{v}}+\epsilon)$는 대략 1의 차수이므로, $|\Delta w| \approx \eta \approx 3 \times 10^{-6}$입니다. 대부분의 가중치에 대해 $|w|$는 $10^{-2}$에서 $10^{-1}$ 사이 어딘가에 위치합니다 (PULSE는 대표적인 LLM 가중치에 대해 중앙값으로 0.019를 보고합니다). 해당 크기에서 임계값 $|w|/256$은 약 $4 imes 10^{-5}$에서 $4 imes 10^{-4}$ 정도이며, 이는 업데이트 값보다 더 큽니다.
다시 말해, 옵티마이저는 속삭이고 있지만 bf16은 이를 듣지 못하는 것입니다. 업데이트는 반올림(rounding)에 의해 흡수되고, $w$의 바이트 표현은 변하지 않으며, 추론 엔진(inference engine)의 관점에서는 이 가중치가 움직이지 않은 것이 됩니다. 이를 수억 개의 파라미터에 곱하면, 어떠한 근사(approximation) 없이도 공짜로 99% 이상의 희소성(sparsity) 수치를 얻게 됩니다.
이것은 바로 PULSE 논문(Mihai & Belilovsky, 2026)에서 공식화된 논거입니다. 그들은 두 가지 임계값(threshold)을 정의합니다. 흡수 경계 (absorption bound) $10 ext{η}$는 Adam 업데이트의 보수적인 최악의 경우(worst case)이며, 유효 경계 (effective bound) $ ext{η}$는 실제로 적용되는 영역입니다. **bf16 가시성 임계값 (bf16 visibility threshold)**은 $|w|/256$입니다. 업데이트가 가시성 임계값보다 낮을 때마다 해당 값은 흡수되며, bf16 바이트(byte)는 변경되지 않습니다. 그들의 Figure 3는 대표적인 LLM 가중치(weights) 구름에 대해 두 경계를 모두 플롯하며, 결론은 명확합니다: $ ext{η} = 3 imes 10^{-6}$에서 흡수 경계 자체가 모델의 거의 모든 가중치에 대해 이미 가시성 임계값 아래에 위치합니다. 그들은 Qwen2.5 (0.5B/1.5B/7B), Llama-3.2-3B, 그리고 Gemma-3-4B에 대해 이를 경험적으로 측정하였으며, **400회의 학습 단계(training steps) 동안 표준 편차가 0.2~0.4%인 약 99%의 평균 단계별 희소성 (per-step sparsity)**을 일관되게 발견했습니다. 최악의 경우의 단계에서도 98% 이상을 유지합니다. 따라서 1% 미만이 변경되었다는 것은 운 좋은 측정치가 아니라, 산술적으로 보장되는 결과입니다.
우리는 이를 분석적으로 예측할 필요가 없습니다 (실제로 Adam의 $m$ 및 $v$ 통계치로부터 변경 마스크(change mask)를 예측하려고 시도했으나, 재현율(recall)이 30%에 그쳤습니다. 이에 대해서는 나중에 더 자세히 다루겠습니다). 우리는 그저 어떤 바이트가 뒤집혔는지(flipped) 관찰하기만 하면 됩니다. 이는 옵티마이저 단계(optimizer step) 직후에 계산되는 파라미터당 아주 작은 불리언 텐서(boolean tensor)입니다.
여기서 이야기의 두 번째 조각이 등장하며, 이 포스트는 Fireworks/Cursor의 번역을 넘어 Hugging Face의 이야기로 넘어갑니다.
A **버킷 (Bucket)**은 고빈도 객체 스토리지(high-frequency object storage)를 위해 설계된 Hub의 리포지토리(repo) 유형입니다. 커밋 의식(commit ceremony)도, PR 워크플로우도, LFS의 기이한 동작도 없습니다. 파일을 추가하고, 목록을 나열하고, 다운로드하면 됩니다. Python 인터페이스는 두 개의 함수로 구성됩니다:
from huggingface_hub import batch_bucket_files, download_bucket_files
# Trainer 측
batch_bucket_files("my-org/wordle-deltas", add=[(buffer, "deltas/step_000042.safetensors")])
...
그게 전부입니다. 두 번의 함수 호출만으로 당신의 가중치(weights)가 전송되기 시작합니다.
내부적으로 버킷(buckets)은 Hub의 콘텐츠 정의 청킹(content-defined chunking) 저장 계층인 Xet에 의해 지원됩니다. Xet은 사용자가 업로드하는 모든 파일을 살펴보고, 고정된 오프셋(fixed offsets)이 아닌 실제 콘텐츠를 기반으로 파일을 청크(chunks)로 슬라이스하며, 버킷에 이미 존재하는 모든 데이터와 비교하여 중복을 제거(deduplicate)합니다. 이 맥락에서 매우 반가운 실질적인 결과는, 우리가 희소 인코딩(sparse encoding)을 작성하기 귀찮아서 매 단계마다 전체 앵커(full anchors)를 업로드하더라도, Xet은 여전히 변경된 청크만 전송한다는 점입니다. 희소 인코딩(Sparse encoding) + Xet 스택: 우리는 이동한 만큼만 비용을 지불하며, 그 비용은 단 한 번만 지불합니다.
이는 Fireworks와 Cursor가 지향하는 "공유 S3 버킷(shared S3 bucket)"의 오픈 소스 버전과 유사하지만, 저장 계층이 이미 콘텐츠 해싱(content hashing)을 알고 있고, 기존의 HF 토큰에 이미 권한이 있으며, 나머지 스택(Spaces, datasets, models)과 네이티브하게 결합된다는 점이 다릅니다.
전체 아키텍처는 정확히 세 개의 박스와 하나의 공유 기질(substrate)로 구성됩니다:
Trainer (트레이너). 어디에 있든 상관없습니다. GPU 한 개, GPU 여덟 개, USB로 연결된 H100이 달린 노트북 등 무엇이든 좋습니다. 우리는 판단하지 않습니다. 모델 가중치(weights)를 소유하고, 옵티마이저(optimizer)를 실행하며, 희소 델타(sparse deltas)를 방출합니다.
HF Bucket (HF 버킷). 단일 리포지토리(repo)이며, 두 개의 접두사(prefixes)를 가집니다: 가끔씩 수행하는 전체 스냅샷을 위한 anchors/와 그 사이의 희소 패치(sparse patches)를 위한 deltas/입니다. 이것이 양측이 합의하는 유일한 지점입니다.
vLLM rollout server (vLLM 롤아웃 서버). 어디에 있든 상관없으며, 결정적으로 반드시 트레이너가 있는 곳일 필요는 없습니다. 버킷에서 데이터를 가져와 델타를 적용하고 롤아웃(rollouts)을 서빙합니다.
Environment (환경). 통상적인 방식(HTTP, 함수 호출, 환경이 지원하는 무엇이든)으로 롤아웃 서버에 연결됩니다.
내재화해야 할 속성이자, Cursor의 논문이 강력하게 주장하며 여기서도 그대로 유지되는 속성은 다음과 같습니다: 트레이너와 롤아웃 서버는 가중치(weights)에 대해 서로 대화하지 않습니다. 그들은 {"repo_id": ..., "filename": ...}를 포함하는 아주 작은 POST 요청만을 교환하며, 그것이 제어 평면(control plane)의 전부입니다. 실제 바이트 전송은 각 측과 버킷 사이에서 병렬로 발생하며, 공유 네트워크 패브릭(network fabric)을 사용하지 않습니다.
이것이 실무에서 중요한 이유는 다음과 같습니다:
- 롤아웃 서버 (rollout server)는 다른 리전(region), 다른 클라우드, 또는 Hugging Face Space 내부의 NAT 뒤에 있을 수 있습니다. 시스템은 이를 신경 쓰지 않습니다.
- N개의 추론 복제본 (inference replicas)이 동일한 버킷에서 동일한 델타 (delta)를 가져올 수 있으며, Xet은 이들 전체에 걸쳐 바이트를 중복 제거 (deduplicate) 합니다.
- 트레이너 (trainer)는 추론 복제본이 몇 개인지, 어디에 있는지, 또는 그중 하나가 방금 충돌했는지 여부를 알 필요가 전혀 없습니다.
트레이너는 쓰고, 복제본은 읽습니다. 허브 (Hub)가 배관 (plumbing) 작업을 수행합니다.
이제 내부 구조를 살펴보겠습니다. 프로토콜은 네 부분으로 구성됩니다: 와이어 포맷 (wire format), 버킷 레이아웃 (bucket layout), 30줄짜리 vLLM 확장 (extension), 그리고 트레이너 측 변경 감지기 (change detector)입니다. 솔직히 말해서 들리는 것보다 코드 양이 더 적습니다.
우리는 디스크 및 와이어 포맷으로 safetensors를 선택했습니다. 이는 이미 허브 (Hub)의 표준 체크포인트 (checkpoint) 포맷이며, 모든 합리적인 프레임워크가 이를 읽을 수 있고, 헤더에 임의의 문자열 메타데이터 (metadata)를 담을 수 있습니다. 그 메타데이터 필드가 바로 우리가 프로토콜을 숨기는 곳입니다.
버킷에는 두 종류의 파일이 있습니다.
**앵커 (Anchors)**는 일반적인 체크포인트처럼 보입니다: 파라미터당 하나의 텐서 (tensor), 전체 bf16 가중치 (weights), $N$번의 동기화(sync)마다 작성됩니다 (기본값은 $N=10$입니다).
anchors/step_000010.safetensors
├── model.layers.0.self_attn.q_proj.weight (bf16, full)
├── model.layers.0.self_attn.k_proj.weight (bf16, full)
...
**델타 (Deltas)**가 흥미로운 부분입니다. 실제로 변경된 각 파라미터에 대해 두 개의 항목을 저장합니다: 요소 인덱스 (element indices)를 담은 평탄한 int32 텐서와, 해당 인덱스의 값들을 담은 bf16 텐서입니다.
deltas/step_000011.safetensors
├── model.layers.0.self_attn.q_proj.weight.indices (int32, [num_changed])
├── model.layers.0.self_attn.q_proj.weight.values (bf16, [num_changed])
...
이러한 선택으로 얻는 몇 가지 좋은 결과는 다음과 같습니다:
- 델타는 하나의 파일입니다. Python에서
safe_open(...)을 사용하여 파일을 열고 그 안의 모든 텐서를 검사할 수 있습니다. 독점적인 프레이밍 (framing), 길이 접두사 (length prefixes), 버전 핸드셰이크 (version handshake)가 없습니다. - 메타데이터는 자기 기술적 (self-describing)입니다. 수신자는
sparse=True/False를 읽습니다.
및 브랜치(branches)가 있습니다. 별도의 매니페스트(manifest)는 존재하지 않습니다. 추론(inference) 측면에서는 mmap을 통한 제로 카피(zero-copy) 방식으로 동작하며, 이는 몇 초마다 이 작업을 수행할 때 매우 중요합니다.
전송 주기(cadence)는 간단합니다: 매 N번째 스텝마다 앵커(anchor)를 찍고, 그 사이에는 델타(delta)를 기록합니다. 두 데이터 모두 anchors/ 및 deltas/ 접두사(prefix) 아래의 동일한 버킷(bucket)에 저장됩니다.
각 새로운 추론 레플리카(inference replica)는 가장 최근의 앵커를 가져온 다음, 그 이후의 델타를 다시 재생(replay)하기만 하면 됩니다.
트레이너(trainer)는 어떤 bf16 요소가 실제로 변경되었는지 알아야 합니다. 이를 위해 옵티마이저(optimizer)에 프리 스텝(pre-step) 및 포스트 스텝(post-step) 훅(hook)을 등록하는 아주 작은 BF16ChangeDetector를 사용합니다.
class BF16ChangeDetector:
def __init__(self, model, optimizer):
self._pre_step_bf16: dict[str, torch.Tensor] = {}
...
PR(Pull Request)의 실제 코드는 약간 더 복잡한 배관 작업(plumbing)이 포함되어 있습니다 (data_ptr()를 통해 옵티마이저 파라미터 객체를 모델 파라미터와 매칭해야 하는데, 이는 Accelerate가 이들을 서로 다른 파이썬 객체로 래핑(wrap)하기 때문입니다). 하지만 그 핵심 아이디어는 매우 간단합니다: 스냅샷(snapshot), 스텝(step), 차이(diff).
AI 자동 생성 콘텐츠
본 콘텐츠는 Hugging Face Blog의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기