欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 产业 > 用 PyTorch 轻松实现 MNIST 手写数字识别

用 PyTorch 轻松实现 MNIST 手写数字识别

2025/5/6 2:11:52 来源:https://blog.csdn.net/ycy200377/article/details/147701156  浏览:    关键词:用 PyTorch 轻松实现 MNIST 手写数字识别

用 PyTorch 轻松实现 MNIST 手写数字识别

引言

在深度学习领域,MNIST 数据集就像是 “Hello World” 级别的经典入门项目。它包含大量手写数字图像及对应标签,非常适合新手学习如何搭建和训练神经网络模型。本文将基于 PyTorch 框架,详细拆解如何完成 MNIST 手写数字识别任务,让你轻松入门深度学习实践。

1. 数据加载与预处理

首先,我们利用torchvision库中的datasets.MNIST函数来加载 MNIST 数据集。代码如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortraining_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)

在这段代码中,root="data"指定了数据集的存储路径;train=True表示加载训练集,train=False则用于加载测试集;download=True确保如果本地没有数据集,会自动从网络下载;transform=ToTensor()将图像数据转换为 PyTorch 能够处理的张量格式,同时将像素值从 0-255 归一化到 0-1 区间 。

为了直观感受数据集,我们还可以绘制几张图像:

python

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

上述代码从训练集中选取了 9 张图像,绘制出图像及其对应的标签,方便我们对数据有更直观的认识。

接下来,使用DataLoader对数据集进行封装,以方便后续按批次训练和测试:

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

batch_size=64表示每次训练或测试时,模型会同时处理 64 个样本,这有助于提高计算效率和训练稳定性。

2. 模型构建

我们定义一个简单的全连接神经网络类NeuralNetwork

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28 * 28, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.relu(x)x = self.hidden2(x)x = torch.relu(x)x = self.out(x)return x

__init__函数中,nn.Flatten()用于将输入的二维图像张量展平为一维向量;nn.Linear()是全连接层,我们依次构建了两个隐藏层和一个输出层,输出层有 10 个神经元,对应 0-9 这 10 个数字类别。在forward函数中,定义了数据的前向传播过程,包括线性变换和激活函数torch.relu()的应用,激活函数能为模型引入非线性,使其能够学习更复杂的模式。

然后将模型移动到合适的设备(GPU 或 CPU)上:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = NeuralNetwork().to(device)
print(model)

3. 训练与测试

3.1 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

在训练函数中,首先通过model.train()将模型设置为训练模式,然后遍历数据加载器中的每一批数据。对于每一批数据,将数据和标签移动到指定设备上,进行前向传播计算预测值,通过损失函数nn.CrossEntropyLoss()计算预测值与真实标签之间的损失。接着使用optimizer.zero_grad()清空梯度,loss.backward()进行反向传播计算梯度,最后optimizer.step()根据计算得到的梯度更新模型参数。每训练 100 个批次,打印当前的损失值。

3.2 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return test_loss, correct

测试函数中,先将模型设置为评估模式model.eval(),关闭一些在训练过程中使用的操作(如 Dropout)。在测试过程中,不需要计算梯度,因此使用with torch.no_grad()。通过遍历测试数据加载器,计算模型预测结果与真实标签之间的损失,并统计正确预测的样本数量,最后计算平均损失和准确率并打印输出。

3.3 执行训练与测试

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

我们选择交叉熵损失函数nn.CrossEntropyLoss()作为损失计算方式,Adam 优化器torch.optim.Adam()来更新模型参数,学习率设置为 0.01。通过循环 10 个训练周期,不断训练模型,训练完成后进行测试,得到模型在测试集上的准确率和平均损失。

4. 总结

通过上述步骤,我们基于 PyTorch 完成了 MNIST 手写数字识别任务。从数据加载、模型构建,到训练和测试,每个环节都紧密相连。这个项目不仅让我们熟悉了 PyTorch 的基本使用流程,也对神经网络的工作原理有了更直观的认识。后续我们可以通过调整模型结构、超参数等方式进一步优化模型性能,探索更多深度学习的奥秘。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词