欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 金融 > 从代码学习深度学习 - 预训练word2vec PyTorch版

从代码学习深度学习 - 预训练word2vec PyTorch版

2025/5/21 6:45:08 来源:https://blog.csdn.net/weixin_43887510/article/details/148098256  浏览:    关键词:从代码学习深度学习 - 预训练word2vec PyTorch版

文章目录

  • 前言
  • 辅助工具
    • 1. 绘图工具 (`utils_for_huitu.py`)
    • 2. 数据处理工具 (`utils_for_data.py`)
    • 3. 训练辅助工具 (`utils_for_train.py`)
  • 预训练 Word2Vec - 主流程
    • 1. 环境设置与数据加载
    • 2. 跳元模型 (Skip-gram Model)
      • 2.1. 嵌入层 (Embedding Layer)
      • 2.2. 定义前向传播
    • 3. 训练
      • 3.1. 二元交叉熵损失
      • 3.2. 初始化模型参数
      • 3.3. 定义训练阶段代码
      • 3.4. 开始训练
    • 4. 应用词嵌入
  • 总结


前言

词嵌入(Word Embeddings)是自然语言处理(NLP)领域中的基石技术之一。它们将词语从稀疏的、高维的独热编码(one-hot encoding)表示转换为稠密的、低维的向量表示。这些向量能够捕捉词语之间的语义和句法关系,使得相似的词在向量空间中距离更近。Word2Vec是其中一种非常流行且有效的词嵌入算法,由Google的Tomas Mikolov等人在2013年提出。它主要包含两种模型架构:CBOW(Continuous Bag-of-Words,连续词袋模型)和Skip-gram(跳字模型)。

本篇博客将聚焦于Skip-gram模型,并结合**负采样(Negative Sampling)**这一重要的优化技巧,通过PyTorch框架从零开始实现一个Word2Vec模型。我们将详细探讨数据预处理的每一个步骤,如何构建模型,如何进行训练,以及训练完成后如何应用得到的词向量来寻找相似词。通过深入代码细节,我们希望能帮助读者更好地理解Word2Vec的内部工作原理及其在PyTorch中的实现。

我们将依赖一系列辅助脚本来处理数据、可视化训练过程以及进行模型训练。让我们一步步揭开Word2Vec的神秘面纱。

完整代码:下载链接

辅助工具

在构建和训练Word2Vec模型之前,我们首先介绍一下项目中用到的一些辅助Python脚本。这些脚本提供了数据加载、预处理、可视化以及训练监控等常用功能。

1. 绘图工具 (utils_for_huitu.py)

这个脚本主要封装了使用matplotlib进行绘图的常用函数,特别是在Jupyter Notebook环境中,它包含了一个Animator类,可以动态地展示训练过程中的损失变化。

# 导入必要的包
import matplotlib.pyplot as plt  # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline  # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display  # 用于后续动态显示(如 Animator)
import torch  # 导入PyTorch库,用于处理张量类型的图像
import numpy as np  # 导入NumPy,可能用于数据处理
import matplotlib as mpl  # 导入Matplotlib主模块,用于设置图像属性def set_figsize(figsize=(3.5, 2.5)):"""设置matplotlib图形的大小参数:figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸输出:无返回值"""plt.rcParams['figure.figsize'] = figsize  # 设置图形默认大小def use_svg_display():"""使用 SVG 格式在 Jupyter 中显示绘图输入:无输出:无返回值"""backend_inline.set_matplotlib_formats('svg')  # 设置 Matplotlib 使用 SVG 格式def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置 Matplotlib 的轴  输入:axes: Matplotlib 的轴对象  # 输入参数:轴对象xlabel: x 轴标签  # 输入参数:x 轴标签ylabel: y 轴标签  # 输入参数:y 轴标签xlim: x 轴范围  # 输入参数:x 轴范围ylim: y 轴范围  # 输入参数:y 轴范围xscale: x 轴刻度类型  # 输入参数:x 轴刻度类型yscale: y 轴刻度类型  # 输入参数:y 轴刻度类型legend: 图例标签列表  # 输入参数:图例标签输出:无返回值  # 函数无显式返回值"""axes.set_xlabel(xlabel)  # 设置 x 轴标签axes.set_ylabel(ylabel)  # 设置 y 轴标签axes.set_xscale(xscale)  # 设置 x 轴刻度类型axes.set_yscale(yscale)  # 设置 y 轴刻度类型axes.set_xlim(xlim)  # 设置 x 轴范围axes.set_ylim(ylim)  # 设置 y 轴范围if legend:  # 检查是否提供了图例标签axes.legend(legend)  # 如果有图例,则设置图例axes.grid()  # 为轴添加网格线class Animator:"""在动画中绘制数据,仅针对一张图的情况"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):"""初始化 Animator 类 输入:xlabel: x 轴标签,默认为 None  # 输入参数:x 轴标签ylabel: y 轴标签,默认为 None  # 输入参数:y 轴标签legend: 图例标签列表,默认为 None  # 输入参数:图例标签xlim: x 轴范围,默认为 None  # 输入参数:x 轴范围ylim: y 轴范围,默认为 None  # 输入参数:y 轴范围xscale: x 轴刻度类型,默认为 'linear'  # 输入参数:x 轴刻度类型yscale: y 轴刻度类型,默认为 'linear'  # 输入参数:y 轴刻度类型fmts: 绘图格式元组,默认为 ('-', 'm--', 'g-.', 'r:')  # 输入参数:线条格式nrows: 子图行数,默认为 1  # 输入参数:子图行数ncols: 子图列数,默认为 1  # 输入参数:子图列数figsize: 图像大小元组,默认为 (3.5, 2.5)  # 输入参数:图像大小输出:无返回值  # 方法无显式返回值定义位置::numref:`sec_softmax_scratch`  # 指明定义的参考位置"""if legend is None:  # 检查 legend 是否为 Nonelegend = []  # 如果为 None,则初始化为空列表use_svg_display()  # 设置绘图显示为 SVG 格式self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)  # 创建绘图对象和子图if nrows * ncols == 1:  # 判断是否只有一个子图self.axes = [self.axes, ]  # 如果是单个子图,将 axes 转为列表self.config_axes = lambda: set_axes(  # 定义 lambda 函数配置坐标轴self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)  # 调用 set_axes 设置参数self.X, self.Y, self.fmts = None, None, fmts  # 初始化数据和格式属性def add(self, x, y):"""向图表中添加多个数据点  输入:x: x 轴数据点  # 输入参数:x 轴数据y: y 轴数据点  # 输入参数:y 轴数据输出:无返回值  # 方法无显式返回值"""if not hasattr(y, "__len__"):  # 检查 y 是否具有长度属性(是否可迭代)y = [y]  # 如果不可迭代,将 y 转为单元素列表n = len(y)  # 获取 y 的长度if not hasattr(x, "__len__"):  # 检查 x 是否具有长度属性x = [x] * n  # 如果不可迭代,将 x 扩展为与 y 同长度的列表if not self.X:  # 检查 self.X 是否已初始化self.X = [[] for _ in range(n)]  # 如果未初始化,为每条线创建空列表if not self.Y:  # 检查 self.Y 是否已初始化self.Y = [[] for _ in range(n)]  # 如果未初始化,为每条线创建空列表for i, (a, b) in enumerate(zip(x, y)):  # 遍历 x 和 y 的数据对if a is not None and b is not None:  # 检查数据点是否有效self.X[i].append(a)  # 将 x 数据点添加到对应列表self.Y[i].append(b)  # 将 y 数据点添加到对应列表self.axes[0].cla()  # 清除当前轴的内容for x, y, fmt in zip(self.X, self.Y, self.fmts):  # 遍历所有数据和格式self.axes[0].plot(x, y, fmt)  # 绘制每条线self.config_axes()  # 调用 lambda 函数配置坐标轴display.display(self.fig)  # 显示当前图形display.clear_output(wait=True)  # 标记当前输出为待清除,但由于 wait=True,它不会立即清除,而是等待下一次 display.display()。def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):"""绘制列表长度对的直方图,用于比较两组列表中元素长度的分布参数:legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签xlabel: str - x轴标签ylabel: str - y轴标签xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)输出:无返回值,但会显示生成的直方图"""set_figsize()  # 设置图形大小# plt.hist返回的三个值:# n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)# bins: array - bin的边界值,形状为 (bin数量+1,)# patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)_, _, patches = plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]])  # 绘制两组数据长度的直方图plt.xlabel(xlabel)  # 设置x轴标签plt.ylabel(ylabel)  # 设置y轴标签# 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据for patch in patches[1].patches:  # patches[1]是ylist对应的矩形对象列表patch.set_hatch('/')  # 设置填充图案为斜线plt.legend(legend)  # 添加图例

解读

  • set_figsizeuse_svg_display 用于基础的Matplotlib绘图设置。
  • set_axes 是一个通用的函数,用于配置图表的坐标轴标签、范围、刻度类型和图例。
  • Animator 类是实现动态绘图的关键。在训练循环中,我们可以周期性地调用其add方法,传入当前的训练轮次(或迭代次数)和对应的损失值(或其他指标)。Animator会清除旧的图像并重新绘制,从而在Jupyter Notebook中形成动画效果,直观地展示训练趋势。
  • show_list_len_pair_hist 函数用于绘制两个列表集合中,各子列表长度分布的直方图,方便进行数据分析和比较。

2. 数据处理工具 (utils_for_data.py)

这个脚本是Word2Vec数据预处理的核心,包含了从读取原始文本、构建词汇表、下采样、生成中心词-上下文词对、负采样到最终打包成PyTorch DataLoader的完整流程。

from collections import Counter  # 导入 Counter 类
from collections import Counter  # 用于词频统计
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作
import random  # 导入随机模块,用于下采样和负采样
import math  # 导入数学函数模块,用于概率计算
import osdef count_corpus(tokens):"""统计词元的频率参数:tokens: 词元列表,可以是:- 一维列表,例如 ['a', 'b']- 二维列表,例如 [['a', 'b'], ['c']]返回值:Counter: Counter 对象,统计每个词元的出现次数"""# 如果输入为空列表,直接返回空计数器if not tokens:  # 等价于 len(tokens) == 0return Counter()# 检查输入是否为二维列表if isinstance(tokens[0], list):# 将二维列表展平为一维列表flattened_tokens = [token for sublist in tokens for token in sublist]else:# 如果是一维列表,直接使用原列表flattened_tokens = tokens# 使用 Counter 统计词频并返回return Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化词表Args:tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表"""# 处理默认参数self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []# 统计词元频率并按频率降序排序counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化词表,'<unk>'为未知词元,索引为0self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 添加满足最小频率要求的词元到词表for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = 

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词