pytorchの誤差関数についてまとめた
pytorchで
- 最後の層の活性化関数
- 誤差関数
の適切な組み合わせについて度々わからなくなってその都度調べているので、ここでまとめておく。
2クラス分類
これは、sigmoidをつけているかどうかで2通りしかない。
活性化関数 | 損失関数 | yに対する成約 | 備考 |
---|---|---|---|
out = torch.sigmoid(x) | nn.BCELoss() | outと同じdtype・shapeにする | Binary Cross Entropy Loss |
out = x (活性化関数なし) |
nn.BCEWithLogitsLoss() | outと同じdtype・shapeにする |
多クラス分類
ここがちょっとややこしい。
活性化関数 | 損失関数 | yに対する成約 | 備考 |
---|---|---|---|
out = F.log_softmax(x, dim=-1) | nn.NLLLoss() | dtype=torch.int64、 shape=(n_samples, ) にする |
Negative Log Likelihood Loss |
out = x (活性化関数なし) |
nn.CrossEntropyLoss() | dtype=torch.int64、 shape=(n_samples, ) にする |
公式ドキュメントより。 This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class. |
基本的にpytorchでF.softmax
は使わないものと思ったら良い。