欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 锐评 > python第51天

python第51天

2025/11/3 21:01:00 来源:https://blog.csdn.net/zdy1263574688/article/details/148640131  浏览:    关键词:python第51天

1.读取数据

使用CIFAR-10图像数据

import torch
from torchvision import datasets, transforms# 定义图像预处理流程
image_transform = transforms.Compose([transforms.ToTensor(),  # 将PIL图像转换为张量transforms.Normalize(mean=(0.5, 0.5, 0.5),  # RGB三通道均值std=(0.5, 0.5, 0.5))   # RGB三通道标准差
])# 获取训练数据集
trainset = datasets.CIFAR10(root='./data',  # 数据集存储路径train=True,     # 使用训练集transform=image_transform,download=True   # 如果本地不存在则下载
)# 获取测试数据集
testset = datasets.CIFAR10(root='./data',train=False,    # 使用测试集transform=image_transform,download=True
)# 配置数据加载器
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=128,    # 每批样本数量shuffle=True       # 训练时打乱顺序
)test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=128,shuffle=False      # 测试时保持原始顺序
)

2.模型建立

(1)建立CNN模型

import torch
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.fc1 = nn.Linear(32 * 8 * 8, 256)self.fc2 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))  # 16x16x16x = self.pool(self.relu(self.conv2(x)))  # 32x8x8x = x.view(-1, 32 * 8 * 8)x = self.relu(self.fc1(x))x = self.fc2(x)return x

 @浙大疏锦行

热搜词