欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 培训 > Attention Free Transformer (AFT)-2020论文笔记

Attention Free Transformer (AFT)-2020论文笔记

2025/12/19 10:36:19 来源:https://blog.csdn.net/qq_52964132/article/details/145386928  浏览:    关键词:Attention Free Transformer (AFT)-2020论文笔记

在这里插入图片描述


名称:

Attention Free Transformer (AFT)

来源:

[2105.14103] An Attention Free Transformer

相关工作:

#Approximatingthedotproduct #Sparselocalattention #Contextcompression #Eliminatingdotproductattention #MLPsforvision

创新点:

1ub40zla.40o.png

lzs4fzme.3l4.png

1q5zzfyb.5ka.png

贡献:

  • 提出了一种全新的注意力机制替代方案,完全摒弃了点积注意力。

  • AFT的计算复杂度与输入长度和特征维度呈线性关系,适用于大规模数据。

  • AFT-local和AFT-conv变体通过引入局部性和空间权重共享,进一步提高了模型的效率和性能。

代码:

# ---------------------------------------  
# 论文:An Attention Free Transformer (arxiv2021)  
# ---------------------------------------  
import torch  
from torch import nn  
from torch.nn import init  class AFT_FULL(nn.Module):  def __init__(self, d_model, n=49, simple=False):  super(AFT_FULL, self).__init__()  self.fc_q = nn.Linear(d_model, d_model)  self.fc_k = nn.Linear(d_model, d_model)  self.fc_v = nn.Linear(d_model, d_model)  if (simple):  self.position_biases = torch.zeros((n, n))  else:  self.position_biases = nn.Parameter(torch.ones((n, n)))  self.d_model = d_model  self.n = n  self.sigmoid = nn.Sigmoid()  self.init_weights()  def init_weights(self):  for m in self.modules():  if isinstance(m, nn.Conv2d):  init.kaiming_normal_(m.weight, mode='fan_out')  if m.bias is not None:  init.constant_(m.bias, 0)  elif isinstance(m, nn.BatchNorm2d):  init.constant_(m.weight, 1)  init.constant_(m.bias, 0)  elif isinstance(m, nn.Linear):  init.normal_(m.weight, std=0.001)  if m.bias is not None:  init.constant_(m.bias, 0)  def forward(self, input):  bs, n, dim = input.shape  q = self.fc_q(input)  # bs,n,dim  k = self.fc_k(input).view(1, bs, n, dim)  # 1,bs,n,dim  v = self.fc_v(input).view(1, bs, n, dim)  # 1,bs,n,dim  numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2)  # n,bs,dim  denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2)  # n,bs,dim  out = (numerator / denominator)  # n,bs,dim  out = self.sigmoid(q) * (out.permute(1, 0, 2))  # bs,n,dim  return out  # 输入 B C N,  输出 B C Nif __name__ == '__main__':  block = AFT_FULL(d_model=512, n=64).cuda()  input = torch.rand(64, 64, 512).cuda()  output = block( input)  print(input.size(), output.size())

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词