欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > 李沐老师动手深度学习pytorch版本的读取fashion_mnist数据并用AlexNet模型训练,其中修改为利用本地的数据集训练

李沐老师动手深度学习pytorch版本的读取fashion_mnist数据并用AlexNet模型训练,其中修改为利用本地的数据集训练

2025/9/23 8:36:14 来源:https://blog.csdn.net/weixin_44628096/article/details/141189787  浏览:    关键词:李沐老师动手深度学习pytorch版本的读取fashion_mnist数据并用AlexNet模型训练,其中修改为利用本地的数据集训练

李沐老师的d2l.load_data_fashion_mnist里面没有root参数,所以只会下载,不能利用本地的fashion_mnist数据。所以我使用torchvision 的datasets里面FashionMNIST方法,又由于李沐老师此处是利用AlexNet模型来训练fashion_mnist数据,所以我们需要调整数据集的大小

导入必要的库和模块

import torch 
from torch import nn 
from d2l import torch as d2l
import numpy as np  
from torch.utils.data import Dataset, DataLoader  
import torchvision
import torchvision.transforms as transforms

转换数据

由于我们需要在加载数据同时定义数据转换,可以使用transforms.Compose来组合多个转换操作,使用Resize方法来调整图片大小,使其可以符合AlexNet的输入尺寸

 transform = transforms.Compose([  transforms.Resize((224, 224)),  # 将图片调整为224x224的大小  transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0., 1.]  # 可以选择添加transforms.Normalize进行进一步的标准化操作  
])

加载本地的fashion_mnist数据集

注意root参数,PyTorch 期望 FashionMNIST 数据集在指定的根目录下以特定的方式组织,确保下载好的文件存在于你指定的 root 路径下的 FashionMNIST/raw/ 文件夹中。可以参考我这篇博客:pytorch加载本地文件的root设置

通过train=False/True来设置训练集和测试集
通过设置download=True/False来确定找不到本地数据集的时候是否从网络下载
通过transform 指定特征和标签转换

#加载本地数据集,注意root参数
minist_train = torchvision.datasets.FashionMNIST(root='F:\\deeplearning\\fashion_mnist',train=True,download=False,transform=transform)
minist_test = torchvision.datasets.FashionMNIST(root='F:\\deeplearning\\fashion_mnist',train=False,download=False,transform=transform)print(type(minist_train))
print(len(minist_train),len(minist_test))

创建数据加载器,方便批次化处理

# 创建数据加载器  
batch_size = 128  
train_iter = DataLoader(minist_train, batch_size=batch_size, shuffle=True)  
test_iter = DataLoader(minist_test, batch_size=batch_size, shuffle=False) 

大致实现AlexNet网络架构

李沐老师在pytorch版本的动手深度学习中实现的模型

net = nn.Sequential(
# 这里使用一个11*11的更大窗口来捕捉对象。
# 同时,步幅为4,以减少输出的高度和宽度。
# 另外,输出通道的数目远大于LeNet
nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
# 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数
nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
# 使用三个连续的卷积层和较小的卷积窗口。
# 除了最后的卷积层,输出通道的数量进一步增加。
# 在前两个卷积层之后,汇聚层不用于减少输入的高度和宽度
nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Flatten(),
# 这里,全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合
nn.Linear(6400, 4096), nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Dropout(p=0.5),
# 最后是输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000
nn.Linear(4096, 10))

设置学习率,epoch并在GPU上训练

lr, num_epochs = 0.01, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

结果如下

在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词