欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > python打卡 DAY 37 早停策略和模型权重的保存

python打卡 DAY 37 早停策略和模型权重的保存

2025/9/14 15:05:58 来源:https://blog.csdn.net/2301_76970300/article/details/148799751  浏览:    关键词:python打卡 DAY 37 早停策略和模型权重的保存

目录

一、过拟合检测与诊断

1. 过拟合判断标准

2. 同步监控实现代码

3. 过拟合典型表现

二、模型保存与加载技术

1. 三种保存方式对比

2. 具体实现代码

a) 仅保存权重

b) 保存完整模型

c) 保存Checkpoint

3. 生产环境最佳实践

三、早停策略实现

1. 早停算法原理

2. 完整实现代码

3. 早停策略调参指南

四、综合训练模板

完整训练流程示例

五、常见问题解决方案

1. 模型保存加载问题

2. 早停策略问题

3. 过拟合解决方案


一、过拟合检测与诊断

1. 过拟合判断标准

2. 同步监控实现代码

def train_and_validate(model, train_loader, val_loader, epochs):train_losses, val_losses = [], []for epoch in range(epochs):# 训练阶段model.train()train_loss = 0for batch in train_loader:loss = train_step(model, batch)train_loss += loss.item()# 验证阶段model.eval()val_loss = 0with torch.no_grad():for batch in val_loader:outputs = model(batch)val_loss += criterion(outputs, targets).item()# 记录指标train_loss /= len(train_loader)val_loss /= len(val_loader)train_losses.append(train_loss)val_losses.append(val_loss)# 实时打印对比print(f"Epoch {epoch+1}:")print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")print(f"  Train Acc: {calc_accuracy(model, train_loader):.2f}% | "f"Val Acc: {calc_accuracy(model, val_loader):.2f}%")return train_losses, val_losses

3. 过拟合典型表现

现象训练集验证集
正常拟合指标持续改善指标同步改善
过拟合指标持续改善指标先升后降
欠拟合指标改善缓慢指标改善缓慢

二、模型保存与加载技术

1. 三种保存方式对比

保存类型代码实现文件内容适用场景
仅权重torch.save(model.state_dict(), path)模型参数推理部署
完整模型torch.save(model, path)模型结构+参数快速保存
Checkpointtorch.save({...}, path)全训练状态训练中断恢复

2. 具体实现代码

a) 仅保存权重
# 保存
torch.save(model.state_dict(), 'model_weights.pth')# 加载
model = MyModel()
model.load_state_dict(torch.load('model_weights.pth'))
b) 保存完整模型
# 保存
torch.save(model, 'full_model.pth')# 加载
model = torch.load('full_model.pth')
c) 保存Checkpoint
# 保存
checkpoint = {'epoch': current_epoch,'model_state': model.state_dict(),'optim_state': optimizer.state_dict(),'loss': best_loss
}
torch.save(checkpoint, 'checkpoint.pth')# 加载
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
start_epoch = checkpoint['epoch'] + 1

3. 生产环境最佳实践

# 安全保存方案
def safe_save(model, path):from tempfile import NamedTemporaryFilewith NamedTemporaryFile('wb', suffix='.pth', delete=False) as f:torch.save(model.state_dict(), f)temp_path = f.nameimport shutilshutil.move(temp_path, path)  # 原子操作替换旧文件# 跨设备加载
def load_on_device(path, device):return torch.load(path, map_location=device)

三、早停策略实现

1. 早停算法原理

2. 完整实现代码

class EarlyStopping:def __init__(self, patience=5, delta=0):self.patience = patienceself.delta = delta  # 最小改善阈值self.counter = 0self.best_score = Noneself.early_stop = Falsedef __call__(self, val_loss):score = -val_loss  # 损失越小越好if self.best_score is None:self.best_score = scoreelif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter}/{self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.counter = 0# 使用示例
early_stopping = EarlyStopping(patience=7, delta=0.001)for epoch in range(epochs):val_loss = validate(model, val_loader)early_stopping(val_loss)if early_stopping.early_stop:print("Early stopping triggered")break

3. 早停策略调参指南

参数影响推荐值调整建议
patience容忍退化轮数5-10数据量大可增加
delta最小改善量0-0.01指标波动大时增加
restore_best恢复最佳权重True推荐启用

四、综合训练模板

完整训练流程示例

def train_with_early_stopping(model, train_loader, val_loader, max_epochs=100):optimizer = torch.optim.Adam(model.parameters())criterion = nn.CrossEntropyLoss()early_stopping = EarlyStopping(patience=5)best_loss = float('inf')for epoch in range(max_epochs):# 训练阶段model.train()for batch in train_loader:# ... 训练步骤 ...# 验证阶段model.eval()val_loss = 0with torch.no_grad():for batch in val_loader:# ... 验证步骤 ...# 早停检查early_stopping(val_loss)if early_stopping.early_stop:break# 保存最佳模型if val_loss < best_loss:best_loss = val_losstorch.save({'epoch': epoch,'model_state': model.state_dict(),'optim_state': optimizer.state_dict(),'loss': val_loss}, 'best_model.pth')# 恢复最佳模型checkpoint = torch.load('best_model.pth')model.load_state_dict(checkpoint['model_state'])return model

五、常见问题解决方案

1. 模型保存加载问题

问题解决方案
加载架构不匹配先初始化模型再加载权重
设备不匹配使用map_location参数
版本不兼容保持PyTorch版本一致

2. 早停策略问题

现象调整方向
过早停止增加patience或delta
停止过晚减小patience
指标波动大增加delta或使用平滑处理

3. 过拟合解决方案

方法实现方式
数据增强torchvision.transforms
正则化L2正则/ Dropout
早停验证集监控
简化模型减少参数量

@浙大疏锦行

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词