自己写了个封装PyTorch深度学习训练流程的函数,实现了根据输入参数训练模型并可视化训练过程的功能,可以方便快捷地检验一个模型的效果,有助于提高选择模型架构、优化超参数等工作的效率。发出来供大家参考,如有不足之处,欢迎批评讨论。
分类是人工智能的一个非常重要的应用,这篇文章分享的函数适用于实现分类的深度学习模型,包括以下功能:
- 根据输入的数据集、模型、优化器、损失函数等参数训练一个分类模型;
- 使用visdom可视化训练过程,实时输出精确度曲线、损失曲线、混淆矩阵和ROC曲线;
- 支持二分类和多分类;
- 输入数据集支持形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型,可以方便灵活地调用torch内置数据集或自定义数据集;
- 支持使用GPU加速深度学习模型的训练。
废话不多说,先来看下输出效果:


深度学习的完整流程通常包括以下几个步骤:
- 收集数据
- 数据预处理
- 选择模型
- 训练模型
- 评估模型
- 超参数调优
- 测试模型
本函数封装了训练模型和评估模型的步骤,包括:
- 若数据集为(X,y)形式则分离训练集和测试集(测试集占20%),数据标准化,封装训练集和测试集;
- 将训练集和测试集设置为加载器;
- 遍历训练集加载器,计算每一批次的输出和损失,并反向传播更新神经网络参数;
- 每迭代100次评估一下模型,用测试集数据计算并画出精确度曲线、损失曲线、混淆矩阵和ROC曲线。
代码如下:
from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_scoreimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdomfrom typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizerdef classify(data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],model: nn.Module,optimizer: Optimizer,criterion: nn.Module,scaler: Optional[TransformerMixin] = None,batch_size: int = 64,epochs: int = 10,device: Optional[torch.device] = None
) -> nn.Module:"""分类任务的训练函数。:param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型:param model: 分类模型:param optimizer: 优化器:param criterion: 损失函数:param scaler: 数据标准化器:param batch_size: 批大小:param epochs: 训练轮数:param device: 训练设备:return: 训练好的分类模型"""if isinstance(data[0], np.ndarray):X, y = data# 处理类别classes = np.unique(y)classes_str = [str(i) for i in classes]num_classes = len(classes)# 分离训练集和测试集,指定随机种子以便复现X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化if scaler is not None:X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)# 转换为tensorX_train = torch.from_numpy(X_train.astype(np.float32))X_test = torch.from_numpy(X_test.astype(np.float32))y_train = torch.from_numpy(y_train.astype(np.int64))y_test = torch.from_numpy(y_test.astype(np.int64))# 将X和y封装成TensorDatasettrain_dataset = TensorDataset(X_train, y_train)test_dataset = TensorDataset(X_test, y_test)elif isinstance(data[0], Dataset):train_dataset, test_dataset = dataclasses = list(train_dataset.class_to_idx.values())classes_str = train_dataset.classesnum_classes = len(classes)else:raise ValueError('Unsupported data type')train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)model.to(device)vis = Visdom()# 训练模型for epoch in range(epochs):for step, (batch_x_train, batch_y_train) in enumerate(train_loader):batch_x_train = batch_x_train.to(device)batch_y_train = batch_y_train.to(device)# 前向传播output = model(batch_x_train)loss = criterion(output, batch_y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()niter = epoch * len(train_loader) + step + 1 # 计算迭代次数if niter % 100 == 0:# 评估模型model.eval()with torch.no_grad():eval_dict = {'test_loss': [],'test_acc': [],'test_cm': [],'test_roc': [],}for batch_x_test, batch_y_test in test_loader:batch_x_test = batch_x_test.to(device)batch_y_test = batch_y_test.to(device)test_output = model(batch_x_test)predicted = torch.argmax(test_output, 1)test_predicted_tuple = (batch_y_test.numpy(), predicted.numpy())# 计算并记录损失、精确度、混淆矩阵、ROC曲线eval_dict['test_loss'].append(criterion(test_output, batch_y_test))eval_dict['test_acc'].append(accuracy_score(*test_predicted_tuple))eval_dict['test_cm'].append(confusion_matrix(*test_predicted_tuple, labels=classes))if num_classes == 2:# eval_dict['test_roc']形状为(len,(fpr,tpr),3)eval_dict['test_roc'].append(roc_curve(*test_predicted_tuple)[:2])else:# 多分类ROC曲线需要one-hot编码y_test_one_hot, predicted_one_hot = map(partial(label_binarize, classes=classes), test_predicted_tuple)fpr_list = []tpr_list = []for i in range(num_classes):fpr, tpr, _ = roc_curve(y_test_one_hot[:, i], predicted_one_hot[:, i])# 无(fpr,tpr)数据点时,插值填充(0,0)数据点if len(fpr) != 3:fpr = np.insert(fpr, 0, 0)tpr = np.insert(tpr, 0, 0)fpr_list.append(fpr)tpr_list.append(tpr)# eval_dict['test_roc']形状为(len,(fpr,tpr),num_classes,3)eval_dict['test_roc'].append((fpr_list, tpr_list))# 画出损失曲线vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),win='loss',update='append',opts=dict(title='Loss', legend=['train_loss', 'test_loss']),)# 画出精确度曲线train_acc = accuracy_score(batch_y_train.numpy(), torch.argmax(output, 1).numpy())vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.tensor((train_acc, np.mean(eval_dict['test_acc']))).unsqueeze(0),win='accuracy',update='append',opts=dict(title='Accuracy', legend=['train_acc', 'test_acc'], ytickmin=0, ytickmax=1),)# 画出混淆矩阵vis.heatmap(X=np.add.reduce(eval_dict['test_cm']),win='confusion_matrix',opts=dict(title='Confusion Matrix', columnnames=classes_str, rownames=classes_str),)# 画出ROC曲线test_roc_arr = np.array(eval_dict['test_roc'])zeros_df = pd.DataFrame({'fpr': [0], 'tpr': [0]}) # 用于填充的(0,0)数据点ones_df = pd.DataFrame({'fpr': [1], 'tpr': [1]}) # 用于填充的(1,1)数据点if num_classes == 2:plot_arr = test_roc_arr[:, :, 1] # 提取(fpr,tpr)数据点,形状为(len,(fpr,tpr))cats = pd.qcut(plot_arr[:, 0], q=10, labels=False, duplicates='drop') # 按fpr大小分成10个数据一样多的区间groups = pd.DataFrame(plot_arr, columns=['fpr', 'tpr']).groupby(cats).mean() # 计算每个区间的平均值,形状为(10,(fpr,tpr))plot_df = pd.concat([zeros_df, groups, ones_df]) # 头添加(0,0),尾添加(1,1)数据点,形状为(12,(fpr,tpr))x = plot_df['fpr']Y = plot_df['tpr']else:plot_df_list = []plot_arr = test_roc_arr[:, :, :, 1].swapaxes(1, 2) # 提取(fpr,tpr)数据点并换轴,形状为(len,num_classes,(fpr,tpr))for i in range(num_classes):cats = pd.qcut(plot_arr[:, i, 0], q=10, labels=False, duplicates='drop')groups = pd.DataFrame(plot_arr[:, i, :], columns=['fpr', 'tpr']).groupby(cats).mean() # 形状为(10,(fpr,tpr))plot_df = pd.concat([zeros_df, groups, ones_df]) # 形状为(12,(fpr,tpr))add_num = 12 - len(plot_df)# 长度不足12时,插值填充(0,0)数据点if add_num > 0:plot_df = pd.concat([zeros_df] * add_num + [plot_df])plot_df_list.append(plot_df) # 形状为(num_classes,12,(fpr,tpr))plot_arr_sum = np.stack(plot_df_list, axis=1) # 形状为(12,num_classes,(fpr,tpr))x = plot_arr_sum[:, :, 0]Y = plot_arr_sum[:, :, 1]vis.line(X=x,Y=Y,win='ROC',opts=dict(title='ROC', legend=classes_str),)return model
注意:代码运行前要先在命令行输入python -m visdom.server,在浏览器中打开提供的链接:
成功运行的效果如下: