MedQA: Fine-Tuning a Clinical AI on AMD ROCm — No CUDA Required
요약
본 기술 기사는 AMD Instinct MI300X GPU와 ROCm 환경을 사용하여 임상 질문-답변 모델(MedQA)을 LoRA 방식으로 미세 조정하는 과정을 상세히 설명합니다. 이 프로젝트의 핵심은 기존 의료 AI 작업이 NVIDIA CUDA에 의존하던 관행을 깨고, 코드 변경이나 커스텀 커널 없이 순수하게 AMD ROCm 환경에서 HuggingFace 생태계(Transformers, PEFT 등)를 활용하여 성공적으로 구현했다는 점입니다. 이를 통해 192GB의 대용량 VRAM을 활용하여 Qwen3-1.7B와 같은 모델을 FP16으로 학습하고, LoRA 기법을 적용해 효율적이고 접근성이 높은 임상 AI 개발 파이프라인을 제시합니다.
핵심 포인트
- AMD ROCm 환경에서 CUDA 의존성 없이 LLM 미세 조정(LoRA)을 성공적으로 수행하여 의료 AI 분야의 하드웨어 종속성을 탈피했습니다.
- AMD MI300X의 192GB VRAM을 활용하여 Qwen3-1.7B 모델을 FP16으로 학습하며, 대용량 메모리 자원의 이점을 극대화했습니다.
- HuggingFace 생태계(Transformers, PEFT 등)가 ROCm 환경에서 원활하게 작동함을 입증함으로써, 오픈소스 AI 개발의 범용성을 크게 확장했습니다.
- MedQA 데이터셋을 사용하여 의사 질문 답변 및 임상 추론 설명을 생성하는 고성능 의료 AI 파이프라인을 구축하고 그 과정을 공유했습니다.
AMD Developer Hackathon 에서 lablab.ai 에 구축된 AMD MI300X 로 MedMCQA 를 사용하여 Qwen3-1.7B 의 LoRA fine-tuning 완전 가이드.
의사 질문 답변 (medical question answering) 은 확률적으로 매우 높은 stakes 가 있는 작업 중 하나입니다. 임상 MCQ 에서 잘못된 답을 자신 있게 선택하는 모델은 단순히 틀린 것—not dangerous — 입니다. 동시에, 대부분의 오픈소스 의료 AI 작업은 NVIDIA GPU 가 있다고 가정합니다. CUDA 는 기본이고, 나머지는 부차적인 고려사항입니다.
이 프로젝트는 그 가정을 도전합니다.
MedQA 는 ROCm 을 사용하여 AMD 하드웨어 전체로 구축된 LoRA fine-tuned 임상 질문-답변 모델입니다. 다중 선택식 의료 질문을 받아 올바른 답자 and 추론의 임상 설명을 반환합니다. 전체 학습 파이프라인 — 데이터 로딩부터 adapter export 까지 — CUDA 의존성 없이 AMD Instinct MI300X 에서 실행됩니다.
- 🤗
**HuggingFace Hub 모델:**HK2184/medqa-qwen3-lora - 🚀
**Live Demo:**HuggingFace Spaces - 💻
**GitHub:**MedQA-Medical-AI-on-AMD-ROCm
AMD Instinct MI300X 는 놀라운 하드웨어입니다: 단일 장치에서 192 GB 의 HBM3 메모리. LLM fine-tuning 에서는 VRAM 이 종종 결합 제약 (binding constraint) 입니다 — batch size, sequence length, 그리고 quantize 할지 여부를 결정합니다. 192 GB 가 사용 가능하므로 Qwen3-1.7B 를 LoRA 로 full fp16 으로 학습하고 4-bit 또는 8-bit quantization hacks 없이 수행했습니다.
더 중요한 것은 HuggingFace 생태계 — Transformers, PEFT, TRL, Accelerate — ROCm 에서 원활하게 작동한다는 것을 증명하는 것이 목표였습니다. 작동합니다. CUDA 에서 실행되는 동일한 학습 코드는 세 환경 변수를 설정하여 ROCm 에서 실행됩니다:
os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
그것이 전부입니다. 코드 변경 없음. 커스텀 커널 없음. CUDA 호환성 shim 없음.
MedMCQA 는 인도 의대 입시 시험 (AIIMS, USMLE-style) 에서 유래한 대규모 다중 선택식 질문 데이터셋입니다. 각 예시는 다음을 포함합니다:
- 임상 질문
- 네 가지 답변 옵션 (A–D)
- 올바른 답자 인덱스
- 선택적 자유 텍스트 설명 (
exp
필드)
이 프로젝트에서는 2,000 개의 학습 샘플을 사용했습니다 — 의미 있는 fine-tuning 을 빠르게 달성할 수 있음을 보여주는 의도적으로 작은 슬라이스입니다. MI300X 에서 약 5 분 동안 학습되었습니다.
기반 모델은 Alibaba 의 최신 소형 언어 모델인 Qwen/Qwen3-1.7B 입니다. 17 억 개의 파라미터를 가진 이 모델은 저렴하게 미세 조정할 수 있는 컴팩트한 크기를 가지면서도 일관된 임상적 추론을 생성할 수 있는 능력을 갖추고 있습니다. trust_remote_code=True 를 지원하며 HuggingFace Transformers 와 깔끔하게 로드됩니다.
명령微调 (instruction fine-tuning) 을 위한 프롬프트 포맷의 일관성은 매우 중요합니다. 모든 학습 예제와 모든 추론 호출은 동일한 템플릿을 사용합니다:
### Question:
{question}
### Options:
...
학습 중에는 모델이 답변과 설명까지 포함한 전체 시퀀스를 확인합니다. 추론 시에는 ### Answer:\n 까지 제공하고 모델을 거기서 완성하게 합니다.
모든 15 억 개의 파라미터를 미세 조정하는 대신 PEFT 라이브러리를 통해 LoRA (Low-Rank Adaptation) 를 사용합니다. LoRA 는 주의를 기울이는 레이어에 작은 훈련 가능한 rank-decomposition 행렬을 주입하여 기반 가중치를 고정합니다.
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
...
모델의 15 억 개 파라미터 중 약 220 만 개만 훈련됩니다. 이는 메모리 사용량을 낮추고 훈련 속도를 빠르게 합니다.
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
...
참고할 만한 몇 가지 사항:
— 우리는 표준 fp16 을 사용합니다. 초기 bfloat16 실험에서 NaN 손실을 겪었으며, fp16 으로 전환하여 완전히 해결했습니다.fp16=True, bf16=False
— 계산량을 메모리로 교환합니다. MI300X 의 192 GB VRAM 을 고려하면 필수적이지는 않지만, 작은 GPU 에서 재현성을 위해 좋은 관행입니다.gradient_checkpointing=True
— 물리적 배치 크기가 4 인 경우 유효 배치 크기는 16 입니다.gradient_accumulation_steps=4
Warmup 과 함께 Cosine LR 스케줄링— 짧은 훈련 실행에 비해 평평한 스케줄보다 더 부드러운 수렴입니다.
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(
tokenizer,
...
훈련 후 ./outputs 에 LoRA 어댑터 가중치가 포함되어 있습니다. 전체 다중 GB 모델 체크포인트가 아닌 몇 MB 의 파일들입니다.
추론 시에는 기반 모델을 로드하고 LoRA 어댑터를 연결하며 가중치를 선택적으로 병합합니다:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
...
생성 (Generation) 은 반복 패널티를 사용하여 모델이 루프를 방지하는 greedy decoding (do_sample=False) 을 사용합니다.
def generate(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
...
Question: 고혈압 위기 (hypertensive emergency)의 첫 선 치료법은 무엇인가요?
A) 경구 암로디핀
B) 정맥 라벨롤 또는 정맥 니트로프루시드
C) 경구 니피디핀
...
Model Output:
B) IV labetalol or IV nitroprusside
Explanation:
정맥 라벨롤 (베타 차단제) 또는 니트로프루시드는 혈압을 신속하게 감소시킵니다.
...
모델은 단순히 알파벳만 출력하지 않습니다. 대신 왜 그런지 설명합니다. 이것이 임상적으로 유용한 이유입니다.
파인튜닝 어댑터는 공개되어 있습니다. 리포를 클론하지 않고 직접 로드할 수 있습니다:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
...
AMD ROCm 프로젝트에는 반드시 '전쟁 스토리' 섹션이 포함되어야 합니다. 우리가 겪었던 문제:
| Challenge | Root Cause | Fix |
|---|---|---|
| NaN loss | Mixed precision instability | Switched from bfloat16 → fp16 |
| GPU not detected | Missing ROCm env variables | Set ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES, HSA_OVERRIDE_GFX_VERSION |
| bitsandbytes unsupported | No ROCm build of bitsandbytes | Dropped quantization entirely — MI300X has enough VRAM |
| Garbage inference output | Tokenizer padding misconfigured | Set pad_token = eos_token and fixed padding_side |
| Trainer eval errors | Transformers version mismatch | Pinned transformers>=4.40.0 |
bitsandbytes 문제는 별도의 주석에 가치가 있습니다: NVIDIA 하드웨어에서는 4-bit 양자화는 메모리에 모델을 맞출 때 필수로 자주 요구됩니다. MI300X는 192 GB HBM3를 가지고 있으므로, 이는 단순히 불필요합니다. 이것은 진정한 하드웨어 우위입니다 — 더 깨끗한 훈련, 양자화 아티팩트 없음.
| Metric | Value |
|---|---|
| Trainable parameters | ~2.2M (0.15% of total) |
| ... | |
| No GPU? No problem. HuggingFace Spaces (CPU inference)에서 실행되는 라이브 Gradio 데모: |
AMD 하드웨어를 가지고 있나요? 리포를 클론하고 네이티브로 실행하세요:
git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
...
이 프로젝트는 파이프라인이 작동함을 증명합니다. 다음 단계는 확장 및 강화입니다.
더 큰 데이터셋 (Larger dataset) — 전체 MedMCQA 코퍼스 (~180,000 개 질문) 를 사용하여 학습하고 PubMedQA 신뢰도 점수 (Confidence scoring) — 답변과 함께 교정된 신뢰도 추정치를 추가합니다.
RAG 통합 (RAG integration) — 실시간 의학 문헌 검색을 바탕으로 답변을 근거화합니다.
평가 도구 (Evaluation harness) — 훈련 분포를 넘어 올바른 홀드아웃 정확도 벤치마킹을 수행합니다.
MedQA 는 오픈소스 AMD 하드웨어로 설명 가능한 의료 AI 를 구축하는 것이 가능하다는 것을 보여줍니다. HuggingFace 생태계의 ROCm 호환성은 genuinely good 입니다. MI300X 의 메모리 여유 공간은 엔지니어링 문제의 한 전체 카테고리를 제거합니다. 그리고 LoRA 는 1.7B 모델의 파인튜닝을 5 분 작업으로 만듭니다.
AMD ROCm 에서 구축 중이고 벽에 부딪힌다면, 위의 수정 사항들은 시간 절약에 도움이 될 것입니다. 그리고 의료 AI 를 구축 중이라면 정확도보다 설명에 대한 강조는 진지하게 받아들이는 가치가 있습니다.
lablab.ai 의 AMD Developer Hackathon 을 위해 제작 · AMD ROCm + HuggingFace 생태계 지원
- Harikrishna Sivanand Iyer 와 Srijan Sivaram A
AI 자동 생성 콘텐츠
본 콘텐츠는 Hugging Face Blog의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기