文章目录
- 前言
- MATLAB环境配置
- 一、MATLAB 神经网络工具箱概述
- 二、核心功能与 API
- 1. 网络创建与训练
- 2. 数据处理
- 3. 训练与评估
- 4. 可视化
- 三、典型应用场景
- 四、实战案例:手写数字识别(MNIST)
- 五、高级技巧
- 迁移学习:
- 超参数优化:
- 模型解释:
- GPU 加速:
前言
以下是关于 MATLAB 神经网络的系统总结,涵盖核心功能、应用场景及典型案例:
MATLAB环境配置
MATLAB下载安装教程:https://blog.csdn.net/tyatyatya/article/details/147879353
一、MATLAB 神经网络工具箱概述
MATLAB 提供了全面的神经网络工具,支持从基础网络到深度学习的各类模型,主要包括:
- 基础神经网络:前馈网络、径向基函数网络、递归网络等。
- 深度学习:卷积神经网络 (CNN)、循环神经网络 (RNN)、LSTM、Transformer 等。
- 预训练模型:AlexNet、ResNet、VGG 等,支持迁移学习。
- 可视化工具:网络结构可视化、训练过程监控、决策边界绘制。
- 部署功能:模型导出为 C/C++、Python、TensorFlow 格式,或部署到 GPU / 嵌入式设备。
二、核心功能与 API
1. 网络创建与训练
% 创建前馈神经网络(分类任务)
net = patternnet(hiddenSizes); % hiddenSizes为隐含层神经元数量% 创建CNN(图像分类)
layers = [imageInputLayer([224 224 3])convolution2dLayer(3, 16, 'Padding', 'same')reluLayermaxPooling2dLayer(2, 'Stride', 2)fullyConnectedLayer(3)softmaxLayerclassificationLayer
];
2. 数据处理
% 数据划分
net.divideFcn = 'dividerand'; % 随机划分
net.divideParam.trainRatio = 0.7;
net.divideParam.valRatio = 0.15;
net.divideParam.testRatio = 0.15;% 归一化
[X_norm, ps] = mapminmax(X); % 将数据归一化到[-1,1]
3. 训练与评估
% 训练网络
[net, tr] = train(net, X, T);% 评估性能
Y = net(X);
accuracy = mean(round(Y) == T); % 分类准确率
mse = perform(net, T, Y); % 均方误差
4. 可视化
view(net) % 可视化网络结构
plotperform(tr) % 绘制训练性能曲线
三、典型应用场景
四、实战案例:手写数字识别(MNIST)
% 加载数据
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...'nndatasets', 'DigitDataset');
digitData = imageDatastore(digitDatasetPath, ...'IncludeSubfolders', true, ...'LabelSource', 'foldernames');% 划分训练集和测试集
[imdsTrain, imdsTest] = splitEachLabel(digitData, 0.8, 'randomized');% 创建简单CNN
layers = [imageInputLayer([28 28 1])convolution2dLayer(5, 20)reluLayermaxPooling2dLayer(2, 'Stride', 2)convolution2dLayer(5, 50)reluLayermaxPooling2dLayer(2, 'Stride', 2)fullyConnectedLayer(500)reluLayerdropoutLayer(0.5)fullyConnectedLayer(10)softmaxLayerclassificationLayer
];% 设置训练参数
options = trainingOptions('sgdm', ...'InitialLearnRate', 0.001, ...'MaxEpochs', 10, ...'MiniBatchSize', 128, ...'Shuffle', 'every-epoch', ...'ValidationData', imdsTest, ...'ValidationFrequency', 30, ...'Verbose', false, ...'Plots', 'training-progress');% 训练网络
net = trainNetwork(imdsTrain, layers, options);% 评估性能
YPred = classify(net, imdsTest);
YTest = imdsTest.Labels;
accuracy = mean(YPred == YTest);
fprintf('测试集准确率: %.2f%%\n', accuracy*100);% 可视化预测结果
figure
idx = randperm(numel(YTest), 16);
for i = 1:16subplot(4,4,i)I = readimage(imdsTest, idx(i));imshow(I)title(sprintf('预测: %d', YPred(idx(i))));
end
五、高级技巧
迁移学习:
% 使用预训练ResNet-50
net = resnet50;
lgraph = layerGraph(net);
% 修改最后几层适应新任务
超参数优化:
% 使用hyperparameterOptimization
results = hyperparameterOptimization(fun, params, opts);
模型解释:
% 使用Deep Network Analyzer
analyzeNetwork(net);
GPU 加速:
% 设置GPU训练
options = trainingOptions('sgdm', 'ExecutionEnvironment', 'gpu');