データサイエンティストAnchorBluesのブログ

とある民間企業で数学とコンピュータサイエンスをやっている研究員のブログです。

pytorchの誤差関数についてまとめた

pytorchで

  • 最後の層の活性化関数
  • 誤差関数

の適切な組み合わせについて度々わからなくなってその都度調べているので、ここでまとめておく。

2クラス分類

これは、sigmoidをつけているかどうかで2通りしかない。

活性化関数 損失関数 yに対する成約 備考
out = torch.sigmoid(x) nn.BCELoss() outと同じdtype・shapeにする Binary Cross Entropy Loss
out = x
(活性化関数なし)
nn.BCEWithLogitsLoss() outと同じdtype・shapeにする

多クラス分類

ここがちょっとややこしい。

活性化関数 損失関数 yに対する成約 備考
out = F.log_softmax(x, dim=-1) nn.NLLLoss() dtype=torch.int64、
shape=(n_samples, )
にする
Negative Log Likelihood Loss
out = x
(活性化関数なし)
nn.CrossEntropyLoss() dtype=torch.int64、
shape=(n_samples, )
にする
公式ドキュメントより。
This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

基本的にpytorchでF.softmaxは使わないものと思ったら良い。

参考URL