
PyTorch와 Lightning AI를 이용한 LSTM 구축 Part 4: Training Step 및 초기 예측
요약
PyTorch와 Lightning AI를 사용하여 LSTM 모델의 training_step 구현 및 초기 예측 과정을 설명합니다. 손실 계산, 로그 기록 방법, 그리고 학습 전 모델의 예측 성능을 확인하는 단계를 다룹니다.
핵심 포인트
- training_step을 통한 배치 데이터 처리 및 손실 계산 방법
- Lightning AI의 log() 함수를 이용한 학습 로그 기록
- forward() 메서드를 활용한 모델의 초기 예측 수행
- 계산 그래프 제거를 위한 .detach() 사용법
이전 기사에서 우리는 LSTM 셀을 완성하고, 모델을 위한 forward 메서드와 Adam optimizer를 살펴보았습니다.
이번 기사에서는 training_step() 함수를 살펴보고, 학습 없이 모델을 실행해 보겠습니다.
training_step() 함수는 두 회사 중 한 곳의 훈련 데이터 배치(batch)와 해당 배치의 인덱스를 가져옵니다.
그런 다음 forward() 함수를 사용하여 해당 훈련 예시에 대한 예측을 수행합니다.
def training_step(self, batch, batch_idx):
input_i, label_i = batch
output_i = self.forward(input_i[0])
...
다음으로, 예측값과 관측값 사이의 잔차 제곱(squared residual)인 손실(loss)을 계산합니다.
또한 학습 중에 손실이 어떻게 변하는지 쉽게 추적할 수 있도록 손실을 로그(log)로 남길 수 있습니다.
Lightning은 이 목적을 위해 log() 함수를 제공합니다. 이 함수는 로그를 lightning_logs 디렉토리에 자동으로 저장합니다.
Company A와 Company B에 대한 예측값과 같은 다른 값들도 로그로 남길 수 있습니다.
마지막으로, 손실을 반환합니다.
def training_step(self, batch, batch_idx):
input_i, label_i = batch
output_i = self.forward(input_i[0])
...
지금까지 우리는 다음 사항들을 구현했습니다:
- 가중치(weight) 및 편향(bias) 텐서 초기화
lstm_unit()내 LSTM 계산 구현- 펼쳐진(unrolled) LSTM을 통한 순전파(forward pass)를 수행하는
forward()메서드 생성 configure_optimizers()를 사용한 Adam optimizer 설정training_step()을 사용한 훈련 손실(training loss) 계산 및 로그 기록
이제 모델을 사용해 봅시다.
model = LSTMByHand()
print("\nComparing observed and predicted values")
...
여기서 우리는 1일부터 4일까지의 주가를 포함하는 텐서를 전달합니다. 그러면 모델은 5일째의 값을 예측합니다.
모델은 예측값(prediction)과 그와 관련된 계산 그래프(computation graph)를 모두 반환합니다. 우리는 .detach()를 호출하여 계산 그래프를 제거하고 예측값만을 가져옵니다.
코드를 실행하면 다음과 같은 출력이 생성됩니다:
Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor(-0.2321)
Company B: Observed = 1, Predicted = tensor(-0.2360)
Company A에 대한 예측값은 관측값(observed value)과 상당히 유사합니다.
하지만, Company B에 대한 예측값은 예상값(expected value)과 상당히 거리가 있습니다.
다음 기사에서는 이러한 예측을 개선하기 위해 모델을 학습시킬 것입니다.
AI 에이전트는 코드를 빠르게 작성합니다. 하지만 사용자에게 알리지 않고 조용히 로직을 제거하거나, 동작을 변경하고, 버그를 유발하기도 합니다. 이는 종종 프로덕션(production) 환경에서 발견되곤 합니다.
git-lrc가 이 문제를 해결합니다. 이 도구는 git 커밋에 후킹(hook)하여 모든 차이점(diff)이 반영되기 전에 검토합니다. 설정은 60초면 충분하며, 완전히 무료입니다.
모든 피드백과 기여자를 환영합니다! 온라인에서 소스 코드를 확인할 수 있으며 누구나 사용할 준비가 되어 있습니다.
AI 자동 생성 콘텐츠
본 콘텐츠는 Dev.to AI tag의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.
원문 바로가기