🧮 Cross-Attention 公式详解
1.📌 定义
Cross-Attention(交叉注意力)常用于 Transformer 解码器、Encoder-Decoder 结构中,用于将一个序列(如目标语言)对另一个序列(如源语言)的表示进行对齐和关注。
2.⚙️ 公式结构
我们有:
• 查询向量来自 Decoder 的当前输入(Q)
• 键和值向量来自 Encoder 的输出(K 和 V)
3.🔢 公式
Cross-Attention 核心计算过程
-
输入维度:
- 查询矩阵: Q ∈ R T q × d Q \in \mathbb{R}^{T_q \times d} Q∈RTq×d
- 键矩阵: K ∈ R T k × d K \in \mathbb{R}^{T_k \times d} K∈RTk×d
- 值矩阵: V ∈ R T k × d V \in \mathbb{R}^{T_k \times d} V∈RTk×d
-
计算注意力权重:
Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dkQK⊤)V
- Q Q Q 是 Decoder 的投影输出
- K , V K, V K,V 是 Encoder 的输出进行线性变换后的结果
- d k d_k dk 是键向量的维度(通常是 d / h d / h d/h, h h h 是头数)
- 多头注意力(Multi-Head Attention):
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个 head 的计算:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i) headi=Attention(QWiQ,KWiK,VWiV)
- W i Q , W i K , W i V ∈ R d × d h W^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d \times d_h} WiQ,WiK,WiV∈Rd×dh 是每个头独立的投影矩阵
- W O ∈ R h d h × d W^O \in \mathbb{R}^{hd_h \times d} WO∈Rhdh×d 是输出投影矩阵
4.📝 使用场景说明
• Encoder-Decoder:Decoder 的每一层利用 Cross-Attention 获取对 Encoder 输出的关注
• 多模态:文本 Query + 图像 Key/Value,实现跨模态融合
• 检索增强:Query 结合外部知识库的 Key/Value 进行对齐匹配
5.代码实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CrossAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0.1):super(CrossAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"# Linear projection for Q, K, Vself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)# Output projectionself.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):"""query: [batch_size, tgt_len, embed_dim] - from decoderkey: [batch_size, src_len, embed_dim] - from encodervalue: [batch_size, src_len, embed_dim] - from encodermask: [batch_size, tgt_len, src_len] (optional)"""B, T_q, _ = query.size()T_k = key.size(1)# Project Q, K, VQ = self.q_proj(query) # [B, T_q, embed_dim]K = self.k_proj(key) # [B, T_k, embed_dim]V = self.v_proj(value) # [B, T_k, embed_dim]# Split into headsQ = Q.view(B, T_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, T_q, head_dim]K = K.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, T_k, head_dim]V = V.view(B, T_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, T_k, head_dim]# Scaled Dot-Product Attentionscores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, heads, T_q, T_k]if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attn = F.softmax(scores, dim=-1) # [B, heads, T_q, T_k]attn = self.dropout(attn)context = torch.matmul(attn, V) # [B, heads, T_q, head_dim]context = context.transpose(1, 2).contiguous().view(B, T_q, self.embed_dim) # [B, T_q, embed_dim]output = self.out_proj(context) # [B, T_q, embed_dim]return output
6.使用实例
decoder_query = torch.randn(8, 10, 512) # 来自 decoder
encoder_key_value = torch.randn(8, 20, 512) # 来自 encodercross_attn = CrossAttention(embed_dim=512, num_heads=8)
out = cross_attn(decoder_query, encoder_key_value, encoder_key_value)
print(out.shape) # [8, 10, 512]