$x$-예측 흐름으로서의 마스크 확산 디코딩 (Masked Diffusion Decoding as $x$-Prediction Flow)
요약
MDLMs의 이진 디코딩 방식을 개선하기 위해 마스크 예측을 연속적인 x-예측 흐름으로 재해석한 연구입니다. 신뢰도 기반 비동기식 업데이트와 경량 정책 네트워크를 통해 디코딩 효율성을 극대화했습니다.
핵심 포인트
- 기존의 이진 디코딩 방식을 연속적인 흐름(continuous flow) 방식으로 전환
- 신뢰도 기반 비동기식 업데이트를 통한 토큰별 최적화된 디코딩 수행
- 경량 정책 네트워크 도입 및 강화학습(RL)을 통한 학습 공식화
- LLaDA 적용 시 HumanEval에서 25%의 예산만으로 성능 97% 달성
마스크 확산 언어 모델 (Masked Diffusion Language Models, MDLMs)은 토큰을 반복적으로 언마스킹 (unmasking)하여 텍스트를 생성하지만, 표준 디코더는 각 단계를 이진 동작 (binary action)으로 축소합니다. 즉, 특정 위치는 단일 토큰으로 확정되거나 완전히 마스크 (masked) 상태로 남을 뿐, 그 사이의 부분적인 믿음 (partial belief)에 대한 표현이 없습니다. 이러한 '전부 아니면 전무 (all-or-nothing)' 방식은 풍부한 예측 정보를 버리고 성급하며 돌이킬 수 없는 결정을 강요하여, 제한된 디코딩 예산 (decoding budget) 하에서 낮은 성능을 초래합니다. 본 논문에서 우리는 마스크 예측을 클린 상태 예측 ($x$-prediction)으로 재해석하며, 이것이 입력 임베딩 공간 (input embedding space)에서 연속적인 흐름 (continuous flow)을 유도하는 데 사용될 수 있음을 보여줍니다. 이러한 관점을 바탕으로, 우리는 MDLMs를 위한 연속적 디코딩 프레임워크를 제안하며, 여기서 토큰은 각 확산 단계 (diffusion step)에서 부분적인 진행 상황을 축적할 수 있고 수정 가능한 상태로 유지될 수 있습니다. 언어의 위치별로 불균형한 문맥적 제약 (contextual constraints)에 맞추기 위해, 우리는 이미지 확산 (image diffusion)의 전역적 동기식 스케줄 (globally synchronous schedule)을 확산 진행 상황이 토큰별로 축적되는 신뢰도 기반 비동기식 업데이트 (confidence-based asynchronous update)로 대체합니다. 또한, 경량 정책 네트워크 (lightweight policy network)를 도입하고 그 학습을 강화학습 (Reinforcement Learning, RL) 문제로 공식화합니다. 사전 학습된 LLaDA에 적용했을 때, 우리의 연속적 디코더는 HumanEval 데이터셋에서 디코딩 예산의 25%만 사용하고도 성능의 97%에 도달했습니다.
AI 자동 생성 콘텐츠
본 콘텐츠는 arXiv cs.CL (NLP)의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기