Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

频率分量的确定 #40

Open
yueshuheng opened this issue Nov 28, 2022 · 4 comments
Open

频率分量的确定 #40

yueshuheng opened this issue Nov 28, 2022 · 4 comments

Comments

@yueshuheng
Copy link

想请问一下,这些频率分量是怎么确定的呀?

@Frank-Star-fn
Copy link

作者在这里给出了一种两步准则来选择MCA模块中的频率分量。其主要思想是为:

第一步先分别计算出通道注意力中每个频率分量的结果;
第二步再根据所得结果筛选出Top-k个性能最佳的频率分量。

可以参考这篇文章:
https://zhuanlan.zhihu.com/p/339215696

@Zhongrocky
Copy link

Hi, this work is amazing! However, I found that the module of the MultispectralDCTlayer is empty while debugging. May I ask what is the reason for this result.
image

@cfzd
Copy link
Owner

cfzd commented Jun 6, 2023

@Zhongrocky
It is not empty. It just don't fave any learnable parameters. The implementation of MultispectralDCTlayer is here:

FcaNet/model/layer.py

Lines 65 to 117 in aa5fb63

class MultiSpectralDCTLayer(nn.Module):
"""
Generate dct filters
"""
def __init__(self, height, width, mapper_x, mapper_y, channel):
super(MultiSpectralDCTLayer, self).__init__()
assert len(mapper_x) == len(mapper_y)
assert channel % len(mapper_x) == 0
self.num_freq = len(mapper_x)
# fixed DCT init
self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
# fixed random init
# self.register_buffer('weight', torch.rand(channel, height, width))
# learnable DCT init
# self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
# learnable random init
# self.register_parameter('weight', torch.rand(channel, height, width))
# num_freq, h, w
def forward(self, x):
assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
# n, c, h, w = x.shape
x = x * self.weight
result = torch.sum(x, dim=[2,3])
return result
def build_filter(self, pos, freq, POS):
result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
if freq == 0:
return result
else:
return result * math.sqrt(2)
def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)
c_part = channel // len(mapper_x)
for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
for t_x in range(tile_size_x):
for t_y in range(tile_size_y):
dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)
return dct_filter

@Zhongrocky
Copy link

@cfzd Thank you a lot. Got it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants