pytorch の NLLLoss の挙動

Negative Log Likelihood(NLL) Lossという名前から自分が予測する関数の挙動と実際の関数の挙動が毎回異なって少し困っていたので、具体例を交えつつ挙動の確認を行う。

環境

  • pytorch stable (1.4)

概要

  • NLL Loss は対数は取らず負の符号は取り、ベクトルの重み付き平均 or 和を計算する。

  • 関数名に対数が付いているのは、何らかの確率に対して対数を取ったものを入力とすることが期待されているから。

詳細

pytorchでは nll_loss 関数と NLLLoss クラスが定義されている。NLLLoss クラスは nn.Module を継承しており、 forward メソッドで nll_loss を呼び出すので中身は殆ど同じである。 公式ドキュメントの NLLLoss クラスの所に書かれている通り、NLLLossクラスのインスタンスで計算される損失関数は、


\begin{align}
\ell(x, y) &= \begin{cases}
    \sum_{n=1}^N \frac{1}{\sum_{m=1}^N w_{y_m}}l_{n}, & \text{if reduction = 'mean'}\\
    \sum_{n=1}^N l_{n}, & \text{if reduction = 'sum'}
  \end{cases}\\
l_n &= -w_{y_n}x_{n, y_n}
\end{align}

で定義される。

ここで, N, C は正の整数でそれぞれデータ数、クラス数を表す。  x N\times C の実行列(たとえばニューラルネットワークの出力など)、 y N 次元のベクトルで各要素  y_n は整数 0 \leq y_n \lt C を取る。 w C 次元の実ベクトルだが、デフォルト値は全て 1 なので値の指定をしない限りは今は特に気にする必要はない。 なので 基本的に


\begin{align}
l_n = -x_{n, y_n}
\end{align}

と考えれば良い。よって損失関数  \ell(x, y) は、


\begin{align}
\ell(x, y) &= \begin{cases}
    \frac{1}{N} \sum_{n=1}^N -x_{n, y_n}, & \text{if reduction = 'mean'}\\
    \sum_{n=1}^N -x_{n, y_n}, & \text{if reduction = 'sum'}
  \end{cases}
\end{align}

となる。

ある程度ソースコードを読むと,実際の処理はtorch._C._nn.nll_loss という関数に処理を投げていることがわかる。これはどこにあるかというと、https://discuss.pytorch.org/t/where-is-torch-c-nn-nll-loss/9769 に書かれている通り, ClassNLLCriterion に投げられているらしい。が,結局これは cuda に投げているだけなので、この関数の動作を理解するという目的からすると本質的な情報ではない。

ある程度の挙動が分かってきたので,具体例と実際の計算結果から NLLLoss の挙動を把握する。

例 1:和の計算

クラス数  C=3、データ数  N=1 、重み  w_c は全て 1 , reduction = 'sum'とする。

 x = (0.1, 0.8, 0.1), y = (1) を与えたとき、 \ell(x, y) はどうなるか?

答え


\begin{align}
\ell(x, y) &= \sum_{i=1}^N l_n \\
&= l_1\\
&= -1\cdot x_{1, 1} = -0.8
\end{align}
C = 3
N = 3
x = torch.Tensor([
                  [0.1, 0.8, 0.1],
                ])
target = torch.Tensor([1]).to(dtype=torch.long)
output = F.nll_loss(x, target)
print(output) # tensor(-0.8000)

例 2:平均の計算

クラス数  C=3、データ数  N=3 w_c は全て 1 、 reduction = 'mean'とする。

 x = \left(
    \begin{array}{ccc}
      0.1 & 0.8 & 0.1 \\
      0.6 & 0.2 & 0.2 \\
      0.25 & 0.25 & 0.5
    \end{array}
  \right)

 y = (1, 0, 2) を与えたとき、 \ell(x, y) はどうなるか?

答え


\begin{align}
\ell(x, y) &= \sum_{i=1}^N l_n = \frac{1}{3}l_1+\frac{1}{3}l_2+\frac{1}{3}l_3 \\
&= -\frac{1}{3} \cdot (x_{1, 1} + x_{2, 0} + x_{3, 2}) \approx -0.633
\end{align}
C = 3
N = 3
x = torch.Tensor([
                  [0.1, 0.8, 0.1],
                  [0.6, 0.2, 0.2],
                  [0.25, 0.25, 0.5],
                ])
target = torch.Tensor([1, 0, 2]).to(dtype=torch.long)
output = F.nll_loss(x, target)
print(output) # tensor(-0.6333)

例3:重み付き平均の計算

クラス数  C=3、 データ数  N=3 \mathbf{w} = (0.1, 0.1, 0.2)reduction = 'mean'とする。

 x = \left(
    \begin{array}{ccc}
      0.1 & 0.8 & 0.1 \\
      0.6 & 0.2 & 0.2 \\
      0.25 & 0.25 & 0.5
    \end{array}
  \right)

 y = (1, 0, 2) を与えたとき、 \ell(x, y) はどうなるか?

答え


\begin{align}
\sum_{n=1}^N w_{y_n} &= w_{1} + w_{0} + w_{2} = 0.4\\
\ell(x, y) &= \sum_{i=1}^N l_n = \frac{1}{0.4} l_1+\frac{1}{0.4}l_2+\frac{1}{0.4}l_3\\
 &= -\frac{1}{0.4} \cdot (0.1\cdot x_{1, 1} + 0.1\cdot x_{2, 0} + 0.2\cdot x_{3, 2}) = -0.6
\end{align}
C = 3
N = 3
x = torch.Tensor([
                  [0.1, 0.8, 0.1],
                  [0.6, 0.2, 0.2],
                  [0.25, 0.25, 0.5],
                ])
weight = torch.Tensor([0.1, 0.1, 0.2])
target = torch.Tensor([1, 0, 2]).to(dtype=torch.long)
output = F.nll_loss(x, target, weight=weight)
print(output) # -0.6

NLLLoss の使い所

xをネットワークの出力、yをラベルとしたとき、nll_loss(log_softmax(x), y) のように使われることが多い。これは cross_entropy(softmax(x), y) の計算結果と同じである。

なぜこのような実装になっているのか

NLLLoss が使われるタイミングとして、上にも書いたように log_softmax 関数の後に適用されることが多いです。 素直に考えると、そんなことはせず単純に softmax 関数を出力に適用した後にクロスエントロピーの式に投げて計算すれば良いのでは?と思いますが、softmaxとlogを別々に適用すると数値計算が不安定になるようです。 なぜかというと、softmax は


\begin{align}
\text{softmax}(x_i) = \frac{\exp{(x_i)}}{\sum_j \exp{(x_j)}}
\end{align}

に従って計算されますが、 x_i が大きな値を取る場合に  \exp(x_i) がオーバーフローするからです。

これを防ぐために最終的な計算結果を変えることなくオーバーフローしないように  \exp の中の数値を変換します。ある実数  b に対して  \exp{(x_i})=\exp{(x_i)}\exp{(-b)}\exp{(b)}=\exp{(x_i-b)}\exp{(b)} が成り立つため、


\begin{align}
\text{softmax}(x_i) = \frac{\exp{(x_i-b)}\exp{(b)}}{\sum_j \exp{(x_j-b)}\exp{(b)}} = \frac{\exp{(x_i-b)}}{\sum_j \exp{(x_j-b)}}
\end{align}

が得られます。そこで  b=\max_{i}(x_i) のように  b を定めると、全ての  i に対して  \exp{(x_i - b)} \leq 1 が成り立ち、オーバーフローを回避できます。 これを exp-normalize trickと呼んでいるようです。

このような exp-normalize によってオーバーフローの問題を回避できるわけですが、log_softmax でも同様の方法によってオーバーフローを回避できます。


\begin{align}
\text{log_softmax}(x_i) &= \log\left(\frac{\exp{(x_i)}}{\sum_j \exp{(x_j)}}\right)\\
&= x_i - \log \left( \sum_j \exp(x_j) \right) \\
&= x_i - \log \left( \sum_j \exp(x_j - b)\exp(b) \right) \\
&= x_i - \log \left( \sum_j \exp(x_j - b) \right) - b \\
\end{align}

 b は exp-normalize と同じように決定すれば良いです。この場合は、log-sum-exp と呼ばれているようです(そのままですね)。

Exp-normalize と log-sum-exp どちらが良いのかという話ですが、log-sum-exp を使った方が exp-normalize と比べて exp の呼び出し回数が減るので全体的な学習時間も少しだけ早くなることが期待され、少しメリットがあるような気がします。

クロスエントロピーの計算においては NLLLoss を使った計算をする方が良いのは分かりましたが、NLLLoss がなぜこのような名前のまま放置されているのかはあまり自分もよく分かっていないままです...。

The input given through a forward call is expected to contain log-probabilities of each class

公式ドキュメントには logを取った確率の入力が期待されていると書かれており、この入力の log 要素が残ったものなのかなと思っています。

所感

振り返ると、log と softmax を分けると数値計算が不安定になるというのは結構嘘な気がしますが(別に exp-normalize でも良さそうなので)(追記: Justification for LogSoftmax being better than Log(Softmax) - PyTorch Forums に書いてあるように、別々で適用した場合は exp-normalize した場合であっても log にほぼほぼ 0 の値が入るような状況であれば発散するので、数値計算は不安定になります。)、log_softmax にすることでそれなりのメリットがあるということが知れて良かったです。

参照

  1. https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss
  2. https://discuss.pytorch.org/t/why-there-is-no-log-operator-in-implementation-of-torch-nn-nllloss/16610
  3. https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/