
PyTorch와 Lightning AI로 LSTM 구축하기 파트 7: 체크포인트를 사용한 학습 재개
요약
PyTorch와 Lightning AI를 사용하여 LSTM 모델 학습을 중단된 지점부터 재개하는 방법을 설명합니다. 체크포인트를 활용해 에포크를 연장하고 모델의 예측 정확도를 높이는 과정을 다룹니다.
핵심 포인트
- Lightning AI의 체크포인트 기능을 통한 학습 재개 방법
- best_model_path를 사용하여 최신 체크포인트 경로 확보
- ckpt_path 인자를 활용한 모델 최적화 지속
- 학습 연장을 통한 예측값의 수렴 성능 개선 확인
이전 기사에서 우리는 TensorBoard를 사용하여 학습 과정을 분석했습니다. 그래프를 바탕으로 모델이 완전히 수렴(converged)하지 않았으며, 추가적인 학습 에포크(epochs)를 통해 이득을 얻을 수 있다는 결론을 내렸습니다.
이번 기사에서 그 작업을 계속해 보겠습니다.
Lightning의 장점 중 하나는 처음부터 다시 시작하지 않고 학습을 계속할 수 있다는 점입니다.
이는 Lightning이 학습 중에 **체크포인트 (checkpoints)**를 자동으로 저장하기 때문에 가능합니다.
이 체크포인트들을 사용하면 학습을 중단했던 지점부터 다시 시작하여 모델을 계속 최적화할 수 있습니다.
체크포인트 가져오기
먼저, 최신 체크포인트의 경로를 찾아야 합니다.
path_to_best_checkpoint = trainer.checkpoint_callback.best_model_path
여기서 best_model_path는 Lightning이 저장한 최신 체크포인트의 경로를 제공합니다.
에포크 수 늘리기
이제 새로운 트레이너(trainer)를 생성하고 최대 에포크(epochs) 수를 3000으로 늘립니다.
trainer = L.Trainer(max_epochs=3000)
처음부터 시작하는 대신, 저장된 체크포인트로부터 학습을 재개합니다.
trainer.fit(
model,
train_dataloaders=dataloader,
...
ckpt_path를 지정함으로써, Lightning은 모델을 다시 초기화하는 대신 저장된 체크포인트부터 학습을 계속합니다.
업데이트된 예측값 확인하기
이제 다시 한번 예측값을 출력해 보겠습니다.
print("\nComparing observed and predicted values")
print(
...
그러면 다음과 같은 출력이 생성됩니다:
Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor(0.0009)
...
Company A에 대한 예측값이 목표값인 0에 훨씬 더 가까워졌습니다.
마찬가지로, Company B에 대한 예측값도 목표값인 1에 더 가까워졌습니다.
TensorBoard 그래프 비교하기
다시 TensorBoard를 살펴보겠습니다.
Company A
Company A
더 많은 에포크(epochs) 동안 훈련한 결과, 예측값이 목표값인 0에 더 가까워졌습니다.
Company B
마찬가지로, Company B의 예측값은 목표값인 1에 더 가까워졌습니다.
훈련 손실 (Training Loss)
train_loss 그래프에서도 추가 훈련 후 손실(loss)이 더욱 감소했음을 보여줍니다.
모델 성능은 향상되었지만, 예측을 더욱 정교하게 다듬기 위해 여전히 더 많은 에포크 동안 훈련할 수 있습니다.
다음 글에서는 모델 개선을 계속하고, 또한 Lightning이 PyTorch의 내장 nn.LSTM() 모듈을 사용하여 LSTM 구현을 어떻게 간소화할 수 있는지 탐구할 예정입니다.
AI 에이전트들은 코드를 빠르게 작성합니다. 또한 사용자에게 알리지 않고 로직을 제거하거나, 동작 방식을 변경하고, 버그를 도입하기도 합니다. 이는 프로덕션 환경에서 자주 발견됩니다.
git-lrc가 이를 해결합니다. 이 도구는 git 커밋에 연결되어 커밋이 반영되기 전에 모든 diff를 검토합니다. 설정은 60초 만에 끝납니다. 완전히 무료입니다.
피드백이나 기여는 언제든지 환영합니다! 온라인이며, 소스 코드가 공개되어 있어 누구나 사용할 수 있습니다.
AI 자동 생성 콘텐츠
본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기

