GANのdiscriminatorでsoftplus関数が使われている理由
概要
GANで使われる Discriminator の損失計算の実装を見てみると,softplus関数が使われていることがあります.
これは,sigmoid 関数で binary cross entropy を計算することと同じです.
各関数の定義
- softplus 関数
- sigmoid 関数
- binary cross entropy (BCE) 関数
BCE と softplus の関係
ラベル の場合分けをして BCE を計算します. discriminator の出力を としたときの は
となります.また, は,となります.
GAN における discriminator の学習は,真の入力 に対しては discriminator の出力 を sigmoid 関数に入れた値 を 1 に, 偽の入力 に対しては discriminator の出力 を sigmoid 関数に入れた値 を 0 に近づけるようにします.実際に chainer の DCGAN 実装のソースコード https://github.com/chainer/chainer/blob/master/examples/dcgan/updater.py#L13-L19 を見ると,
def loss_dis(self, dis, y_fake, y_real): batchsize = len(y_fake) L1 = F.sum(F.softplus(-y_real)) / batchsize L2 = F.sum(F.softplus(y_fake)) / batchsize loss = L1 + L2 chainer.report({'loss': loss}, dis) return loss
真の入力 については を 1 に近づけるために を損失関数に用いており、 偽の入力 についても同様に を 0 に近づけるために を損失関数に用いていることが分かります.
Generatorの学習は,生成した画像を discriminator に入れたときの出力が 1(本物)になるように学習したいので を使用します.
(直感的に考えると生成した画像は活性化関数を通っているとはいえ,ある種の正規化されたピクセルを持っているわけで,各ピクセルが1になるようにするということは,つまり画像を真っ白にすることなのでヤバそうなのですが,Discriminatorの学習もあるので結局良い感じにパラメータが学習されるのかな,と自分は解釈しています. かなり馬鹿なことを書いてました.ちゃんとソースコードを読むと,generatorで生成した に対してdiscriminatorを通して損失を計算しているので,画像を真っ白にするとかそんなことはやっていないです.)