文章目录
- 前言
- 输入表示
- BERTEncoder类实现
- 预训练任务
- 掩蔽语言模型(Masked Language Modeling)
- 下一句预测(Next Sentence Prediction)
- 整合代码
- 总结
前言
在自然语言处理(NLP)的世界里,词嵌入技术是基石。从早期的 Word2Vec、GloVe 等上下文无关(context-independent)模型,到后来能够根据上下文动态调整词表示的 ELMo、GPT 等上下文敏感(context-sensitive)模型,我们见证了 NLP 领域表示学习的飞速发展。
上下文无关模型,如 Word2Vec,为每个词分配一个固定的向量,无法区分多义词。例如,“bank”在“river bank”(河岸)和“investment bank”(投资银行)中会被赋予完全相同的表示。
为了解决这个问题,ELMo 和 GPT 等模型应运而生。ELMo 使用双向 LSTM 来编码上下文,但其下游任务通常需要一个特定于任务的模型架构。而 GPT 使用强大的 Transformer 解码器,是任务无关的,但其自回归的特性使其只能“从左到右”地编码上下文,无法同时利用左右两侧的信息。
2018年,BERT(Bidirectional Encoder Representations from Transformers)横空出世,它集众家之长,革命性地改变了 NLP 领域。BERT 不仅实现了真正的双向上下文编码,而且其任务无关的设计使其能够以最小的架构改动,在众多 NLP 任务中取得顶尖(SOTA)的性能。
本篇文章将以 PyTorch 为工具,深入剖析 BERT 的内部结构和实现细节。我们将从输入表示开始,一步步构建 BERT 编码器,并实现其两个核心的预训练任务:掩蔽语言模型(MLM)和下一句预测(NSP)。让我们通过代码,揭开 BERT 的神秘面纱。
完整代码:下载链接
输入表示
BERT 的一个精妙之处在于其能够灵活处理单个文本和文本对。为了实现这一点,BERT 对输入序列进行了特殊格式化。
- 单个文本输入:格式为
[CLS] 文本序列A [SEP]
。 - 文本对输入:格式为
[CLS] 文本序列A [SEP] 文本序列B [SEP]
。
其中:
[CLS]
:一个特殊的分类标记。它不对应任何真实词元,但其在 BERT 输出中的最终表示被设计为聚合整个输入序列的信息,通常用于分类任务。[SEP]
:一个特殊的分隔标记,用于分隔不同的文本片段。
为了让模型能够区分文本对中的两个句子(例如,在问答任务中区分问题和上下文),BERT 引入了 片段嵌入(Segment Embeddings)。第一个句子的所有词元会加上片段嵌入 A,第二个句子的所有词元会加上片段嵌入 B。
下面的 get_tokens_and_segments
函数清晰地展示了这一过程。它接收一个或两个文本序列(已分词),并返回符合 BERT 格式的词元列表及其对应的片段索引。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BERT模型输入序列处理工具
用于获取输入序列的词元及其对应的片段索引
"""from typing import List, Tuple, Optionaldef get_tokens_and_segments(tokens_a: List[str], tokens_b: Optional[List[str]] = None) -> Tuple[List[str], List[int]]:"""获取输入序列的词元及其片段索引该函数用于处理BERT模型的输入序列,将单个或两个文本序列转换为包含特殊标记的词元列表,并生成对应的片段索引,用于区分不同的输入片段。参数:tokens_a (List[str]): 第一个输入序列的词元列表维度: [seq_len_a] - seq_len_a为第一个序列的长度tokens_b (Optional[List[str]], 可选): 第二个输入序列的词元列表,默认为None维度: [seq_len_b] - seq_len_b为第二个序列的长度返回:Tuple[List[str], List[int]]: 包含两个元素的元组- tokens (List[str]): 处理后的完整词元序列维度: [total_len] - total_len为最终序列总长度单序列时: total_len = seq_len_a + 2 (包含<cls>和<sep>)双序列时: total_len = seq_len_a + seq_len_b + 3 (包含<cls>和两个<sep>)- segments (List[int]): 对应的片段索引列表维度: [total_len] - 与tokens长度相同0表示第一个片段(包括<cls>和第一个<sep>)1表示第二个片段(包括第二个<sep>)示例:>>> tokens_a = ['hello', 'world']>>> tokens, segments = get_tokens_and_segments(tokens_a)>>> print(tokens) # ['<cls>', 'hello', 'world', '<sep>']>>> print(segments) # [0, 0, 0, 0]>>> tokens_b = ['good', 'morning']>>> tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)>>> print(tokens) # ['<cls>', 'hello', 'world', '<sep>', 'good', 'morning', '<sep>']>>> print(segments) # [0, 0, 0, 0, 1, 1, 1]"""# 构建第一个片段的词元序列# tokens维度: [seq_len_a + 2] - 包含<cls>标记 + tokens_a + <sep>标记tokens = ['<cls>'] + tokens_a + ['<sep>']# 构建第一个片段的索引序列,全部标记为0# segments维度: [seq_len_a + 2] - 与tokens长度对应segments = [0] * (len(tokens_a) + 2)# 如果存在第二个输入序列if tokens_b is not None:# 将第二个序列的词元添加到tokens中,并在末尾添加<sep>标记# tokens维度更新为: [seq_len_a + seq_len_b + 3]tokens += tokens_b + ['<sep>']# 为第二个序列生成片段索引,全部标记为1,并添加到segments中# segments维度更新为: [seq_len_a + seq_len_b + 3] - 与tokens长度对应segments += [1] * (len(tokens_b) + 1)return tokens, segments# 调用示例
if __name__ == "__main__":print("=" * 60)print("BERT输入序列处理示例")print("=" * 60)# 示例1: 单个序列处理print("\n【示例1】单个序列处理:")print("-" * 30)tokens_a = ['我', '喜欢', '自然', '语言', '处理']print(f"输入序列A: {tokens_a}")print(f"序列A长度: {len(tokens_a)}")tokens, segments = get_tokens_and_segments(tokens_a)print(f"处理后词元: {tokens}")print(f"片段索引: {segments}")print(f"最终长度: {len(tokens)} (原长度{len(tokens_a)} + 2个特殊标记)")# 示例2: 两个序列处理(句子对分类任务)print("\n【示例2】两个序列处理(句子对任务):")print("-" * 30)tokens_a = ['今天', '天气', '很好']tokens_b = ['适合', '外出', '游玩']print(f"输入序列A: {tokens_a}")print(f"输入序列B: {tokens_b}")print(f"序列A长度: {len(tokens_a)}, 序列B长度: {len(tokens_b)}")tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)print(f"处理后词元: {tokens}")print(f"片段索引: {segments}")print(f"最终长度: {len(tokens)} (A:{len(tokens_a)} + B:{len(tokens_b)} + 3个特殊标记)")# 示例3: 问答任务示例print("\n【示例3】问答任务示例:")print("-" * 30)question = ['什么', '是', 'BERT', '模型']context = ['BERT', '是', '一种', '预训练', '语言', '模型']print(f"问题: {question}")print(f"上下文: {context}")tokens, segments = get_tokens_and_segments(question, context)print(f"处理后词元: {tokens}")print(f"片段索引: {segments}")# 分析片段索引的含义print("\n【片段索引说明】:")print("-" * 30)print("索引0: 问题部分(包括<cls>和第一个<sep>)")print("索引1: 上下文部分(包括第二个<sep>)")# 示例4: 英文示例print("\n【示例4】英文文本处理:")print("-" * 30)tokens_a = ['hello', 'world']tokens_b = ['good', 'morning']print(f"English A: {tokens_a}")print(f"English B: {tokens_b}")tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)print(f"Processed tokens: {tokens}")print(f"Segment indices: {segments}")print("\n" + "=" * 60)print("示例运行完成!")
运行结果:
============================================================
BERT输入序列处理示例
============================================================【示例1】单个序列处理:
------------------------------
输入序列A: ['我', '喜欢', '自然', '语言', '处理']
序列A长度: 5
处理后词元: ['<cls>', '我', '喜欢', '自然', '语言', '处理', '<sep>']
片段索引: [0, 0, 0, 0, 0, 0, 0]
最终长度: 7 (原长度5 + 2个特殊标记)【示例2】两个序列处理(句子对任务):
------------------------------
输入序列A: ['今天', '天气', '很好']
输入序列B: ['适合', '外出', '游玩']
序列A长度: 3, 序列B长度: 3
处理后词元: ['<cls>', '今天', '天气', '很好', '<sep>', '适合', '外出', '游玩', '<sep>']
片段索引: [0, 0, 0, 0, 0, 1, 1, 1, 1]
最终长度: 9 (A:3 + B:3 + 3个特殊标记)【示例3】问答任务示例:
------------------------------
问题: ['什么', '是', 'BERT', '模型']
上下文: ['BERT', '是', '一种', '预训练', '语言', '模型']
处理后词元: ['<cls>', '什么', '是', 'BERT', '模型', '<sep>', 'BERT', '是', '一种', '预训练', '语言', '模型', '<sep>']
片段索引: [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]【片段索引说明】:
------------------------------
索引0: 问题部分(包括<cls>和第一个<sep>)
索引1: 上下文部分(包括第二个<sep>)【示例4】英文文本处理:
------------------------------
English A: ['hello', 'world']
English B: ['good', 'morning']
Processed tokens: ['<cls>', 'hello', 'world', '<sep>', 'good', 'morning', '<sep>']
Segment indices: [0, 0, 0, 0, 1, 1, 1]============================================================
示例运行完成!
除了片段嵌入,BERT 还使用了 位置嵌入(Position Embeddings) 来让模型感知到词元在序列中的顺序。与原始 Transformer 使用固定的正弦/余弦位置编码不同,BERT 使用的是可学习的位置嵌入。
最终,每个输入词元的表示是其 词元嵌入、片段嵌入 和 位置嵌入 三者之和。
BERTEncoder类实现
BERT 的核心是一个多层的双向 Transformer 编码器。接下来,我们将实现这个编码器。我们的 BERTEncoder
类将包含词嵌入层、片段嵌入层、可学习的位置嵌入参数,以及堆叠的多层 EncoderBlock
。
下面的代码块包含了构建 BERTEncoder
所需的所有组件,从底层的缩放点积注意力到完整的多头注意力和编码器块。代码注释非常详细,解释了每个模块的功能、参数和张量维度的变化。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BERT编码器的完整实现
包含缩放点积注意力、多头注意力、编码器块和BERT编码器的完整定义
"""import math
import torch
import torch.nn as nn
from typing import Optionaldef masked_softmax(X: torch.Tensor, valid_lens: Optional[torch.Tensor]) -> torch.Tensor:"""通过在最后一个轴上遮蔽元素来执行softmax操作参数:X (torch.Tensor): 输入张量,维度: [batch_size, seq_len, seq_len] 或 [batch_size*num_heads, seq_len, seq_len]valid_lens (Optional[torch.Tensor]): 有效长度,维度: [batch_size] 或 [batch_size, seq_len] 或 None返回:torch.Tensor: 经过遮蔽的softmax结果,维度与输入X相同"""if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 在最后的轴上,被遮蔽的元素使用一个非常大的负值替换,从而其softmax输出为0X = X.reshape(-1, shape[-1])for batch_idx, valid_len in enumerate(valid_lens):X[batch_idx, valid_len:] = -1e6return nn.functional.softmax(X, dim=-1).reshape(shape)def transpose_qkv(X: torch.Tensor, num_heads: int) -> torch.Tensor:"""为了多头注意力的并行计算而变换形状参数:X (torch.Tensor): 输入张量,维度: [batch_size, seq_len, num_hiddens]num_heads (int): 注意力头的数量返回:torch.Tensor: 变换后的张量,维度: [batch_size*num_heads, seq_len, num_hiddens/num_heads]"""# 输入X的形状: [batch_size, seq_len, num_hiddens]# 输出形状: [batch_size, seq_len, num_heads, num_hiddens/num_heads]X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出形状: [batch_size, num_heads, seq_len, num_hiddens/num_heads]X = X.permute(0, 2, 1, 3)# 最终输出形状: [batch_size*num_heads, seq_len, num_hiddens/num_heads]return X.reshape(-1, X.shape[2], X.shape[3])def transpose_output(X: torch.Tensor, num_heads: int) -> torch.Tensor:"""逆转transpose_qkv函数的操作参数:X (torch.Tensor): 输入张量,维度: [batch_size*num_heads, seq_len, num_hiddens/num_heads]num_heads (int): 注意力头的数量返回:torch.Tensor: 逆转换后的张量,维度: [batch_size, seq_len, num_hiddens]"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)class PositionWiseFFN(nn.Module):"""基于位置的前馈网络参数:ffn_num_input (int): 输入特征维度ffn_num_hiddens (int): 隐藏层特征维度ffn_num_outputs (int): 输出特征维度"""def __init__(self, ffn_num_input: int, ffn_num_hiddens: int, ffn_num_outputs: int, **kwargs):super