
文章目录
- 1、Global Filter
- 2、代码实现
paper:Global Filter Networks for Image Classification
Code:https://github.com/raoyongming/GFNet
1、Global Filter
自注意力机制和纯 MLP 模型在视觉任务中展现出潜力,但计算复杂度高,难以扩展到高分辨率特征。而局部自注意力机制虽有效,但引入了人为选择和限制感受野。对此,论文首先分析了傅里叶变换,指出其是分析图像频谱信息的重要工具,具有对数线性复杂度,能够高效地处理全局信息。并基于此提出一种 全局滤波器(Global Filter)。
GlobalFilter 的基本思想是利用傅里叶变换将空间特征转换为频率域,学习空间位置的长期依赖关系。使用可学习的全局滤波器对频率域特征进行逐元素乘法,捕获不同位置之间的交互。最后通过傅里叶逆变换将特征映射回空间域。
对于输入X,Global Filter 的实现过程:
- 傅里叶变换:对特征图进行二维傅里叶变换,将空间信息转换为频率域。
- 滤波:使用可学习的全局滤波器对频率域特征进行逐元素乘法,模拟不同频率成分的交互。
- 逆变换:对滤波后的频率域特征进行二维傅里叶逆变换,将特征映射回空间域,即为最终输出。
与现有的滤波器相比,Global Filter 具有以下优势:
- 高效:傅里叶变换和逆变换具有对数线性复杂度,比自注意力和 MLP 更高效。
- 全局信息:能够有效地捕获全局空间信息,避免局部自注意力机制的局限性。
- 灵活性:可通过调整滤波器设计,控制模型对不同频率成分的关注程度。
Global Filter 结构图:

2、代码实现
import torch
import math
from torch import nn
from einops.einops import rearrangeclass GlobalFilter(nn.Module):def __init__(self, dim, h=14, w=8):super().__init__()self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)self.w = wself.h = hdef forward(self, x, spatial_size=None):B, N, C = x.shapeif spatial_size is None:a = b = int(math.sqrt(N))else:a, b = spatial_sizex = x.view(B, a, b, C)x = x.to(torch.float32)x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')weight = torch.view_as_complex(self.complex_weight)x = x * weightx = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')x = x.reshape(B, N, C)return xif __name__ == '__main__':H, W = 14, 14x = torch.randn(4, 384, 14, 14).cuda()x = rearrange(x, 'b c h w -> b (h w) c')model = GlobalFilter(384, h=H, w=H//2 + 1).cuda()out = model(x)out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)print(out.shape)