GANのdiscriminatorでsoftplus関数が使われている理由

概要

GANで使われる Discriminator の損失計算の実装を見てみると,softplus関数が使われていることがあります.
これは,sigmoid 関数で binary cross entropy を計算することと同じです.

各関数の定義

  • softplus 関数


\begin{align}
\text{softplus}(x)=\log{(1+\exp{(x)})}
\end{align}

  • sigmoid 関数


\begin{align}
f(x)=\frac{1}{1 + \exp{(-x)} }
\end{align}

  • binary cross entropy (BCE) 関数


\begin{align}
g(x, t)= - t \log{x} - (1-t) \log{(1-x)}
\end{align}

BCE と softplus の関係

ラベル  t \in \{0, 1\} の場合分けをして BCE を計算します. discriminator の出力を  y=h(x) としたときの  g(f(y), 0)


\begin{align}
g(f(y), 0)&=\log{(1-f(y))} \\
&= - \log{ \left( 1 - \frac{1}{1 + \exp{(-y)} } \right) } \\
&= - \log{ \left( \frac{ 1 + \exp{(-y)} - 1 }{ 1 + \exp{(-y)} } \right) }\\
&= - \log{ \left(  \frac{ \exp{(-y)} }{ 1 + \exp{(-y)} } \right) } \\
&= - \log{ \left(  \frac{1}{ \exp{(y)} + 1} \right) } \\
&= \log{ \left(  \frac{1}{ \exp{(y)} + 1} \right)^{-1}} \\
&= \log{ \left( 1 + \exp{(y)} \right)} = \text{softplus}(y)
\end{align}
となります.
また,  g(f(y), 1) は,

\begin{align}
g(f(y), 1)&= - \log{(f(y))} \\
&= - \log{ \left( \frac{1}{1 + \exp{(-y)} } \right) } \\
&= \log{ \left(  \frac{1}{ 1+ \exp{(-y)} } \right)^{-1}} \\
&= \log{ \left( 1 + \exp{(-y)} \right)} = \text{softplus}(-y)
\end{align}
となります.

GAN における discriminator の学習は,真の入力  x_{\text{real}} に対しては discriminator の出力  y_{\text{real}}=h(x_{\text{real}}) を sigmoid 関数に入れた値  f(y_{\text{real}}) を 1 に, 偽の入力  x_{\text{fake}} に対しては discriminator の出力  y_{\text{fake}}=h(x_{\text{fake}}) を sigmoid 関数に入れた値  f(y_{\text{fake}}) を 0 に近づけるようにします.実際に chainer の DCGAN 実装のソースコード https://github.com/chainer/chainer/blob/master/examples/dcgan/updater.py#L13-L19 を見ると,

    def loss_dis(self, dis, y_fake, y_real):
        batchsize = len(y_fake)
        L1 = F.sum(F.softplus(-y_real)) / batchsize
        L2 = F.sum(F.softplus(y_fake)) / batchsize
        loss = L1 + L2
        chainer.report({'loss': loss}, dis)
        return loss

真の入力  x_{\text{real}} については  f(y_{\text{real}}) を 1 に近づけるために  g(f(y_{\text{real}}), 1) = \text{softplus}(-y_{\text{real}}) を損失関数に用いており、 偽の入力  x_{\text{fake}} についても同様に  f(y_{\text{fake}}) を 0 に近づけるために  g(f(y_{\text{fake}}), 0) = \text{softplus}(y_{\text{fake}}) を損失関数に用いていることが分かります.

Generatorの学習は,生成した画像を discriminator に入れたときの出力が 1(本物)になるように学習したいので  \text{softplus}(-x_{\text{fake}}) を使用します.
直感的に考えると生成した画像は活性化関数を通っているとはいえ,ある種の正規化されたピクセルを持っているわけで,各ピクセルが1になるようにするということは,つまり画像を真っ白にすることなのでヤバそうなのですが,Discriminatorの学習もあるので結局良い感じにパラメータが学習されるのかな,と自分は解釈しています. かなり馬鹿なことを書いてました.ちゃんとソースコードを読むと,generatorで生成した  x に対してdiscriminatorを通して損失を計算しているので,画像を真っ白にするとかそんなことはやっていないです.)

追記 (2022/01/08)

  1. この話は某インターンでメンターの方に相談したところ実はこうだよ、と教えてもらい復習がてら書いたのですが、意外と見られていて少し申し訳ない気持ちになっています。
  2. twitter でGANのsoftplusに関するやりとりを見かけたのですが、 下の記事にも同様のことが書かれていることを知りました。自分は softplus を使って計算することで愚直に計算するより nan を避けやすくなることを知りませんでした。

www.monthly-hack.com