欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 新车 > 以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

2025/5/5 17:56:08 来源:https://blog.csdn.net/fanjinglian_/article/details/144264591  浏览:    关键词:以nlp为例,区分BatchNorm、LayerNorm、GroupNorm、RMSNorm

        以nlp中一个小批次数据,详细区分BatchNorm、LayerNorm、GroupNorm、RMSNorm。这几种归一化的不同。如下表格,从计算范围、统计量、计算复杂度以及应用场景等方面的差异给出。

方法计算范围统计量计算复杂度应用场景
BatchNorm跨所有句子的同一维度使用批次统计量O(batch_size * seq_len)适合 CNN,需要较大 batch size
LayerNorm单个 Token 的所有维度使用单 Token 的统计量O(embedding_dim)适合 Transformer,独立于 batch size
GroupNorm单个 Token 的维度组使用组内统计量O(embedding_dim / num_groups)适合小 batch size 场景
RMSNorm单个 Token 的所有维度,但简化计算只使用 RMS 值O(embedding_dim),但运算更简单适合需要高效计算的场景

假设小批量数据是: 

句子1: "我来自加拿大,那是一个美丽的国家" (10个token)
句子2: "你来自加拿大的哪个城市?" (8个token)

假设各个句子的嵌入式数值如下:

import torchdef create_sample_data():"""创建示例数据:两个中文句子的词嵌入句子1: "我/来自/加拿大/,/那是/一个/美丽的/国家"句子2: "你/来自/加拿大/的/哪个/城市/?""""# 创建两个示例句子的嵌入# batch_size=2, max_seq_len=8 (用padding补齐), embedding_dim=10data = torch.tensor([# 句子1的词嵌入 (8个token)[[2.1, -1.5, 0.8, 3.2, -0.4, 1.7, -2.3, 0.5, 1.9, -1.1],  # "我"[1.5, 2.2, -0.7, 1.8, 2.5, -1.2, 1.6, -0.8, 2.0, 1.4],   # "来自"[2.8, -1.9, 1.5, -0.6, 2.1, 1.8, -1.4, 2.2, -0.5, 1.7],  # "加拿大"[0.5, 1.2, -1.8, 2.4, -0.9, 1.5, -2.0, 0.7, 1.6, -1.3],  # ","[1.9, -2.1, 0.6, 1.7, -1.5, 2.3, -0.8, 1.4, -1.9, 2.5],  # "那是"[2.2, 1.6, -1.1, 2.0, -0.3, 1.9, -1.7, 0.9, 2.4, -0.6],  # "一个"[1.7, -1.4, 2.3, -0.5, 1.8, -2.2, 0.8, 1.5, -1.2, 2.1],  # "美丽的"[2.4, -0.8, 1.6, -1.3, 2.7, -0.4, 1.2, -1.8, 2.3, -0.7], # "国家"],# 句子2的词嵌入 (7个token + 1个padding)[[1.8, -2.0, 1.2, 2.8, -0.9, 2.1, -1.8, 0.3, 2.2, -0.7],  # "你"[1.4, 2.5, -0.8, 1.9, 2.3, -1.5, 1.7, -0.6, 2.1, 1.2],   # "来自"[2.6, -1.7, 1.4, -0.5, 2.2, 1.6, -1.3, 2.4, -0.4, 1.8],  # "加拿大"[0.7, 1.3, -1.6, 2.2, -1.0, 1.4, -1.9, 0.8, 1.5, -1.4],  # "的"[2.0, -1.8, 0.9, 1.6, -1.2, 2.5, -0.7, 1.3, -1.7, 2.3],  # "哪个"[2.3, 1.5, -1.2, 1.8, -0.5, 1.7, -1.6, 1.0, 2.2, -0.8],  # "城市"[1.6, -1.3, 2.1, -0.4, 1.9, -2.0, 0.6, 1.7, -1.1, 2.0],  # "?"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],      # padding]], dtype=torch.float32)return data

这段代码生成了一个 2×8×10 的张量,表示 2 个句子,每个句子有 8个 Token,每个 Token 的嵌入维度为 10。下面是各种norm的实现。具体每种的详细讲解请看:

RMSNorm 、GroupNorm、LayerNorm、BatchNorm

提示:关注形状。

 # 创建掩码来处理不同长度的句子mask = torch.zeros(batch_size, max_seq_len)mask[0, :8] = 1  # 第一个句子长度为8mask[1, :7] = 1  # 第二个句子长度为7return embeddings, maskdef batch_norm(x, eps=1e-5):"""BatchNorm实现Args:x: shape [batch_size, seq_len, embedding_dim]eps: 数值稳定性常数Returns:normalized: shape [batch_size, seq_len, embedding_dim]"""# 在batch和seq_len维度上计算均值和方差# mean shape: [1, 1, embedding_dim]mean = x.mean(dim=(0, 1), keepdim=True)# var shape: [1, 1, embedding_dim]var = x.var(dim=(0, 1), unbiased=False, keepdim=True)# 归一化normalized = (x - mean) / torch.sqrt(var + eps)return normalizeddef layer_norm(x, eps=1e-5):"""LayerNorm实现Args:x: shape [batch_size, seq_len, embedding_dim]eps: 数值稳定性常数Returns:normalized: shape [batch_size, seq_len, embedding_dim]"""# 在最后一个维度(embedding_dim)上计算均值和方差# mean shape: [batch_size, seq_len, 1]mean = x.mean(dim=-1, keepdim=True)# var shape: [batch_size, seq_len, 1]var = x.var(dim=-1, unbiased=False, keepdim=True)# 归一化normalized = (x - mean) / torch.sqrt(var + eps)return normalizeddef group_norm(x, num_groups=2, eps=1e-5):"""GroupNorm实现Args:x: shape [batch_size, seq_len, embedding_dim]num_groups: 分组数eps: 数值稳定性常数Returns:normalized: shape [batch_size, seq_len, embedding_dim]"""batch_size, seq_len, embedding_dim = x.shape# 重塑张量以进行分组归一化# 将embedding_dim分成num_groups组x = x.reshape(batch_size, seq_len, num_groups, embedding_dim // num_groups)# 在seq_len和每组内计算均值和方差# mean shape: [batch_size, 1, num_groups, 1]mean = x.mean(dim=(1, 3), keepdim=True)# var shape: [batch_size, 1, num_groups, 1]var = x.var(dim=(1, 3), unbiased=False, keepdim=True)# 归一化normalized = (x - mean) / torch.sqrt(var + eps)# 重塑回原始形状normalized = normalized.reshape(batch_size, seq_len, embedding_dim)return normalizeddef rms_norm(x, eps=1e-5):"""RMSNorm实现Args:x: shape [batch_size, seq_len, embedding_dim]eps: 数值稳定性常数Returns:normalized: shape [batch_size, seq_len, embedding_dim]"""# 计算RMS (Root Mean Square)# rms shape: [batch_size, seq_len, 1]rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)# 归一化 (只除以RMS,不减均值)normalized = x / rmsreturn normalized

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词