PyTorch 提供了丰富的损失函数用于不同类型的机器学习任务。下面我将全面介绍 PyTorch 中的主要损失函数,包括它们的数学表达式、使用场景和实际代码示例。
一、回归任务损失函数
1. MSELoss (均方误差损失)
torch.nn.MSELoss(reduction='mean')
-
公式:
loss = (x - y)²
-
特点: 对异常值敏感,惩罚大误差更重
-
应用: 一般回归问题
criterion = nn.MSELoss() loss = criterion(outputs, targets)
2. L1Loss (平均绝对误差)
torch.nn.L1Loss(reduction='mean')
-
公式:
loss = |x - y|
-
特点: 对异常值更鲁棒
-
应用: 需要减少异常值影响的回归问题
3. SmoothL1Loss (Huber损失)
torch.nn.SmoothL1Loss(reduction='mean', beta=1.0)
公式:
-
特点: 结合L1和L2的优点
-
应用: 目标检测(如Faster R-CNN)
二、分类任务损失函数
1. CrossEntropyLoss (交叉熵损失)
torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')
-
公式:
loss = -log(exp(x[class]) / ∑exp(x[j]))
-
特点: 自动应用softmax
-
应用: 多分类问题
criterion = nn.CrossEntropyLoss() loss = criterion(outputs, targets) # targets是类别索引
2. BCELoss (二元交叉熵)
torch.nn.BCELoss(weight=None, reduction='mean')
-
公式:
-
要求: 输入需经过sigmoid(0-1之间)
-
应用: 二分类问题
3. BCEWithLogitsLoss
torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
-
特点: 结合sigmoid和BCELoss,数值更稳定
-
应用: 推荐用于二分类问题
三、其他重要损失函数
1. NLLLoss (负对数似然损失)
torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')
-
要求: 输入需经过log-softmax
-
应用: 通常与LogSoftmax配合使用
2. KLDivLoss (KL散度)
torch.nn.KLDivLoss(reduction='mean')
-
公式:
loss = y * (log(y) - x)
-
应用: 衡量概率分布差异,如VAE
3. MarginRankingLoss
torch.nn.MarginRankingLoss(margin=0.0, reduction='mean')
-
应用: 排序任务
4. TripletMarginLoss
torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False)
-
应用: 度量学习,人脸识别
5. CosineEmbeddingLoss
torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
-
应用: 相似度学习
四、损失函数选择指南
任务类型 | 推荐损失函数 | 备注 |
---|---|---|
回归问题 | MSELoss/L1Loss/SmoothL1Loss | 根据异常值情况选择 |
二分类 | BCEWithLogitsLoss | 优于BCELoss |
多分类 | CrossEntropyLoss | 最常用 |
多标签分类 | BCEWithLogitsLoss | 每个类别独立判断 |
分布匹配 | KLDivLoss | 如VAE |
相似度学习 | TripletMarginLoss/CosineEmbeddingLoss | 度量学习 |
五、自定义损失函数示例
class CustomLoss(nn.Module):def __init__(self, weight=1.0):super().__init__()self.weight = weightdef forward(self, inputs, targets):# 计算L1损失l1_loss = torch.abs(inputs - targets)# 计算特殊惩罚项penalty = torch.where(targets > inputs, 2.0 * l1_loss, l1_loss)# 组合损失return (penalty.mean() + self.weight * l1_loss.mean())