欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > 基于PyTorch的食物图像分类实战:从数据处理到模型训练

基于PyTorch的食物图像分类实战:从数据处理到模型训练

2025/5/4 19:15:47 来源:https://blog.csdn.net/2201_75345884/article/details/147687980  浏览:    关键词:基于PyTorch的食物图像分类实战:从数据处理到模型训练

基于PyTorch的食物图像分类实战:从数据处理到模型训练

在深度学习领域,图像分类是一个经典且应用广泛的任务。无论是在电商平台的商品分类、医疗影像诊断,还是在农业作物识别等场景中,图像分类模型都发挥着重要作用。本文将以食物图像分类为例,详细介绍如何使用PyTorch框架构建一个卷积神经网络(CNN)模型,实现对20种不同食物的分类。

一、数据准备与预处理

1.1 数据集介绍

本项目使用的食物图像数据集包含20类不同的食物,如八宝粥、巴旦木、白萝卜等。数据集分为训练集和测试集,分别用于模型的训练和评估。每类食物都有一定数量的图像样本,这些图像涵盖了不同角度、光照条件下的食物外观,为模型的训练提供了丰富的多样性。

1.2 数据路径处理

在代码中,最初有一个用于处理数据路径并生成数据索引文件的函数(虽然被注释掉,但可以理解其功能):

# def train_test_file(root,dir):
#     file_txt=open(dir+'.txt','w')
#     path=os.path.join(root,dir)
#     for roots,directories,files in os.walk(path):
#         if len(directories) !=0:
#             dirs=directories
#         else:
#             now_dir=roots.split('\\')
#             for file in files:
#                 path_1=os.path.join(roots,file)
#                 print(path_1)
#                 file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')
#     file_txt.close()
# root=r'.\食物分类\food_dataset'
# train_dir='train'
# test_dir='test'
# train_test_file(root,train_dir)
# train_test_file(root,test_dir)

该函数通过遍历指定目录下的文件和文件夹,将图像文件的路径及其对应的类别索引写入文本文件中。这样做的目的是为后续的数据加载提供便利,使得数据加载器能够快速获取图像和标签信息。

1.3 数据转换与加载

data_transforms={'train':transforms.Compose([transforms.Resize([256, 256]),transforms.ToTensor(),]),'valid':transforms.Compose([transforms.Resize([256, 256]),transforms.ToTensor(),])
}

这里定义了数据转换操作,对于训练集和验证集(测试集),都将图像尺寸调整为256x256像素,并将图像转换为PyTorch的张量格式。图像尺寸的统一有助于在神经网络中进行高效的计算,而张量格式是PyTorch处理数据的标准形式。

class food_dataset(Dataset):def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image=Image.open(self.imgs[idx])if self.transform:image=self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,labeltraining_data=food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
test_data=food_dataset(file_path='testda.txt', transform=data_transforms['valid'])

通过自定义food_dataset类继承自Dataset,实现了数据加载的逻辑。在初始化函数中,读取数据索引文件,将图像路径和标签分别存储在列表中。__len__方法返回数据集的样本数量,__getitem__方法根据索引获取图像和对应的标签,并对图像应用转换操作。最后,创建训练集和测试集的实例,为后续的数据加载做好准备。

train_dataloader=DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=64,shuffle=True)

使用DataLoader将训练集和测试集打包成批次,每个批次包含64个样本,并且在训练过程中对数据进行打乱(shuffle=True),这样可以增加数据的随机性,有助于模型更好地学习数据特征,避免模型过拟合。

二、模型构建

2.1 设备选择

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device") 

这段代码用于检测当前设备是否支持GPU(CUDA)或苹果M系列芯片的GPU(MPS),如果都不支持,则使用CPU进行计算。将模型和数据移动到合适的设备上,可以显著提高训练和推理的速度。

2.2 卷积神经网络定义

class CNN(nn.Module):def __init__(self):   super(CNN,self).__init__()   self.conv1=nn.Sequential(    nn.Conv2d(in_channels=3,   out_channels=16,   kernel_size=5,     stride=1,          padding=2,        ),                     nn.ReLU(),            nn.MaxPool2d(kernel_size=2),        )self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),  nn.ReLU(),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  nn.ReLU())self.out = nn.Linear(64 * 128 * 128, 20)def forward(self,x):    x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=x.view(x.size(0),-1)output=self.out(x)return output

定义了一个名为CNN的类,继承自nn.Module。在__init__方法中,构建了网络的结构:

  • 卷积层(nn.Conv2d:通过设置不同的输入通道数、输出通道数、卷积核大小、步长和填充参数,逐步提取图像的特征。例如,conv1层以3通道(彩色图像)作为输入,输出16个特征图,使用5x5的卷积核。
  • 激活函数层(nn.ReLU:为网络引入非线性,增强模型的表达能力。
  • 池化层(nn.MaxPool2d:在conv1之后使用最大池化层,对特征图进行下采样,减少数据量和计算量。
  • 全连接层(nn.Linear:将卷积层和池化层提取的特征进行线性变换,输出20个节点,对应20种食物类别。

forward方法定义了数据在网络中的前向传播路径,确保数据依次经过各层处理,最终输出预测结果。

三、模型训练与评估

3.1 损失函数与优化器

loss_fn=nn.CrossEntropyLoss()   
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)   

选择nn.CrossEntropyLoss作为损失函数,它适用于多分类任务,能够计算模型预测结果与真实标签之间的差距。torch.optim.Adam作为优化器,通过调整模型的参数来最小化损失函数,学习率设置为0.001,这个值需要根据实际情况进行调整,以平衡模型的训练速度和收敛效果。

3.2 训练函数

def train(dataloader,model,loss_fn,optimizer):model.train()   batch_size_num=1for X,y in dataloader:       X,y=X.to(device),y.to(device)    pred=model(X)    loss=loss_fn(pred,y)    optimizer.zero_grad()    loss.backward()    optimizer.step()loss=loss.item()   if batch_size_num %50 ==0:print(f'loss:{loss:>7f} [number:{batch_size_num}]')batch_size_num+=1

在训练函数中:

  • model.train()将模型设置为训练模式,此时模型中的一些层(如Dropout层,如果有使用的话)会按照训练规则工作。
  • 遍历训练数据加载器中的每一个批次数据,将数据和标签移动到指定设备上。
  • 通过模型进行预测,计算损失值。
  • 使用optimizer.zero_grad()清零梯度,loss.backward()进行反向传播计算梯度,optimizer.step()根据梯度更新模型参数。
  • 每隔50个批次,打印当前的损失值,以便观察训练过程中损失的变化趋势,了解模型的训练情况。

3.3 测试函数

def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader) model.eval() test_loss,correct=0,0with torch.no_grad():    for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)test_loss+=loss_fn(pred,y).item()   correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)   b=(pred.argmax(1)==y).type(torch.float)result = zip(pred.argmax(1).tolist(), y.tolist())for i in result:print(f"当前测试的结果为:{food_type[i[0]]}\t,当前真实的结果为:{food_type[i[1]]}")test_loss /=num_batchescorrect /=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')

测试函数中:

  • model.eval()将模型设置为测试模式,关闭一些在训练过程中起作用但在测试时不需要的操作(如Dropout)。
  • 使用with torch.no_grad()上下文管理器,关闭梯度计算,因为在测试阶段不需要更新模型参数,这样可以节省计算资源。
  • 遍历测试数据,计算每个批次的损失值并累加,同时统计预测正确的样本数量。此外,还将预测结果和真实标签进行对应输出,方便直观查看模型的预测效果。
  • 最后计算并打印测试集上的平均损失和准确率,评估模型在未知数据上的性能表现。

3.4 执行训练与测试

epoch=10
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)

通过设置训练轮数为10,循环调用训练函数进行模型训练,每完成一轮训练,模型对数据的特征学习就更深入一些。训练完成后,调用测试函数评估模型在测试集上的性能,查看模型的分类准确率和损失情况,判断模型是否达到预期效果。

四、总结与展望

通过以上步骤,我们成功使用PyTorch构建了一个食物图像分类模型,从数据处理、模型构建到训练和评估,完整地展示了深度学习图像分类任务的实现过程。当然,当前模型还有很多可以优化的地方,例如调整网络结构、增加数据增强操作、调整超参数等,以进一步提高模型的性能。未来,我们可以将这种图像分类模型应用到更多实际场景中,如智能厨房设备中的食物识别、饮食健康管理中的食物记录等,为生活带来更多便利和创新。希望本文能为大家在深度学习图像分类领域的学习和实践提供有益的参考。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词