用 PyTorch 轻松实现 MNIST 手写数字识别
引言
在深度学习领域,MNIST 数据集就像是 “Hello World” 级别的经典入门项目。它包含大量手写数字图像及对应标签,非常适合新手学习如何搭建和训练神经网络模型。本文将基于 PyTorch 框架,详细拆解如何完成 MNIST 手写数字识别任务,让你轻松入门深度学习实践。
1. 数据加载与预处理
首先,我们利用torchvision
库中的datasets.MNIST
函数来加载 MNIST 数据集。代码如下:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortraining_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)
在这段代码中,root="data"
指定了数据集的存储路径;train=True
表示加载训练集,train=False
则用于加载测试集;download=True
确保如果本地没有数据集,会自动从网络下载;transform=ToTensor()
将图像数据转换为 PyTorch 能够处理的张量格式,同时将像素值从 0-255 归一化到 0-1 区间 。
为了直观感受数据集,我们还可以绘制几张图像:
python
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()
上述代码从训练集中选取了 9 张图像,绘制出图像及其对应的标签,方便我们对数据有更直观的认识。
接下来,使用DataLoader
对数据集进行封装,以方便后续按批次训练和测试:
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
batch_size=64
表示每次训练或测试时,模型会同时处理 64 个样本,这有助于提高计算效率和训练稳定性。
2. 模型构建
我们定义一个简单的全连接神经网络类NeuralNetwork
:
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28 * 28, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.relu(x)x = self.hidden2(x)x = torch.relu(x)x = self.out(x)return x
在__init__
函数中,nn.Flatten()
用于将输入的二维图像张量展平为一维向量;nn.Linear()
是全连接层,我们依次构建了两个隐藏层和一个输出层,输出层有 10 个神经元,对应 0-9 这 10 个数字类别。在forward
函数中,定义了数据的前向传播过程,包括线性变换和激活函数torch.relu()
的应用,激活函数能为模型引入非线性,使其能够学习更复杂的模式。
然后将模型移动到合适的设备(GPU 或 CPU)上:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = NeuralNetwork().to(device)
print(model)
3. 训练与测试
3.1 训练函数
def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
在训练函数中,首先通过model.train()
将模型设置为训练模式,然后遍历数据加载器中的每一批数据。对于每一批数据,将数据和标签移动到指定设备上,进行前向传播计算预测值,通过损失函数nn.CrossEntropyLoss()
计算预测值与真实标签之间的损失。接着使用optimizer.zero_grad()
清空梯度,loss.backward()
进行反向传播计算梯度,最后optimizer.step()
根据计算得到的梯度更新模型参数。每训练 100 个批次,打印当前的损失值。
3.2 测试函数
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return test_loss, correct
测试函数中,先将模型设置为评估模式model.eval()
,关闭一些在训练过程中使用的操作(如 Dropout)。在测试过程中,不需要计算梯度,因此使用with torch.no_grad()
。通过遍历测试数据加载器,计算模型预测结果与真实标签之间的损失,并统计正确预测的样本数量,最后计算平均损失和准确率并打印输出。
3.3 执行训练与测试
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
我们选择交叉熵损失函数nn.CrossEntropyLoss()
作为损失计算方式,Adam 优化器torch.optim.Adam()
来更新模型参数,学习率设置为 0.01。通过循环 10 个训练周期,不断训练模型,训练完成后进行测试,得到模型在测试集上的准确率和平均损失。
4. 总结
通过上述步骤,我们基于 PyTorch 完成了 MNIST 手写数字识别任务。从数据加载、模型构建,到训练和测试,每个环节都紧密相连。这个项目不仅让我们熟悉了 PyTorch 的基本使用流程,也对神经网络的工作原理有了更直观的认识。后续我们可以通过调整模型结构、超参数等方式进一步优化模型性能,探索更多深度学习的奥秘。