본문으로 건너뛰기

© 2026 Molayo

Zenn헤드라인2026. 06. 05. 00:37

Just Train Twice 논문 해설: 실패 사례를 무겁게 보는 것만으로 group robustness를 개선하는 심플한 수법

요약

가상 상관(Spurious Correlation)으로 인해 특정 그룹에서 성능이 저하되는 문제를 해결하기 위한 Just Train Twice(JTT) 논문을 해설합니다. JTT는 그룹 어노테이션 없이 ERM 모델의 실패 사례에 가중치를 두어 재학습함으로써 Group Robustness를 개선하는 단순하고 강력한 방법론을 제시합니다.

핵심 포인트

  • ERM은 평균 정확도는 높지만 특정 그룹에서 실패할 수 있음
  • 가상 상관은 모델이 데이터의 본질이 아닌 배경 등에 의존하게 만듦
  • JTT는 ERM의 오분류 샘플을 식별하여 가중치를 높여 재학습함
  • 그룹 어노테이션 없이도 Group Robustness를 개선 가능함

Just Train Twice 논문 해설: 실패 사례를 무겁게 보는 것만으로 group robustness를 개선하는 심플한 수법

머신러닝 (Machine Learning) 모델의 성능을 평가할 때, 가장 흔히 사용되는 지표는 평균 정확도이다. 훈련 데이터 전체, 혹은 테스트 데이터 전체에 대해 얼마나 정확하게 예측했는지를 본다. 이 평가는 단순하고 이해하기 쉬우며, 많은 상황에서 유용하다.

하지만 평균 정확도가 높다는 것이 모든 데이터 군에 대해 안정적으로 높은 성능을 나타낸다는 것을 의미하지는 않는다. 특히 입력과 라벨 사이에 가상 상관 (Spurious Correlation) 이 존재할 경우, ERM (Empirical Risk Minimization)에 의해 학습된 모델은 그 가상 상관에 의존함으로써 평균적으로는 높은 정확도를 달성하면서도, 특정 그룹에 대해서는 크게 실패할 수 있다.

예를 들어, 새 분류에서 "수조류는 물가에 있다", "육상조류는 육지에 있다"라는 상관관계가 훈련 데이터에 강하게 존재할 경우, 모델은 새 그 자체가 아니라 배경에 의존하여 분류할 가능성이 있다. 이때 육지에 있는 수조류나 물가에 있는 육상조류처럼 훈련 데이터 중에서는 소수파인 샘플에 대해 성능이 크게 저하될 수 있다. 평균 정확도만 보고 있으면 이러한 실패는 보이지 않기 쉽다.

이 문제에 대해, Group DRO 는 각 그룹의 손실을 고려하여 가장 성능이 나쁜 그룹의 손실을 직접 작게 만드는 수법이다. 이는 worst-group performance를 개선하는 데 강력한 방법이다. 한편, Group DRO에는 큰 제약이 있다. 각 훈련 샘플이 어느 그룹에 속하는지를 나타내는 그룹 어노테이션 (Group Annotation) 이 필요하다.

Just Train Twice (JTT) 는 이 제약을 완화하기 위해 제안된 매우 단순한 수법이다. JTT에서는 먼저 일반적인 ERM 모델을 짧게 학습시키고, 그 모델이 오분류한 훈련 샘플을 특정한다. 그다음, 그 오분류된 샘플의 가중치를 크게 하여 모델을 다시 한번 학습시킨다. 이름 그대로 하는 일은 "두 번 학습하는 것" 뿐이다.

언뜻 보면 이것은 단순한 실패 사례의 재가중치 부여(re-weighting)로 보인다. 하지만 중요한 점은 ERM이 실패한 샘플 집합이 단순한 노이즈의 모임이 아니라, 가상 상관 하에서 성능이 무너지기 쉬운 그룹을 반영하고 있을 가능성이 있다는 점이다. 즉, JTT는 훈련 그룹 어노테이션을 직접 사용하지 않고, ERM의 실패를 실마리로 삼아 group robustness를 개선하려는 수법이다.

본 기사에서는 JTT의 알고리즘 자체뿐만 아니라, 원 논문에서 비교되고 있는 ERM, CVaR DRO[1], Learning from Failure (LfF)[2], Group DRO[3]와의 관계를 정리한다. 그 위에서 JTT가 왜 작동하는지, 왜 단순히 고손실 샘플을 동적으로 가중치 부여하는 CVaR DRO와는 다른지, 그리고 JTT가 어떤 한계를 갖는지 살펴본다.

1. Setup: group robustness와 가상 상관

1.1 ERM이 보고 있는 것

먼저, 표준적인 경험적 위험 최소화 (Empirical Risk Minimization; ERM)가 무엇을 최적화하고 있는지 확인한다.

입력을

이 목적 함수는 훈련 데이터 전체에 대한 평균 손실을 작게 만드는 것을 목표로 한다. 따라서 데이터셋 중에서 다수를 차지하는 샘플에 대해 성능이 높으면 전체 평균 정확도도 높아지기 쉽다.

하지만 평균 손실을 최소화하는 것이 모든 부분 집단에 대해 균등하게 좋은 성능을 내는 것을 의미하지는 않는다. 어떤 그룹의 샘플 수가 적은 경우, 그 그룹에서 큰 오차가 발생하더라도 전체 평균에 대한 기여는 작다. 그 때문에 ERM은 평균적으로는 좋은 모델을 반환하는 한편, 특정 그룹에서는 크게 실패할 수 있다.

이 문제는 입력과 라벨 사이에 가상 상관 (Spurious Correlation) 이 존재할 때 특히 현저해진다. 모델은 태스크에 본질적인 특징이 아니라, 훈련 데이터 중에서 우연히 라벨과 상관되어 있는 특징을 사용하여 분류할 수 있기 때문이다.

예를 들어, 새 분류 태스크에서 수조류의 이미지에는 물가 배경이 많고, 육상조류의 이미지에는 육지 배경이 많다고 하자. 이때 모델은 새의 형태나 종류 그 자체가 아니라 배경을 사용하여 분류해도 훈련 데이터상에서는 높은 정확도를 달성할 수 있다. 하지만 그렇게 학습된 모델은 육지에 있는 수조류나 물가에 있는 육상조류에 대해 실패하기 쉽다.

즉, ERM의 문제는 단순히 "성능이 낮다"는 것이 아니다. 평균 정확도만 보면 고성능으로 보임에도 불구하고, 특정 그룹에서 구조적으로 실패하는 것이 문제인 것이다.

1.2 worst-group error

이러한 실패를 포착하기 위해, 본 논문에서는 평균 정확도가 아닌 worst-group performance (최악 그룹 성능)에 주목한다.

미리 정의된 그룹의 집합을

여기서,

로 정의된다. 즉, 모델

이 지표는 각 그룹에서의 분류 오차를 계산하고, 그중에서 가장 큰 것을 본다. 따라서 전체 평균 정확도가 높더라도 어느 한 그룹에서 성능이 크게 무너져 있다면, 그 실패가 직접 평가에 반영된다.

여기서 중요한 것은, worst-group error (최악 그룹 오차)가 평균 성능과는 다른 평가 축을 가지고 있다는 점이다. ERM (Empirical Risk Minimization)이 최소화하는 것은 모든 샘플에 걸친 평균 손실(loss)이다. 반면, worst-group error가 보는 것은 성능이 가장 나쁜 그룹이다.

이 차이는 모델 선택에서도 중요해진다. 평균 validation accuracy (검증 정확도)를 기반으로 모델을 선택하는 것과, worst-group validation accuracy를 기반으로 모델을 선택하는 것은 일반적으로 같지 않다. 평균 정확도를 최대화하는 모델이 반드시 worst-group performance에서 우수하다고 할 수는 없다.

JTT가 다루는 것은 이 worst-group performance를 개선하는 문제이다. 다만, 여기서 중요한 제약이 있다. 훈련 데이터 전체에 대해 각 샘플이 어느 그룹에 속하는지를 나타내는 **그룹 어노테이션 (group annotation)**은 주어지지 않았다는 설정이다.

즉, 본 논문의 목표는 다음과 같이 정리할 수 있다.

다만, 하이퍼파라미터 조정을 위해 소규모의 검증 세트에는 그룹 어노테이션이 부여되어 있다고 가정한다. 이 검증 세트를 사용함으로써 worst-group validation error를 계산하고, 모델 선택이나 하이퍼파라미터 조정을 수행한다.

이 점은 JTT를 이해하는 데 있어 매우 중요하다. JTT는 훈련 그룹 어노테이션을 사용하지 않는 수법이다. 하지만 완전히 그룹 정보를 불필요하게 만드는 수법은 아니다. 적어도 원 논문의 설정에서는 소규모 검증 세트상의 그룹 어노테이션을 사용하여, worst-group performance에 기반한 모델 선택을 수행한다.

1.3 의사 상관에 의한 그룹

본 논문이 주로 다루는 것은 라벨

의사 속성(pseudo-attribute)을

이때, 각 그룹

따라서 그룹 집합은 다음과 같이 표현된다.

이 정의에 의해, 단순히 "물 배경의 이미지"나 "waterbird의 이미지"를 보는 것이 아니라, "물 배경의 waterbird", "육지 배경의 waterbird", "물 배경의 landbird", "육지 배경의 landbird"와 같이 의사 속성과 라벨의 조합마다 성능을 평가할 수 있다.

Waterbirds의 예로 생각하면, 라벨

  • land background (육지 배경) 위의 landbird (녹색)
  • water background (물 배경) 위의 waterbird (녹색)
  • land background (육지 배경) 위의 waterbird (빨간색)
  • water background (물 배경) 위의 landbird (빨간색)

훈련 데이터에서는 통상적으로 waterbird는 water background와, landbird는 land background와 강하게 상관되어 있다. 즉, 많은 훈련 샘플에서 새의 종류와 배경이 일치한다.

Figure 1. Waterbirds 데이터셋에서의 4가지 그룹

이러한 상황에서는 배경이 라벨 예측에 사용할 수 있는 강력한 단서가 된다. 하지만 그것이 태스크에 본질적인 특징이라고 할 수는 없다. 배경과 새 종류의 상관관계는 훈련 데이터 중에는 성립하더라도, 테스트 시에 항상 성립한다고 볼 수 없기 때문이다.

이처럼 훈련 데이터 중에는 라벨과 상관되어 있지만, 예측 대상 그 자체를 본질적으로 결정한다고 할 수는 없는 속성을 여기서는 **의사 속성 (pseudo-attribute)**이라고 부른다. 그리고 그 의사 속성과 라벨의 상관관계를 **의사 상관 (spurious correlation)**이라고 부른다.

의사 상관이 존재하면 ERM은 다수파 그룹에 적합되기 쉽다. Waterbirds의 경우, 물 배경의 waterbird나 육지 배경의 landbird는 분류하기 쉽다. 반면, 육지 배경의 waterbird나 물 배경의 landbird는 훈련 데이터 중에서 소수파가 되기 쉬우며, 모델이 배경에 의존하고 있을 경우 오분류되기 쉽다.

따라서 이 설정에서의 group robustness 문제는 다음과 같이 말할 수 있다.

JTT는 이 문제에 대해 훈련 그룹 어노테이션 (training group annotation)을 사용하지 않고 접근한다. 다음 장에서는 그 위치를 명확히 하기 위해, 원 논문에서 비교 대상인 ERM, CVaR DRO, LfF, Group DRO를 정리한다.

2. 비교 방법론: ERM / CVaR DRO / LfF / Group DRO

JTT를 이해하기 위해서는 단순히 알고리즘의 절차를 보는 것만으로는 불충분하다. 중요한 것은 JTT가 어떤 기존 방법론들 사이에 위치하고 있는가이다.

원 논문에서는 JTT의 비교 대상으로 주로 4가지 방법을 정리하고 있다. ERM, CVaR DRO, Learning from Failure (LfF), Group DRO이다. 이들은 각각 JTT의 서로 다른 측면을 이해하기 위한 기준이 된다.

ERM은 평균 손실 (average loss)을 최소화하는 표준적인 베이스라인이다. CVaR DRO는 훈련 그룹 어노테이션 없이 고손실 샘플 (high-loss samples)을 무겁게 보는 방법으로, JTT와 표면적으로 유사하다. LfF는 첫 번째 모델의 실패를 이용하여 두 번째 모델을 학습한다는 점에서 JTT와 가장 가까운 비교 대상이다. Group DRO는 훈련 그룹 어노테이션을 사용하여 최악 그룹 손실 (worst-group loss)을 직접 최소화하는 방법으로, JTT가 어디까지 도달할 수 있는지를 보여주는 오라클 (oracle)적인 기준이 된다.

이 장에서는 이 4가지 방법을 차례대로 정리하여 JTT의 위치를 명확히 한다.

2.1 ERM: 평균 손실을 최소화하는 표준 방법

ERM은 훈련 데이터 전체에 대한 평균 손실을 최소화하는 표준적인 학습 방법이다. 손실 함수를

ERM의 장점은 그 단순함에 있다. 그룹 정보도, 의사 속성 (pseudo-attribute) 정보도, 특별한 재가중치 부여 (re-weighting)도 필요하지 않다. 일반적인 지도 학습 (supervised learning)으로서 평균 훈련 손실을 낮추기만 하면 된다.

하지만 제1장에서 언급했듯이, ERM은 평균 손실을 보고 있다. 그 때문에 데이터셋 내에서 다수를 차지하는 그룹에 대해 성능이 높다면, 소수파 그룹에서 큰 오차가 발생하더라도 전체 평균 손실은 충분히 작아질 수 있다. 의사 상관관계 (spurious correlation)가 있는 경우 이 문제는 더욱 현저해진다. 모델이 라벨 그 자체에 대응하는 본질적인 특징이 아니라, 라벨과 상관된 의사 속성을 사용하여 예측하더라도 다수파 그룹에서는 높은 정확도를 달성할 수 있기 때문이다.

따라서 ERM은 JTT에 있어 이중적인 의미를 갖는다. 하나는 개선해야 할 표준 베이스라인이며, 다른 하나는 JTT 내부에서 실제로 사용되는 학습 절차이기도 하다. JTT는 ERM을 버리는 방법이 아니다. 오히려 ERM을 한 번 학습하고, 그 실패를 이용하여 다시 한번 ERM을 수행하는 방법이다.

이런 의미에서 JTT는 ERM의 대립물이 아니라, ERM의 실패를 정보로서 재사용하는 방법이다.

2.2 CVaR DRO: 고손실 샘플을 동적으로 무겁게 보는 방법

CVaR DRO는 훈련 그룹 어노테이션을 사용하지 않고 최악 그룹 성능 (worst-group performance)을 개선하려는 자연스러운 접근 방식이다.

DRO, 즉 분포 강건 최적화 (distributionally robust optimization)에서는 경험 분포 (empirical distribution) 그 자체에 대한 평균 손실이 아니라, 경험 분포 주변의 불확실성 집합 (uncertainty set) 내의 최악 케이스 분포에 대한 손실을 작게 만들고자 한다. CVaR는 그 대표적인 목적 함수 중 하나이다.

레벨

여기서,

이 목적 함수는 직관적으로 손실이 가장 큰

CVaR DRO에서의 q의 의미

보충하자면, 여기서 adversarial하게 선택되는 샘플 가중치이다. 따라서 CVaR DRO는 현재 모델이 어려워하는 고손실 샘플에 가중치를 집중시켜, 그 부분의 손실을 낮추도록 학습하는 목적 함수라고 볼 수 있다.

이 발상은 JTT와 매우 유사하다. JTT 또한 일반적인 ERM에서 실패한 샘플을 무겁게 보기 때문이다. 둘 다 훈련 그룹 어노테이션을 필요로 하지 않으며, 고손실 샘플을 통해 최악 그룹 성능을 개선하려 한다.

하지만 양자 사이에는 중요한 차이가 있다. CVaR DRO는 현재 모델에게 손실이 큰 샘플 집합을 학습 중에 동적으로 업데이트한다. 반면, JTT는 첫 번째 식별 모델이 오분류한 샘플 집합을 한 번만 생성하고, 그 집합을 고정한다.

이 차이는 원 논문에서 매우 중요한 차이점으로 다루어진다. CVaR DRO는 '현재 모델이 어려워하는 점'을 계속 추적한다. 반면, JTT는 '초기 ERM이 실패한 점'을 고정하여 무겁게 본다.

이후의 실험에서 볼 수 있듯이, 이 고정된 error set (오류 집합)이 JTT의 성능에 있어 중요하다. 단순히 고손실 (high-loss) 샘플을 무겁게 보면 되는 것이 아니다. 어느 시점의, 어느 모델에 대한 실패를 사용할지가 중요한 것이다.

2.3 LfF: 실패로부터 배우지만, JTT보다 복잡한 수법

Learning from Failure (LfF)는 JTT와 개념적으로 유사한 수법이다. 둘 다 첫 번째 모델의 실패를 이용하여 두 번째 모델의 학습을 보조한다는 구조를 가진다.

LfF의 기본적인 생각은 의도적으로 편향된 (biased) 모델을 학습시키고, 그 모델이 어려워하는 샘플을 어려운 샘플로 취급하는 것이다. 유사 상관관계 (spurious correlation)가 존재하는 경우, 편향된 모델은 학습하기 쉬운 유사 속성 (spurious attribute)에 의존하기 쉬워진다. 그 결과, 유사 상관관계가 성립하지 않는 소수 그룹 (minority group)의 샘플에 대해 실패하기 쉬워진다. 이 실패를 이용하여 두 번째 모델을 재가중치화 (re-weighting)한다.

LfF에서는 두 개의 모델을 고려한다. 첫 번째는 biased model, 두 번째는 debiased model이다. biased model은

biased model은 generalized cross-entropy (GCE) loss를 사용하여 학습된다.

여기서,

GCE loss는 무엇을 하고 있는가

여기서

라고 두면,

GCE loss는

로 정의된다. 여기서

일반적인 cross entropy는

가 되어, 손실이 상한 (upper bounded)을 가진다. 따라서 정답 확률이 매우 낮은 어려운 예시에 대해, cross entropy만큼 극단적으로 큰 손실을 주지 않는다.

이러한 성질 덕분에 GCE loss로 학습된 biased model은 어려운 예시를 강하게 추적하기보다, 빠른 단계에서 정답을 맞히기 쉬운 쉬운 예시에 편향되기 쉽다. 유사 상관관계가 본질적 특징 (essential feature)보다 배우기 쉬운 경우, 그러한 쉬운 예시는 bias-aligned sample에 대응하기 쉽다. 그 결과, biased model은 유사 상관관계를 강하게 이용하는 모델로서 학습된다.

다음으로, debiased model은 biased model과 자신의 예측을 사용하여 각 샘플을 재가중치화하며 학습된다. 원 논문에서 제시된 가중치는 다음과 같은 형태이다.

이 가중치화를 통해, biased model이 제대로 다루지 못하는 샘플, 즉 유사 상관관계가 성립하지 않는 샘pxle이나 어려운 그룹에 속하는 샘플이 상대적으로 무겁게 다뤄질 것으로 기대된다.

LfF에서의 가중치에 대한 직관

LfF에서는 debiased model을 학습할 때 다음 가중치를 사용한다.

여기서

언뜻 보기에는 이해하기 어렵지만, 확률의 로그는

를 사용하여 다시 쓰면,

가 된다. 즉, 이 가중치는 "biased model의 손실이 biased model과 debiased model의 손실 합계에서 차지하는 비율"로 간주할 수 있다.

biased model이 해당 샘플을 쉽게 분류할 수 있는 경우,

LfF에서는 biased model을 의도적으로 유사 상관관계 쪽으로 치우치게 학습시킨다. 따라서 biased model이 실패하기 쉬운 샘플은 유사 상관관계만으로는 분류하기 어려운 bias-conflicting sample일 가능성이 높다. 위의 가중치는 이러한 샘플을 debiased model의 학습에서 상대적으로 무겁게 보는 메커니즘이다.

LfF와 JTT는 둘 다 **"첫 번째 모델의 실패를 사용한다"**는 점에서는 비슷하다. 하지만 JTT는 LfF보다 훨씬 단순하다.

LfF에서는 biased model을 의도적으로 만들기 위해 GCE loss를 도입해야 하며, biased model과 debiased model을 조합하여 학습해야 한다. 또한 두 모델을 동시에 다루기 때문에 학습 절차도 복잡해진다.

반면, JTT에서는 특별한 loss를 도입하지 않는다. 의도적으로 편향된 모델을 설계하지도 않는다. 표준적인 ERM 모델을 짧게 학습시키고, 그 모델이 오분류한 샘플을 고정하여 무겁게 볼 뿐이다.

즉, LfF는 "실패하기 쉬운 모델을 의도적으로 만드는" 수법이며, JTT는 "통상적인 ERM의 초기 실패를 그대로 사용하는" 수법이다.

이 차이는 JTT의 중요한 특징이다. JTT는 failure-based 재가중치화 수법이면서도 학습 절차는 매우 단순하다.

2.4 Group DRO: 훈련 그룹 정보를 사용하는 oracle

Group DRO는 지금까지의 수법과는 달리, 훈련 그룹 어노테이션 (group annotation)을 사용한다. 즉, 각 훈련 샘플

Group DRO의 목적은 훈련 데이터상의 worst-group loss를 직접 최소화하는 것이다. 그룹

ERM이 모든 샘플의 평균 손실을 보는 것에 반해, Group DRO는 그룹별 평균 손실을 계산하여 그 최댓값을 작게 만들려고 한다. 따라서 Group DRO는 worst-group performance를 개선하기 위한 직접적인 목적 함수 (objective function)이다.

이 점에서 Group DRO는 매우 강력하다. 평균 손실이 아니라 가장 성능이 나쁜 그룹의 손실을 명시적으로 최적화하기 때문에, 가상 상관 (spurious correlation)이 성립하지 않는 소수 그룹에도 학습 신호를 주기 쉽다.

하지만 그 대가로 훈련 그룹 어노테이션이 필요하다. 이는 실제 응용에서 큰 제약이다. 예를 들어, 모든 훈련 샘플에 대해 배경 속성, 인구 통계적 속성, 측정 조건, 도메인, 서브그룹 등을 부여해야 하는 경우 그 비용은 매우 크다.

따라서 JTT의 원 논문에서는 Group DRO를 훈련 그룹 어노테이션이 없는 수법에 기대할 수 있는 성능의 상한을 제공하는 oracle method로 취급한다. 즉, JTT는 Group DRO와 동일한 정보 조건에서 경쟁하고 있는 것이 아니다. Group DRO는 훈련 그룹 정보를 사용할 수 있는 수법이며, JTT는 그것을 사용할 수 없는 수법이다.

이 차이를 고려하면 JTT의 평가에서 중요한 것은 JTT가 Group DRO를 완전히 능가하느냐가 아니다. 오히려 훈련 그룹 어노테이션을 사용하지 않음에도 불구하고, ERM과 Group DRO 사이의 worst-group accuracy 차이를 어디까지 메울 수 있는가이다.

2.5 JTT의 위치づけ

이상의 4가지 비교 수법을 정리하면 JTT의 위치づけ가 보인다.

ERM은 가장 표준적이고 단순한 평균 손실 최소화이다. 하지만 가상 상관 하에서는 worst-group performance가 무너질 수 있다.

CVaR DRO는 훈련 그룹 어노테이션 없이 고손실 샘플을 무겁게 보는 수법이다. JTT와 마찬가지로 그룹 정보 없이 poor group performance에 대처하려 하지만, 가중치 부여 대상을 동적으로 업데이트한다는 점이 다르다.

LfF는 첫 번째 모델의 실패를 사용하여 두 번째 모델을 학습한다는 점에서 JTT와 유사하다. 하지만 LfF는 의도적으로 편향된 (biased) 모델이나 교차 업데이트를 필요로 하기 때문에 JTT보다 복잡하다.

Group DRO는 훈련 그룹 어노테이션을 사용하여 worst-group loss를 직접 최소화하는 수법이다. JTT에게 있어서는 동일한 조건에서 비교하는 베이스라인이라기보다, 훈련 그룹 정보가 있다면 어디까지 가능한지를 보여주는 oracle적인 기준이다.

이상을 종합하면, JTT는 ERM의 단순함을 유지하면서 실패 사례에 기반한 가중치 부여를 도입하는 수법으로 위치づけ된다.

JTT는 Group DRO처럼 그룹별 손실을 직접 최소화하는 것이 아니다. 또한 CVaR DRO처럼 현재 모델이 큰 손실을 내고 있는 샘플을 동적으로 추적하는 것도 아니다. 나아가 LfF처럼 의도적으로 가상 상관 쪽으로 치우친 biased model을 설계하는 것도 아니다.

JTT가 수행하는 것은 초기 ERM이 오분류한 훈련 샘플 집합을 고정하고, 그 집합을 무겁게 하여 다시 한번 ERM을 수행하는 것뿐이다. 이 단순함이 JTT의 큰 특징이다.

다음 장에서는 이 JTT의 구체적인 알고리즘을 살펴보겠다.

3. JTT: Just Train Twice

전 장에서는 JTT의 비교 대상으로 ERM, CVaR DRO, LfF, Group DRO를 정리했다. 여기서부터는 JTT 그 자체를 살펴보겠다.

JTT는 훈련 시 그룹 어노테이션을 사용하지 않고 worst-group performance를 개선하기 위한 2단계 수법이다. 이름 그대로 기본적으로는 모델을 두 번 학습할 뿐이다.

단, JTT의 본질은 단순히 "두 번 학습하는 것"이 아니다. 중요한 것은 첫 번째 ERM 모델이 오분류한 훈련 샘플을 error set으로 고정하고, 그 샘플 집합을 무겁게 하여 두 번째 ERM을 수행한다는 점이다.

3.1 기본 아이디어

JTT의 기본적인 흐름은 다음과 같다.

  • 일반적인 ERM (Empirical Risk Minimization)을 통해 식별 모델 (discriminative model)을 짧게 학습한다.
  • 해당 식별 모델이 오분류한 훈련 샘플을 모아 error set을 만든다.
  • error set에 포함된 샘플의 가중치를 높인다.
  • 가중치가 부여된 훈련 데이터 위에서, 최종 모델을 다시 한번 ERM으로 학습한다.

이 절차는 매우 단순하다. 특수한 손실 함수 (loss function)를 도입할 필요가 없다. 훈련 데이터 전체에 대한 그룹 어노테이션 (group annotation)도 필요 없다. 첫 번째 모델을 학습하고, 그 모델이 틀린 샘플을 무겁게 하여 다시 한번 학습할 뿐이다.

JTT의 직관은, ERM이 실패하는 샘플에는 가상 상관 (spurious correlation) 하에서 성능이 무너지기 쉬운 그룹의 정보가 포함되어 있다는 것이다.

가상 상관이 존재하는 데이터에서 ERM은 학습하기 쉬운 상관관계에 의존하기 쉽다. 예를 들어 Waterbirds 데이터셋에서는 새 그 자체보다 배경을 단서로 분류해 버릴 수 있다. 이 경우, 배경과 라벨의 상관관계가 성립하는 다수 그룹 (majority group)에서는 높은 정확도를 낼 수 있지만, 상관관계가 성립하지 않는 소수 그룹 (minority group)에서는 오분류하기 쉬워진다.

따라서 첫 번째 ERM 모델이 오분류한 샘플 집합에는, worst-group examples가 일반적인 훈련 데이터보다 더 높은 비율로 포함되어 있을 가능성이 있다. JTT는 이 성질을 이용한다.

3.2 Stage 1: 식별 모델로 error set 만들기

JTT의 첫 번째 단계에서는, 훈련 데이터 상에서 일반적인 ERM을 수행하여 식별 모델을 학습한다. 이 식별 모델을

[IMG:1]

다만, 이 모델은 최종적인 분류기로 사용하기 위한 것이 아니다. 목적은 고성능 모델을 얻는 것이 아니라, ERM이 어떤 훈련 샘플에서 실패하는지를 특정하는 것이다.

식별 모델

[IMG:2]

이 집합

[IMG:3]

여기서 중요한 것은 식별 모델을 훈련 데이터에 완전히 적합(fit)시키지 않는 것이다. 만약 식별 모델을 너무 길게 학습하면, 훈련 데이터 상의 오분류가 거의 없어지게 되어 error set이 빈 집합에 가까워진다. 그 경우 어떤 샘플을 무겁게 만들어야 하는지에 대한 정보가 손실된다.

따라서 JTT에서는 식별 모델을 미리 정해진

[IMG:4]

이 조기 종료 (early stopping)는 단순한 계산 비용 절감이 아니다. JTT가 보고 싶은 것은 충분히 학습된 ERM이 최종적으로 어디서 실패하느냐가 아니라, 학습 초기 단계의 ERM이 어떤 샘플을 아직 분류하지 못하고 있는가이다. 다시 말해, JTT는 "수렴 후의 ERM의 실패"가 아니라, **"짧게 학습한 ERM의 실패"**를 이용하여 error set을 만든다.

이 점은 JTT의 직관을 이해하는 데 중요하다. 학습 초기 단계의 ERM은 우선 학습하기 쉬운 상관관계에 의존하기 쉽다. 따라서 가상 상관이 성립하는 다수 그룹에는 빠르게 적합되는 반면, 가상 상관이 성립하지 않는 그룹이나 소수 그룹에는 적합되기 어렵다. JTT는 그 시점에서의 오분류 집합을 error set으로 추출한다.

3.3 Stage 2: error set을 무겁게 하여 재학습하기

두 번째 단계에서는, 첫 번째 단계에서 얻은 error set

[IMG:5]

가중치 ERM의 목적 함수는 다음과 같이 쓸 수 있다.

$$ ext{arg min}{\theta} \sum{i=1}^{n} w_i \mathcal{L}(f_{\theta}(x_i), y_i)$$

여기서,

$$w_i = \begin{cases} \lambda_{\text{up}} & \text{if } x_i \in \text{error set} \ 1 & \text{otherwise} \end{cases}$$

구현상으로는 이 목적 함수를 명시적인 가중치 손실 (weighted loss)로 구현할 필요는 없다. 원 논문에서는 error set에 포함된 각 샘플을

[IMG:6]

즉, JTT의 두 번째 단계는 다음과 같은 데이터셋 위에서 일반적인 ERM을 수행하는 것과 대응된다.

  • error set에 포함된 샘플: $\lambda_{\text{up}}$번 나타나게 함
  • error set에 포함되지 않은 샘플: 1번만 나타나게 함

이렇게 보면 JTT는 구현 측면에서도 상당히 단순하다. 특별한 optimizer도, 특수한 로버스트 최적화 절차도 필요 없다. 데이터의 샘플링 비율을 변경한 뒤 일반적인 ERM을 수행하면 된다.

3.4 JTT의 알고리즘

JTT의 절차를 정리하면 다음과 같다.

  1. Stage 1: 훈련 데이터에 대해 짧게 ERM을 수행하여 식별 모델 $f_{\theta_{1}}$을 학습한다.
  2. Error Set 생성: $f_{\theta_{1}}$이 오분류한 샘플들의 집합 $E$를 구한다.
  3. Stage 2: 샘플 $x_i \in E$의 가중치를 $\lambda_{\text{up}}$으로, $x_i \notin E$의 가중치를 1로 설정하여 다시 ERM을 수행, 최종 모델 $f_{\theta_{2}}$를 학습한다.

이 알고리즘을 보면 JTT는 매우 단순하다. 첫 번째 모델로 실패 사례를 찾고, 두 번째 모델에서 그 실패 사례를 무겁게 볼 뿐이다.

다만, 이 단순함이 JTT가 아무런 가정을 하지 않았음을 의미하지는 않는다. JTT는 초기 ERM이 오분류하는 샘플 집합에 worst-group performance 개선을 위해 유용한 정보가 포함되어 있다는 경험적 성질에 의존하고 있다.

3.5 무엇이 JTT다운 것인가

JTT의 특징을 정리하면 중요한 점은 세 가지가 있다.

  • JTT는 훈련 그룹 어노테이션을 사용하지 않는다.

Group DRO와 같이, 각 훈련 샘플이 어느 그룹에 속하는지 알 필요는 없다. 사용하는 것은 일반적인 입력 (input)과 라벨 $x_i$, 그리고 식별 모델의 오분류 (misclassification) 정보뿐이다. $y_i$ -
JTT는 특수한 실패 모델을 설계하지 않는다.

LfF와 같이, 의도적으로 편향된 (biased) 모델을 만들기 위한 특수한 손실 함수 (loss function)를 도입하는 것이 아니다. 표준적인 ERM 모델을 짧게 학습하고, 그 오분류를 사용할 뿐이다. -
JTT는 error set을 고정한다.

이는 CVaR DRO와의 중요한 차이점이다. CVaR DRO는 현재 모델에게 손실 (loss)이 큰 샘플 집합을 학습 중에 동적으로 업데이트한다. 반면 JTT는 첫 번째 단계에서 얻은 error set을 고정하며, 두 번째 단계에서는 그 집합을 바꾸지 않는다.

이처럼 JTT는 그룹을 직접 추정하는 것이 아니라, 짧게 학습한 ERM의 오분류 집합을 단서로 사용한다. 가상 상관 (spurious correlation)이 존재하는 상황에서는, 이 error set에 ERM이 취약한 그룹의 샘플이 포함되기 쉽다. JTT는 그 집합을 고정하여 무겁게 다룸으로써, worst-group performance의 개선을 노린다.

다음 장에서는 이 JTT가 실험적으로 어느 정도 유효한지를 ERM, CVaR DRO, LfF, Group DRO와의 비교를 통해 살펴본다.

4. 실험 결과: JTT는 어디까지 개선하는가

AI 자동 생성 콘텐츠

본 콘텐츠는 Zenn AI의 원문을 AI가 자동으로 요약·번역·분석한 것입니다. 원 저작권은 원저작자에게 있으며, 정확한 내용은 반드시 원문을 확인해 주세요.

원문 바로가기
0

댓글

0