目录
一、过拟合检测与诊断
1. 过拟合判断标准
2. 同步监控实现代码
3. 过拟合典型表现
二、模型保存与加载技术
1. 三种保存方式对比
2. 具体实现代码
a) 仅保存权重
b) 保存完整模型
c) 保存Checkpoint
3. 生产环境最佳实践
三、早停策略实现
1. 早停算法原理
2. 完整实现代码
3. 早停策略调参指南
四、综合训练模板
完整训练流程示例
五、常见问题解决方案
1. 模型保存加载问题
2. 早停策略问题
3. 过拟合解决方案
一、过拟合检测与诊断
1. 过拟合判断标准
2. 同步监控实现代码
def train_and_validate(model, train_loader, val_loader, epochs):train_losses, val_losses = [], []for epoch in range(epochs):# 训练阶段model.train()train_loss = 0for batch in train_loader:loss = train_step(model, batch)train_loss += loss.item()# 验证阶段model.eval()val_loss = 0with torch.no_grad():for batch in val_loader:outputs = model(batch)val_loss += criterion(outputs, targets).item()# 记录指标train_loss /= len(train_loader)val_loss /= len(val_loader)train_losses.append(train_loss)val_losses.append(val_loss)# 实时打印对比print(f"Epoch {epoch+1}:")print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")print(f" Train Acc: {calc_accuracy(model, train_loader):.2f}% | "f"Val Acc: {calc_accuracy(model, val_loader):.2f}%")return train_losses, val_losses
3. 过拟合典型表现
现象 | 训练集 | 验证集 |
---|---|---|
正常拟合 | 指标持续改善 | 指标同步改善 |
过拟合 | 指标持续改善 | 指标先升后降 |
欠拟合 | 指标改善缓慢 | 指标改善缓慢 |
二、模型保存与加载技术
1. 三种保存方式对比
保存类型 | 代码实现 | 文件内容 | 适用场景 |
---|---|---|---|
仅权重 | torch.save(model.state_dict(), path) | 模型参数 | 推理部署 |
完整模型 | torch.save(model, path) | 模型结构+参数 | 快速保存 |
Checkpoint | torch.save({...}, path) | 全训练状态 | 训练中断恢复 |
2. 具体实现代码
a) 仅保存权重
# 保存
torch.save(model.state_dict(), 'model_weights.pth')# 加载
model = MyModel()
model.load_state_dict(torch.load('model_weights.pth'))
b) 保存完整模型
# 保存
torch.save(model, 'full_model.pth')# 加载
model = torch.load('full_model.pth')
c) 保存Checkpoint
# 保存
checkpoint = {'epoch': current_epoch,'model_state': model.state_dict(),'optim_state': optimizer.state_dict(),'loss': best_loss
}
torch.save(checkpoint, 'checkpoint.pth')# 加载
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
start_epoch = checkpoint['epoch'] + 1
3. 生产环境最佳实践
# 安全保存方案
def safe_save(model, path):from tempfile import NamedTemporaryFilewith NamedTemporaryFile('wb', suffix='.pth', delete=False) as f:torch.save(model.state_dict(), f)temp_path = f.nameimport shutilshutil.move(temp_path, path) # 原子操作替换旧文件# 跨设备加载
def load_on_device(path, device):return torch.load(path, map_location=device)
三、早停策略实现
1. 早停算法原理
2. 完整实现代码
class EarlyStopping:def __init__(self, patience=5, delta=0):self.patience = patienceself.delta = delta # 最小改善阈值self.counter = 0self.best_score = Noneself.early_stop = Falsedef __call__(self, val_loss):score = -val_loss # 损失越小越好if self.best_score is None:self.best_score = scoreelif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter}/{self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.counter = 0# 使用示例
early_stopping = EarlyStopping(patience=7, delta=0.001)for epoch in range(epochs):val_loss = validate(model, val_loader)early_stopping(val_loss)if early_stopping.early_stop:print("Early stopping triggered")break
3. 早停策略调参指南
参数 | 影响 | 推荐值 | 调整建议 |
---|---|---|---|
patience | 容忍退化轮数 | 5-10 | 数据量大可增加 |
delta | 最小改善量 | 0-0.01 | 指标波动大时增加 |
restore_best | 恢复最佳权重 | True | 推荐启用 |
四、综合训练模板
完整训练流程示例
def train_with_early_stopping(model, train_loader, val_loader, max_epochs=100):optimizer = torch.optim.Adam(model.parameters())criterion = nn.CrossEntropyLoss()early_stopping = EarlyStopping(patience=5)best_loss = float('inf')for epoch in range(max_epochs):# 训练阶段model.train()for batch in train_loader:# ... 训练步骤 ...# 验证阶段model.eval()val_loss = 0with torch.no_grad():for batch in val_loader:# ... 验证步骤 ...# 早停检查early_stopping(val_loss)if early_stopping.early_stop:break# 保存最佳模型if val_loss < best_loss:best_loss = val_losstorch.save({'epoch': epoch,'model_state': model.state_dict(),'optim_state': optimizer.state_dict(),'loss': val_loss}, 'best_model.pth')# 恢复最佳模型checkpoint = torch.load('best_model.pth')model.load_state_dict(checkpoint['model_state'])return model
五、常见问题解决方案
1. 模型保存加载问题
问题 | 解决方案 |
---|---|
加载架构不匹配 | 先初始化模型再加载权重 |
设备不匹配 | 使用map_location 参数 |
版本不兼容 | 保持PyTorch版本一致 |
2. 早停策略问题
现象 | 调整方向 |
---|---|
过早停止 | 增加patience或delta |
停止过晚 | 减小patience |
指标波动大 | 增加delta或使用平滑处理 |
3. 过拟合解决方案
方法 | 实现方式 |
---|---|
数据增强 | torchvision.transforms |
正则化 | L2正则/ Dropout |
早停 | 验证集监控 |
简化模型 | 减少参数量 |
@浙大疏锦行