-
Notifications
You must be signed in to change notification settings - Fork 38
/
dot_ops.py
53 lines (44 loc) · 2.04 KB
/
dot_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# -*- coding: utf-8 -*-
import torch
import torch. nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
class Gaussian(nn.Module):
def __init__(self, in_channels, sigmalist, kernel_size=64, stride=1, padding=0, froze=True):
super(Gaussian, self).__init__()
out_channels = len(sigmalist) * in_channels
# gaussian kernel
mu = kernel_size // 2
gaussFuncTemp = lambda x: (lambda sigma: math.exp(-(x - mu) ** 2 / float(2 * sigma ** 2)))
gaussFuncs = [gaussFuncTemp(x) for x in range(kernel_size)]
windows = []
for sigma in sigmalist:
gauss = torch.Tensor([gaussFunc(sigma) for gaussFunc in gaussFuncs])
gauss /= gauss.sum()
_1D_window = gauss.unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(in_channels, 1, kernel_size, kernel_size).contiguous())
windows.append(window)
kernels = torch.stack(windows)
kernels = kernels.permute(1, 0, 2, 3, 4)
weight = kernels.reshape(out_channels, in_channels, kernel_size, kernel_size)
self.gkernel = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=in_channels, bias=False)
self.gkernel.weight = torch.nn.Parameter(weight)
if froze: self.frozePara()
def forward(self, dotmaps):
gaussianmaps = self.gkernel(dotmaps)
return gaussianmaps
def frozePara(self):
for para in self.parameters():
para.requires_grad = False
class SumPool2d(nn.Module):
def __init__(self, kernel_size):
super(SumPool2d, self).__init__()
self.avgpool = nn.AvgPool2d(kernel_size, stride=1, padding=kernel_size // 2)
if type(kernel_size) is not int:
self.area = kernel_size[0] * kernel_size[1]
else:
self.area = kernel_size * self.kernel_size
def forward(self, dotmap):
return self.avgpool(dotmap) * self.area