欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > Cross-Attention:注意力机制详解-代码实现和公式详解《二》

Cross-Attention:注意力机制详解-代码实现和公式详解《二》

2025/8/10 15:55:48 来源:https://blog.csdn.net/guoguozgw/article/details/148432179  浏览:    关键词:Cross-Attention:注意力机制详解-代码实现和公式详解《二》

🧮 Cross-Attention 公式详解

1.📌 定义

Cross-Attention(交叉注意力)常用于 Transformer 解码器、Encoder-Decoder 结构中,用于将一个序列(如目标语言)对另一个序列(如源语言)的表示进行对齐和关注。

2.⚙️ 公式结构

我们有:
• 查询向量来自 Decoder 的当前输入(Q)
• 键和值向量来自 Encoder 的输出(K 和 V)

3.🔢 公式

Cross-Attention 核心计算过程

  1. 输入维度:

    • 查询矩阵: Q ∈ R T q × d Q \in \mathbb{R}^{T_q \times d} QRTq×d
    • 键矩阵: K ∈ R T k × d K \in \mathbb{R}^{T_k \times d} KRTk×d
    • 值矩阵: V ∈ R T k × d V \in \mathbb{R}^{T_k \times d} VRTk×d
  2. 计算注意力权重:

Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk QK)V

  • Q Q Q 是 Decoder 的投影输出
  • K , V K, V K,V 是 Encoder 的输出进行线性变换后的结果
  • d k d_k dk 是键向量的维度(通常是 d / h d / h d/h h h h 是头数)
  1. 多头注意力(Multi-Head Attention):

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中每个 head 的计算:

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i) headi=Attention(QWiQ,KWiK,VWiV)

  • W i Q , W i K , W i V ∈ R d × d h W^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d \times d_h} WiQ,WiK,WiVRd×dh 是每个头独立的投影矩阵
  • W O ∈ R h d h × d W^O \in \mathbb{R}^{hd_h \times d} WORhdh×d 是输出投影矩阵

4.📝 使用场景说明

•	Encoder-Decoder:Decoder 的每一层利用 Cross-Attention 获取对 Encoder 输出的关注
•	多模态:文本 Query + 图像 Key/Value,实现跨模态融合
•	检索增强:Query 结合外部知识库的 Key/Value 进行对齐匹配

5.代码实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CrossAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0.1):super(CrossAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"# Linear projection for Q, K, Vself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)# Output projectionself.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):"""query: [batch_size, tgt_len, embed_dim] - from decoderkey:   [batch_size, src_len, embed_dim] - from encodervalue: [batch_size, src_len, embed_dim] - from encodermask:  [batch_size, tgt_len, src_len] (optional)"""B, T_q, _ = query.size()T_k = key.size(1)# Project Q, K, VQ = self.q_proj(query)  # [B, T_q, embed_dim]K = self.k_proj(key)    # [B, T_k, embed_dim]V = self.v_proj(value)  # [B, T_k, embed_dim]# Split into headsQ = Q.view(B, T_q, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, T_q, head_dim]K = K.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, T_k, head_dim]V = V.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, T_k, head_dim]# Scaled Dot-Product Attentionscores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [B, heads, T_q, T_k]if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attn = F.softmax(scores, dim=-1)  # [B, heads, T_q, T_k]attn = self.dropout(attn)context = torch.matmul(attn, V)  # [B, heads, T_q, head_dim]context = context.transpose(1, 2).contiguous().view(B, T_q, self.embed_dim)  # [B, T_q, embed_dim]output = self.out_proj(context)  # [B, T_q, embed_dim]return output

6.使用实例

decoder_query = torch.randn(8, 10, 512)  # 来自 decoder
encoder_key_value = torch.randn(8, 20, 512)  # 来自 encodercross_attn = CrossAttention(embed_dim=512, num_heads=8)
out = cross_attn(decoder_query, encoder_key_value, encoder_key_value)
print(out.shape)  # [8, 10, 512]

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词