欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 高考 > matlab从pytorch中导入LeNet-5网络框架

matlab从pytorch中导入LeNet-5网络框架

2025/11/10 23:56:16 来源:https://blog.csdn.net/xy_optics/article/details/147004039  浏览:    关键词:matlab从pytorch中导入LeNet-5网络框架

文章目录

  • 一、Pytorch的LeNet-5网络准备
  • 二、保存用于导入matlab的model
  • 三、导入matlab
  • 四、用matlab训练这个导入的网络

这里演示从pytorch的LeNet-5网络导入到matlab中进行训练用。

一、Pytorch的LeNet-5网络准备

根据LeNet-5的结构图,我们可以写如下结构

import torch
import torch.nn as nnclass LeNet5(nn.Module):def __init__(self, num_classes=10):super(LeNet5, self).__init__()self.feature_extractor = nn.Sequential(# C1: Conv(1→6), 输出 28x28 → 6x28x28nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),nn.BatchNorm2d(6),nn.ReLU(inplace=True),# S2: MaxPool 2x2, 输出 6x14x14nn.MaxPool2d(kernel_size=2, stride=2),# C3: Conv(6→16), 输出 16x10x10nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.BatchNorm2d(16),nn.ReLU(inplace=True),# S4: MaxPool 2x2, 输出 16x5x5nn.MaxPool2d(kernel_size=2, stride=2),# C5: Conv(16→120), 输出 120x1x1(接近 flatten)nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),nn.BatchNorm2d(120),nn.ReLU(inplace=True))self.classifier = nn.Sequential(nn.Flatten(),  # [batch, 120]nn.Linear(120, 84),nn.BatchNorm1d(84),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(84, num_classes))def forward(self, x):x = self.feature_extractor(x)x = self.classifier(x)return xif __name__ == "__main__":model = LeNet5()model.eval()# 示例输入:MNIST 的图像大小 [1, 1, 28, 28]example_input = torch.randn(1, 1, 28, 28)# Tracingtraced_model = torch.jit.trace(model, example_input)# 保存traced_model.save("traced_lenet5.pt")print("✅ traced_lenet5.pt 已成功保存!")

二、保存用于导入matlab的model

在上面的代码中,我们有几行是产生trace model的,即

在这里插入图片描述

torch.jit.trace() 是 PyTorch 的一种 静态图(Static Graph)转换方法,它会:

  • 运行一次前向传播(forward),记录下所有的张量操作;
  • 然后构建一个不可变的计算图(graph),这个图就是所谓的 trace model

保存这个model后,我们就得到了traced_lenet5.pt这个文件。

三、导入matlab

导入matlab可以通过APPS里的Deep Network Designer,如下图

在这里插入图片描述

然后通过From PyTorch这个地方,导入刚才保存的网络结构

在这里插入图片描述

点开From PyTorch后, 我们可以复制刚才保存的traced_lenet5.pt这个文件的绝对路径用于导入,如下图

在这里插入图片描述

然后,import就会有,如下结果

在这里插入图片描述

然后,点击红色方框那部分,进行一下输入尺寸的修改

在这里插入图片描述

导入的这个网络框架,我们还要在末尾段加入softmax层,这个层在原pytorch框架里没写

在这里插入图片描述

这样,我们就完成了LeNet5从Pytorch里导入到matlab了。接着我们可以通过Analyze按钮分析这个网络,如下图

在这里插入图片描述

没有问题后,我们就可以Export这个网络到工作区了,输出的网络自动命名为net_1。

在这里插入图片描述

四、用matlab训练这个导入的网络

训练的代码如下

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  MaxEpochs = 4, ...  ValidationData = imdsValidation, ... ValidationFrequency = 30, ...  Plots = "training-progress", ...  Metrics = "accuracy", ...  Verbose = false); % 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9nexttile  % 在下一个网格位置准备绘图img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像imshow(img)  % 显示图像title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

上面用到的数据集是0-9的数字图片,如下图

在这里插入图片描述

训练的详细信息如下

在这里插入图片描述

预测结果显示

在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词