在torch.nn
中提供了诸多的损失函数,但有时需要根据实际问题自定义损失函数,其流程和自定义网络模型一样,通过继承nn.Module
类,并提供前向计算forward
方法,就可以像积木一样放入整体模型中,参与整个动态流图的计算。
下面是一个简单的例子:
import torch
import torch.nn.functional as F
# 定义一个二元交叉熵损失函数
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, preds, targets, masks=None):
preds, targets = preds.float(), targets.float()
if not masks:
masks = torch.ones_like(preds)
loss = F.binary_cross_entropy_with_logits(preds, targets, reduction="mean", weight=masks)
return loss