欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > 零基础入门机器学习 -- 第十章深度学习入门

零基础入门机器学习 -- 第十章深度学习入门

2025/5/1 12:35:59 来源:https://blog.csdn.net/qq_41611586/article/details/145646475  浏览:    关键词:零基础入门机器学习 -- 第十章深度学习入门

学习内容

  • 为什么要学习深度学习?
  • 什么是神经网络?它如何像人类一样学习?
  • 怎么做?手把手训练一个AI,学会识别手写数字!
  • 如何验证?训练完成后,我们如何知道模型学得好不好?

🧠 你能让计算机学会“看”吗?

假设你是一名AI导师,你的任务是教会计算机认识手写数字。你希望:

  1. 你的AI能看到 “5” ,并说出:“我确信这就是5!”。
  2. 你的AI能看到 “3” ,并说出:“这是3,虽然有点歪,但我还是能认出来!”。

💡 问题来了

  • 普通计算机只能执行规则,但不能自己总结规律。
  • 人类学习是靠经验,看多了就能区分5和3。
  • 我们该如何让计算机像人类一样学习?

🎯 解决方案:深度学习

深度学习的核心思想是:

  • 让计算机自己学习,而不是编写规则
  • 用神经网络模拟人脑,让计算机像人一样思考

那么,让我们正式进入这个AI 训练任务,一起训练出一位 “AI 学生” ! 🚀


第一步:让计算机“看”——认识 MNIST 数据集

📌 什么是 MNIST?

MNIST 是手写数字数据集,包含 60,000 张手写数字(0~9)的图片。

为什么用它?
简单:适合初学者
数据标准化:所有图片都是 28×28 像素,格式一致
广泛使用:AI 研究人员都用它做基准测试

🖼️ 让我们看看数据

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt# 加载 MNIST 数据集
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 显示前 5 张手写数字
plt.figure(figsize=(10, 4))
for i in range(5):plt.subplot(1, 5, i+1)plt.imshow(train_images[i], cmap="gray")plt.title(f"Label: {train_labels[i]}")plt.axis("off")
plt.show()

示例输出:
在这里插入图片描述

📌 你看到的是:

  • 这些是 AI 需要学习的“教材”
  • 每张图片是 28×28 的像素矩阵
  • 计算机不会“看图”,但可以“看数据”。

💡 计算机眼中的“5”

[  0   0   0  ...   0   0   0 ]
[  0   0  10  ...   0   0   0 ]
[  0   0  80  ...  20   0   0 ]
...

每个像素值(0~255)表示亮度:

  • 0 = 黑色
  • 255 = 白色
  • 中间值 = 灰色

第二步:让计算机“思考”——构建神经网络

📌 什么是神经网络?

神经网络的目标是:

  • 让计算机从 MNIST 数据中学习。
  • 逐步学会区分 0-9 之间的数字。

比喻:

  • 神经网络 = 人脑
  • 神经元 = 人脑中的神经细胞
  • 每一层 = 人的大脑皮层(感知、分析、决策)

🛠️ 代码:搭建神经网络

model = keras.Sequential([keras.layers.Flatten(input_shape=(28, 28)),  # 1. 输入层keras.layers.Dense(128, activation='relu'),  # 2. 隐藏层keras.layers.Dense(10, activation='softmax') # 3. 输出层
])

💡 逐行解析:
1️⃣ Flatten(input_shape=(28,28)):把 28×28 的图像“摊平”1D 向量(784 维)。
2️⃣ Dense(128, activation='relu'):创建 128 个神经元,用于学习图片特征。
3️⃣ Dense(10, activation='softmax'):输出 10 个类别(0-9) 的概率。


第三步:让计算机“学习”——训练神经网络

📌 为什么要归一化?

train_images = train_images / 255.0
test_images = test_images / 255.0

💡 归一化的好处

  • 让数值在 0-1 之间,训练更快更稳定。
  • 让所有像素值有 相同的计算权重

🛠️ 代码:训练神经网络

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])model.fit(train_images, train_labels, epochs=5)

💡 代码解析
1️⃣ optimizer='adam':让模型更快优化
2️⃣ loss='sparse_categorical_crossentropy':衡量模型预测的好坏
3️⃣ epochs=5:让模型学习5轮

📌 训练时,计算机会自动调整“神经元连接的强度”,从而不断提高识别准确率


第四步:让计算机“考试”——测试模型

test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"测试准确率: {test_acc}")

示例输出:

测试准确率: 0.977400004863739

💡 测试的意义

  • 训练数据 ≠ 真实世界
  • 我们要在 测试数据(AI从没见过的图片) 上评估它的能力。

📌 如果测试准确率高(>90%),说明AI真的学会了!


第五步:让计算机“认字”——模型预测

import numpy as np
predictions = model.predict(test_images)# 显示 AI 预测的前 5 张图片
plt.figure(figsize=(10, 4))
for i in range(5):plt.subplot(1, 5, i+1)plt.imshow(test_images[i], cmap="gray")#用于返回 预测结果的索引(也就是 AI 认为的数字)plt.title(f"预测: {np.argmax(predictions[i])}")plt.axis("off")
plt.show()

示例输出:
在这里插入图片描述

💡 如果 AI 预测正确

  • 恭喜!你的 AI 真的学会了识别手写数字!
  • 你已经完成了第一个深度学习项目! 🎉

🎯 结果分析:训练效果如何?

当我们训练神经网络时,模型会随着训练轮次(epochs)不断学习,但怎么知道它是否真的在变聪明呢?🤔 这时候我们就需要训练曲线(Training Curve),来观察模型的学习进度。

import matplotlib.pyplot as plthistory = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='测试准确率')
plt.xlabel('训练轮次')
plt.ylabel('准确率')
plt.legend()
plt.show()

示例输出:
在这里插入图片描述

📌 代码解析
1️⃣ history = model.fit(…)
model.fit(…) 的返回值 history 记录了每一轮训练的准确率和损失。
2️⃣ history.history[‘accuracy’]
训练准确率,表示模型在训练集上的表现。
如果曲线不断上升,说明模型在学习。
3️⃣ history.history[‘val_accuracy’]
测试准确率,表示模型在未见过的测试集上的表现。
如果 训练准确率高,但测试准确率低,说明模型过拟合了。

📌 如果曲线平稳上升,说明模型正在学习!


🎯 你现在理解了吗?

这篇文章通过故事主线

  1. 为什么要训练 AI?(让计算机识别手写数字)
  2. 是什么?(神经网络如何学习)
  3. 怎么做?(构建、训练、测试 AI)
  4. 如何分析结果?(测试和可视化)

🚀 现在,你能自己训练一个 AI 了吗? 如果有问题,欢迎继续提问! 😊

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词