cross-entropy 손실함수 완전 정복 — 심화 학습서

04. cross-entropy 실전 심화 — 수치안정 · label smoothing · 변형 · 함정

목차

03장까지로 이론은 끝났다. cross-entropy가 무엇이고(01), 로짓에서 어떻게 계산되며(02), gradient가 왜 qyq-y로 깔끔한지(03)를 다 봤다. 이 장은 그 깔끔한 식을 실제 코드로 옮길 때 만나는 현실의 벽을 다룬다. log0\log 0으로 NaN이 터지는 걸 막는 log-sum-exp 트릭, 모델의 과확신을 누르는 label smoothing, 클래스 불균형용 weighted CE, CE의 친척들, 그리고 조용히 학습을 망치는 함정 목록이다. 핵심 식은 본문에 다시 적어 이 장만 읽어도 따라오게 했다.

4.0 표기 약속

03장에서 gradient를 qyq-y로 정리할 때 쓴 기호를 그대로 이어 간다. 여기서 한 줄씩 다시 적어 둔다.

기호 의미
z=(z1,,zK)z=(z_1,\dots,z_K) 로짓(logits, 정규화 안 된 점수): 음수도 되고 합이 1이 아니어도 된다
qi=softmax(z)i=ezijezjq_i=\text{softmax}(z)_i=\dfrac{e^{z_i}}{\sum_j e^{z_j}} 예측 확률
y=(y1,,yK)y=(y_1,\dots,y_K) 목표 분포(one-hot 또는 soft)
CE=kyklogqk\text{CE}=-\sum_k y_k\log q_k cross-entropy 손실
KK 클래스 수
m=maxjzjm=\max_j z_j 로짓의 최댓값

이 장의 log\log는 자연로그(밑 ee, 단위 nat)로 통일한다. 밑 2(bit)와의 차이는 perplexity를 다루는 §4.4에서만 짚는다.


광고 · Advertisements

4.1 수치 안정성 — 실무에서 가장 먼저 만나는 벽

이론식은 logqt-\log q_t 하나로 깔끔하지만, 이걸 float으로 계산하면 두 곳에서 터진다. 먼저 무엇이 왜 터지는지 보고, 그다음 한 트릭으로 둘을 동시에 막는다.

log0\log 0 문제: 왜 NaN이 터지나

CE의 정의는 kyklogqk-\sum_k y_k\log q_k다. one-hot이면 정답 클래스 tt에 대해 식이 CE=logqt\text{CE}=-\log q_t로 줄어든다. 모델이 정답 확률을 qt0q_t\to 0으로 예측하면 logqt+-\log q_t\to+\infty가 된다.

float에서 log(0)\log(0)-inf다. 이 값이 손실로 들어오면 gradient가 infNaN이 된다. 그 스텝 이후 모든 파라미터가 NaN으로 오염되고, 한 번 오염되면 되돌릴 수 없다. 학습 전체가 죽는다.

# float64 기준
-log(1e-12) = 27.63        # 작지만 유한 — 아직 살아 있음
-log(0.0)   = inf          # 죽음

qq가 0이 되나? softmax는 수학적으로는 늘 qi>0q_i>0이다. 그런데 float에서는 0이 찍힐 수 있다. 어떤 로짓이 다른 로짓보다 압도적으로 크면 ezime^{z_i-m}이 언더플로로 0.0이 된다. 아니면 반올림으로 정확히 0.0이 찍힌다. 그 뒤 log\log에서 터진다.

막는 방법은 두 가지다.

방법 평가
확률 클리핑 q^=clip(q,ε,1ε)\hat q=\text{clip}(q,\varepsilon,1-\varepsilon), 보통 ε=107\varepsilon=10^{-7} 임시 방편. 손실값이 살짝 편향된다. 이미 softmax를 거친 확률밖에 없을 때의 차선책
로짓에서 직접 계산(log-softmax) logqi=zilogsumexp(z)\log q_i=z_i-\text{logsumexp}(z) 정석. log\log를 따로 취하지 않으므로 log0\log 0이 구조적으로 생기지 않는다

클리핑은 0을 ε\varepsilon로 바꿔 막는 사후 처치다. 로짓 직접 계산은 0이 만들어지는 경로 자체를 없애는 근본 처치다. 둘 중 가능하면 후자를 쓴다. 왜 후자가 되는지는 바로 다음 트릭이 설명한다.

log-sum-exp 트릭 — 유도

핵심은 한 줄짜리 보조정리다. 임의의 상수 mm에 대해 다음이 성립한다.

logjezj=logjezjmem=log ⁣(emjezjm)=m+logjezjm.\log\sum_j e^{z_j} =\log\sum_j e^{z_j-m}e^{m} =\log\!\Big(e^{m}\sum_j e^{z_j-m}\Big) = m+\log\sum_j e^{z_j-m}.

이것은 항등식이다. 근사가 아니다. 어떤 mm을 골라도 좌변과 우변의 값이 정확히 같다. 그래서 우리는 수치적으로 가장 편한 mm을 마음대로 고를 수 있다. m=maxjzjm=\max_j z_j를 고르면 두 가지가 따라온다.

  • zjm0z_j-m\le 0이므로 모든 ezjme^{z_j-m}(0,1](0,1] 안에 들어온다. 지수가 커질 일이 없으니 overflow가 불가능하다.
  • 가장 큰 항은 emm=e0=1e^{m-m}=e^0=1이다. 합이 최소 1이므로 log\log의 인자가 0이 되지 않는다. log0\log 0도 불가능하다.

왜 overflow가 막히나, 직관으로. naive ezj\sum e^{z_j}는 가장 큰 항이 e100210435e^{1002}\approx 10^{435}처럼 float64 한계(1.8×10308\approx 1.8\times10^{308})를 훌쩍 넘겨 inf가 된다. mm을 빼면 그 거대한 공통 인자 eme^mlog\log 밖으로 빠져나와 덧셈으로 바뀐다. 그러면 지수 연산이 다루는 수는 항상 [0,1][0,1] 안에 묶인다. 큰 수는 log\log 바깥의 덧셈이 맡고, 지수는 안전한 범위만 본다.

log-softmax와 융합 CE

이 결과로 softmax의 로그를 직접 쓴다.

  logsoftmax(z)i=zilogjezj=zilogsumexp(z)  \boxed{\;\log\text{softmax}(z)_i = z_i-\log\sum_j e^{z_j} = z_i-\text{logsumexp}(z)\;}

CE를 "softmax → log → NLL" 3단으로 따로 짜면 중간마다 overflow나 log0\log 0 지뢰를 밟는다. 안정 경로는 log-softmax와 NLL을 한 식으로 합친다.

CE(z,t)=logsoftmax(z)t=logsumexp(z)zt.\text{CE}(z,t)= -\log\text{softmax}(z)_t = \text{logsumexp}(z)-z_t.

이 형태는 softmax나 확률을 명시적으로 만들지 않고 손실을 바로 얻는다. 불안정의 원천이 사라진다. 02장 §2.5에서 본 "PyTorch가 로짓을 직접 받는" 이유가 바로 이것이다. CrossEntropyLoss는 내부에서 log_softmaxnll_loss를 합쳐 계산하고, BCEWithLogitsLoss는 log-sum-exp 트릭을 융합한다(공식 문서가 그렇게 명시한다).

수치예제 — z=(1000,1001,1002)z=(1000,1001,1002)

float64로 직접 계산한 결과다.

naive exp(z)       = [inf, inf, inf]          # e^1000 already overflow
naive softmax      = [nan, nan, nan]          # inf/inf = nan  <- training death

m = max(z) = 1002
stable softmax     = exp(z - m)/sum(exp(z - m))
                   = [0.09003, 0.24473, 0.66524]   # fine
logsumexp(z)       = m + log(sum exp(z - m))
                   = 1002 + log(e^-2 + e^-1 + 1)
                   = 1002.40761
log_softmax(z)     = z - logsumexp(z)
                   = [-2.40761, -1.40761, -0.40761]
CE (true class=0)  = -log_softmax(z)[0] = 2.40761
CE (true class=2)  = -log_softmax(z)[2] = 0.40761

같은 입력에서 naive 경로는 NaN을 내고, 안정 경로는 멀쩡한 손실 0.408과 2.408을 낸다. 여기서 짚을 점이 하나 있다. 세 로짓의 차이는 1과 2뿐이라 softmax 자체는 [0.090, 0.245, 0.665]로 지극히 평범하다. 그런데도 절대 크기가 1000이라는 이유만으로 naive가 죽는다. 상대 차이가 작아도 절대 크기가 크면 터진다는 뜻이다.

아래 그림은 같은 로짓이 두 경로로 갈렸을 때 어디서 NaN이 생기고 어디서 안전한지를 나란히 보여 준다.

SAFE: fused path (PyTorch CrossEntropyLoss)

UNSAFE: split path

z large: exp overflow -> inf

q ~ 0: log(0) = -inf

logits z = (z_1..z_K)

exp(z_i)

softmax q = exp / sum exp

log(q)

NLL = -log q_t

logsumexp(z) = m + log sum exp(z - m)

log_softmax_i = z_i - logsumexp(z)

CE = logsumexp(z) - z_t

inf / inf = NaN

-inf -> NaN gradients

위쪽(UNSAFE)은 exp에서 overflow, log에서 log0\log 0으로 두 번 죽고, 아래쪽(SAFE)은 융합 한 줄로 살아남는다.


4.2 label smoothing — 과확신 누르기

수치 안정은 손실이 죽는 걸 막는 문제였다. 다음은 손실이 너무 잘 들어서 생기는 문제다. 정답을 1로 밀어붙이면 모델이 과하게 확신하게 된다. label smoothing은 그 정답을 살짝 무르게 만든다.

개념과 식

one-hot 타깃은 정답 클래스에 1, 나머지에 0을 둔다. 그러면 CE는 정답 확률을 1로 밀어붙인다. 모델이 과확신(over-confident)해지고 로짓 간 격차가 무한정 벌어진다. softmax는 정답 확률이 1에 끝내 도달하지 못하므로, 식 위에서는 학습이 멈추지 않고 계속 격차를 키운다. label smoothing은 목표를 한 발 물려 준다.

  ykLS=(1ε)yk+εK  \boxed{\;y^{LS}_k=(1-\varepsilon)\,y_k+\frac{\varepsilon}{K}\;}

여기서 ε\varepsilon은 스무딩 강도이고 보통 0.1을 쓴다. 각 클래스가 어떻게 바뀌는지 보면 이렇다.

  • 정답 클래스: 1(1ε)+εK1\to(1-\varepsilon)+\dfrac{\varepsilon}{K}
  • 오답 클래스: 0εK0\to\dfrac{\varepsilon}{K}
  • 합은 (1ε)1+ε=1(1-\varepsilon)\cdot 1+\varepsilon=1로 보존된다(여전히 확률분포다).

이건 "one-hot과 균등분포 u=1/Ku=1/K의 가중평균" yLS=(1ε)y+εuy^{LS}=(1-\varepsilon)\,y+\varepsilon\,u로도 쓸 수 있다. 정답 쪽으로 쏠린 one-hot을 균등분포 쪽으로 ε\varepsilon만큼 끌어당기는 셈이다.

CE에 미치는 영향 — soft label CE로의 연결

타깃이 soft가 됐으니 01장의 soft label cross-entropy kykLSlogqk-\sum_k y^{LS}_k\log q_k가 그대로 적용된다(01장 §1.7③에서 본 "모든 항이 살아남는" 경우다). 이 식은 깔끔하게 두 조각으로 분해된다.

CE(yLS,q)=(1ε)(logqt)원래 hard CE+ε(1Kklogqk)균등분포와의 CE.\text{CE}(y^{LS},q)=(1-\varepsilon)\underbrace{\big(-\log q_t\big)}_{\text{원래 hard CE}}+\varepsilon\underbrace{\Big(-\tfrac1K\textstyle\sum_k\log q_k\Big)}_{\text{균등분포와의 CE}}.

첫 항은 원래 하던 hard CE 그대로다. 둘째 항이 새로 붙은 것으로, 예측이 한쪽으로 쏠리는 데에 벌점을 주는 정규화 역할을 한다. 모든 클래스 확률을 0에서 약간 떨어뜨려 과확신과 log0\log 0 위험을 동시에 줄인다.

수치예제 — K=3K=3, ε=0.1\varepsilon=0.1, 정답 클래스 t=2t=2

먼저 스무딩된 타깃을 만든다.

y=(0,0,1)    (10.1)y+0.1/3    yLS=(0.0333,0.0333,0.9333),=1.0.y = (0,0,1) \;\xrightarrow{\;(1-0.1)y + 0.1/3\;}\; y^{LS} = (0.0333,\,0.0333,\,0.9333), \quad \textstyle\sum = 1.0.

모델 예측을 q=(0.2,0.3,0.5)q=(0.2,0.3,0.5)라 하자. 같은 손실을 두 길로 구해 서로 맞는지 본다.

hard CE(원래 one-hot 기준)는 정답 항 하나뿐이다. logqt=log0.5=0.6931-\log q_t = -\log 0.5 = \mathbf{0.6931}.

soft CE(스무딩 타깃 기준)를 직접 합산하면 이렇다.

softCE=kykLSlogqk=[0.0333ln0.2+0.0333ln0.3+0.9333ln0.5]=0.7407.\text{softCE} = -\sum_k y^{LS}_k\log q_k = -[\,0.0333\ln 0.2 + 0.0333\ln 0.3 + 0.9333\ln 0.5\,] = \mathbf{0.7407}.

이제 같은 값을 분해식으로 다시 검산한다. 분해 softCE=(1ε)CEhard+εCEuniform\text{softCE}=(1-\varepsilon)\,\text{CE}_{\text{hard}}+\varepsilon\,\text{CE}_{\text{uniform}}의 두 항을 따로 구한다. 균등분포와의 CE는 세 클래스의 음로그를 평균낸 값이다. 정답뿐 아니라 모든 클래스를 본다는 점에 주의한다.

CEuniform=13klogqk=13(ln0.2+ln0.3+ln0.5)=1.1689(=1.168853).\text{CE}_{\text{uniform}} = -\tfrac13\sum_k\log q_k = -\tfrac13(\ln 0.2+\ln 0.3+\ln 0.5) = \mathbf{1.1689}\quad(=1.168853).

계산을 펼치면 ln0.2=1.6094\ln 0.2=-1.6094, ln0.3=1.2040\ln 0.3=-1.2040, ln0.5=0.6931\ln 0.5=-0.6931이다. 셋을 더하면 3.5066-3.5066, 1-1을 곱하고 3으로 나누면 1.16891.1689다. 이제 분해를 더해 본다.

0.9×0.6931CEhard+0.1×1.1689CEuniform=0.62379+0.11689=0.7407 0.9\times \underbrace{0.6931}_{\text{CE}_\text{hard}} + 0.1\times \underbrace{1.1689}_{\text{CE}_\text{uniform}} = 0.62379 + 0.11689 = \mathbf{0.7407}\ \checkmark

직접 합산값 0.7407과 정확히 맞는다. 학습자가 분해식을 손으로 더해도 같은 답이 나오려면 중간 항 CEuniform=1.1689\text{CE}_{\text{uniform}}=1.1689를 정확히 써야 한다. 이 중간 항을 잘못 적으면 "=0.7407=0.7407 ✓"가 산술적으로 거짓이 되어 검산이 깨진다.

결과를 해석하면 이렇다. 정답 타깃이 0.9333으로 낮아졌으니 손실이 hard 대비 살짝 커진다(0.693 → 0.741). 그만큼 모델은 정답 확률을 1까지가 아니라 0.93 부근까지만 밀게 된다.

왜 좋은가, 부작용은 무엇인가

이점은 Müller et al. 2019가 정리했다. 첫째, 일반화가 좋아진다. 둘째, 보정(calibration)이 좋아진다. 예측 확률이 실제 정확도에 더 가까워져 beam-search 같은 디코딩에 유리하다. 셋째, 같은 클래스의 표현이 더 조밀한 군집을 이룬다.

부작용도 있다. ε\varepsilon을 너무 키우면 정답 신호를 과하게 희석해 정확도가 떨어진다. 그래서 보통 0.1에 둔다. 또 한 가지, label smoothing으로 학습한 teacher는 지식증류에 해롭다는 보고가 있다(Müller et al. 2019). 로짓에 담긴 클래스 간 유사도 정보가 군집화로 지워져, student로 넘길 정보가 사라지기 때문이다. 증류 자체는 이 책의 범위 밖이라 사실만 적어 둔다.

PyTorch는 CrossEntropyLoss(label_smoothing=ε) 인자로 이걸 직접 지원한다.


광고 · Advertisements

4.3 가중·불균형 대응

클래스마다 데이터 수가 크게 다르면 흔한 클래스가 손실을 지배한다. 손실에 가중치를 곱해 드문 클래스의 목소리를 키우는 게 가장 간단한 대응이다.

weighted CE

클래스 불균형에서는 클래스별 가중치 wkw_k를 손실에 곱한다.

CEw=kwkyklogqk(one-hot)  wtlogqt.\text{CE}_w=-\sum_k w_k\,y_k\log q_k\quad\Rightarrow\quad \text{(one-hot)}\;-w_t\log q_t.

희소 클래스에 큰 ww를 주면 그 클래스의 손실 기여가 커진다. 예를 들어 w=[1,1,5]w=[1,1,5], 정답이 2, q2=0.5q_2=0.5라면 5log0.5=3.466-5\log 0.5=3.466이다. 가중치가 없으면 같은 상황이 0.693이니, 5배만큼 손실이 커졌다. 가중치를 정하는 두 가지 표준이 있다. class-balanced weight는 클래스 빈도의 역수에 비례시킨다(wk1/nkw_k\propto 1/n_k). 또는 Cui et al. 2019의 "effective number" wk(1β)/(1βnk)w_k\propto(1-\beta)/(1-\beta^{n_k})를 쓴다. PyTorch는 CrossEntropyLoss(weight=...)로 받는다.

Focal Loss — CE의 확장(다리만)

weighted CE가 클래스 단위로 가중한다면, Focal Loss는 샘플 단위로 가중한다. 이미 잘 맞히는 쉬운 샘플의 손실을 깎고 어려운 샘플에 집중한다(02장 §2.3에서 놓은 다리를 한 번 더 짚는다). 정답 확률 pt=qtp_t=q_t에 대해 식은 이렇다.

FL(pt)=(1pt)γlogpt.\text{FL}(p_t)=-(1-p_t)^{\gamma}\log p_t.

원래 CE인 logpt-\log p_t에 변조항 (1pt)γ(1-p_t)^\gamma를 곱한 형태다. ptp_t가 1에 가까우면, 즉 쉬운 샘플이면 (1pt)γ0(1-p_t)^\gamma\to0이라 손실이 거의 사라진다. γ=2\gamma=2로 둔 표를 보면 효과가 한눈에 들어온다.

gamma=2:  p_t=0.9 -> CE=0.105, FL=0.0011   (easy: ~99% suppressed)
          p_t=0.5 -> CE=0.693, FL=0.173
          p_t=0.1 -> CE=2.303, FL=1.865    (hard: almost kept)

쉬운 샘플(pt=0.9p_t=0.9)은 손실이 0.105에서 0.0011로 99% 가까이 깎이고, 어려운 샘플(pt=0.1p_t=0.1)은 2.303에서 1.865로 거의 그대로 남는다. Focal Loss 자체의 유도와 anchor 불균형 맥락은 이 책의 범위 밖이다(별도 주제). 여기서는 "CE에 ptp_t 변조를 더한 확장"이라는 위치만 표시한다.


4.4 cross-entropy의 친척들

CE 하나만 보면 외따로 떨어진 손실 같지만, 실은 같은 logq-\log q 골격을 공유하는 가족이 있다. 표로 한자리에 모은 뒤 관계를 푼다.

손실/지표 식 (one-hot/단일항) 입력 언제 쓰나
CE (categorical) kyklogqk-\sum_k y_k\log q_k 로짓→softmax 다중클래스 단일라벨 분류
NLL loss logqt-\log q_t (qq는 log-softmax 출력) log-확률 log-softmax를 이미 적용한 뒤. NLLlog_softmax=CE\text{NLL}\circ\text{log\_softmax}=\text{CE}
KL divergence kyklogykqk\sum_k y_k\log\dfrac{y_k}{q_k} 두 분포 soft target·지식증류. KL(yq)=CE(y,q)H(y)\text{KL}(y\Vert q)=\text{CE}(y,q)-H(y)
BCE [ylogq+(1y)log(1q)]-[y\log q+(1-y)\log(1-q)] (채널별) 로짓→sigmoid 다중라벨(클래스 독립), 이진 분류
perplexity exp(CE)\exp(\text{CE}) 언어모델 평가 지표(손실 아님)

핵심 관계를 하나씩 본다.

  • CE = NLL ∘ log-softmax. 둘은 같은 양을 다른 데서 자른 것이다. NLL은 log-확률을 받아 정답 항만 떼고, log-softmax는 로짓을 log-확률로 바꾼다. 둘을 이으면 CE다. PyTorch에서도 CrossEntropyLoss = LogSoftmax + NLLLoss다.
  • CE = KL + H(y). 타깃 분포 yy가 고정이면 H(y)H(y)는 상수다. 그러면 CE를 최소화하는 것과 KL을 최소화하는 것이 같아진다(01장 §1.5의 중심 항등식을 타깃 쪽에서 본 것이다). soft target을 쓰는 증류에서는 보통 KL을 직접 최소화한다.
  • BCE와 CE의 갈림. CE는 클래스가 상호배타일 때 쓴다(softmax, 합이 1). BCE는 클래스가 서로 독립일 때 쓴다(채널마다 sigmoid). 한 샘플에 여러 라벨이 동시에 붙는 다중라벨이면 BCE다(02장 §2.4).
  • perplexity = exp(CE). 자연로그 CE면 PP=eCE\text{PP}=e^{\text{CE}}, 밑 2 CE(bit)면 PP=2CE\text{PP}=2^{\text{CE}}다. "평균 분기 계수"로 읽으면 직관이 선다. PP가 10이면 매 토큰마다 10지선다를 균등하게 찍는 것과 같은 혼란도라는 뜻이다. 예를 들어 CE가 1.5 nat이면 PP=e1.5=4.4817\text{PP}=e^{1.5}=4.4817이다.

아래 그림은 로짓 하나에서 이 친척들이 어떻게 갈라져 나오는지 보여 준다.

softmax

log_softmax

sigmoid per channel

-sum y log q

NLL: -log q_t

CE = NLL o log_softmax

sum y log(y/q)

CE = KL + H(y)

-[y log q + (1-y) log(1-q)]

exp(CE)

logits z

q (probs, sum=1)

log q

q_c independent

Cross-Entropy

NLL Loss

KL divergence

BCE (multi-label)

Perplexity (LM eval)

가운데 CE가 허브이고, softmax·log-softmax·sigmoid 어디로 가느냐에 따라 친척이 갈린다.


4.5 흔한 함정 체크리스트 (증상 → 원인 → 해결)

지금까지가 "맞게 짠 CE"였다면, 여기는 "조용히 틀린 CE"다. 에러도 안 나는데 학습이 안 되는 경우가 대부분이라, 증상·원인·해결을 짝으로 묶어 둔다.

softmax 이중 적용

  • 증상: 손실이 잘 안 떨어지고 학습이 둔하다. gradient가 약하고 정확도가 정체된다.
  • 원인: 모델 마지막에 이미 softmax(또는 log_softmax)를 넣었는데, 그걸 nn.CrossEntropyLoss(내부에서 또 log-softmax)에 통과시킨 경우다. softmax가 두 번 적용돼 분포가 균등 쪽으로 뭉개진다.
  • 해결: CrossEntropyLoss에는 마지막 layer의 raw 로짓을 그대로 넘긴다(softmax 제거). 확률화는 추론할 때만 한다. 이미 확률밖에 없다면 NLLLoss(log(prob))로 우회한다.

확률을 로짓 자리에 넣음

  • 증상: 손실이 비정상적으로 작거나 거의 0이고, gradient가 죽는다. 학습이 사실상 멈춘다.
  • 원인: CrossEntropyLossBCEWithLogitsLoss는 로짓을 기대하는데 [0,1][0,1] 확률을 넣은 경우다. 내부의 log-softmax나 log-sigmoid가 의미를 잃는다.
  • 해결: 로짓 입력 클래스에는 로짓을 넣는다. 확률밖에 없으면 BCELoss(로짓 아님)나 NLLLoss로 짝을 맞춘다.

라벨 형식 혼동 (정수 인덱스 vs one-hot)

  • 증상: 차원 불일치 에러가 나거나, 에러 없이 조용히 잘못된 손실이 나온다.
  • 원인: PyTorch CrossEntropyLoss는 타깃으로 정수 클래스 인덱스 (N,)도 받고 soft 확률 (N,C)도 받는다. 정수 인덱스 자리에 one-hot을 잘못 넣거나, soft target API와 혼동한 경우다.
  • 해결: 정수 인덱스 모드면 target.dtype=long, shape (N,)로 둔다. one-hot이나 soft가 필요하면 (N,C) float 타깃을 쓴다. 어느 모드를 쓰는지 명시적으로 확인한다.

클래스 차원 축 실수

  • 증상: 손실이 말이 안 되거나 shape 에러가 난다. 세그멘테이션·시퀀스에서 자주 본다.
  • 원인: CrossEntropyLoss는 클래스 축이 dim=1이어야 한다((N, C, d1, d2, ...)). (N, ..., C)처럼 클래스를 마지막에 두면 softmax가 엉뚱한 축에 적용된다.
  • 해결: 입력을 (N, C, *) 형태로 permutetranspose한다. 토큰 분류 등에서는 logits.reshape(-1, C), labels.reshape(-1)로 평탄화한다.

평균 vs 합 reduction 혼동

  • 증상: 손실 스케일과 학습률이 기대와 안 맞는다. 배치 크기를 바꾸면 동작이 달라진다.
  • 원인: reduction='sum'은 배치 크기에 비례해 손실과 gradient가 커진다. 'mean'(기본)은 배치 평균을 낸다. 두 모델을 비교할 때 reduction이 다르면 학습률도 달라야 한다.
  • 해결: 기본으로 'mean'을 쓴다. 마스킹(padding)할 때는 유효 토큰 수로 직접 나눠 일관성을 지킨다.

NaN 디버깅 순서 (권장 절차)

NaN이 떴을 때는 다음 순서로 좁혀 간다.

  1. 입력 점검: 로짓에 이미 infNaN이 들어 있나? 이전 layer나 데이터 문제다. torch.isnan(logits).any()로 본다.
  2. loss 자리 점검: 확률을 로짓 자리에 넣었나? softmax를 두 번 했나?
  3. 타깃 점검: 라벨이 [0, C-1] 범위인가? ignore_index 밖 값이나 음수 라벨이 끼었나?
  4. 학습률·gradient 폭발: LR을 낮춰 보고 clip_grad_norm_을 건다. 첫 NaN이 forward에서 났는지 backward에서 났는지 격리한다.
  5. 확률 경로면 클리핑: CE를 손으로 짰다면 log(q^)\log(\hat q)에서 q^=clip(q,ε,1ε)\hat q=\text{clip}(q,\varepsilon,1-\varepsilon)을 쓴다. 가능하면 §4.1의 로짓-융합 경로로 갈아탄다.

아래 그림은 이 순서를 분기로 그린 것이다. 위에서 아래로 한 갈래씩 배제해 가면 원인에 닿는다.

yes

no

yes

no

no

yes

NaN / inf in loss

logits already
NaN/inf?

fix upstream layer / data;
check exploding activations

probs fed where
logits expected?
double softmax?

pass raw logits;
remove extra softmax

labels in range,
correct dtype/shape,
class dim = 1?

fix target index/one-hot,
permute to (N,C,*)

lower LR, clip_grad_norm;
if hand-rolled CE: clip probs to [eps, 1-eps]


이 장의 실전 지식을 한 줄로 줄이면 이렇다. 로짓을 손실에 직접 넣어 log-sum-exp로 NaN을 막고, 과확신은 label smoothing(ε=0.1\varepsilon=0.1, 분해 검산 0.7407)으로 누르고, 불균형은 weighted CE나 Focal로 다룬다. 이제 흩어진 조각들 — softmax, CE, gradient qyq-y, 한 스텝 업데이트 — 을 한 예제 z=(2,1,0.1)z=(2,1,0.1)로 처음부터 끝까지 끊김 없이 이어 볼 차례다. 다음 05장이 그 walkthrough다.

출처