-
Notifications
You must be signed in to change notification settings - Fork 29
/
model.py
31 lines (26 loc) · 1.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AODnet(nn.Module):
def __init__(self):
super(AODnet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)
self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, padding=2)
self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, padding=3)
self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, padding=1)
self.b = 1
def forward(self, x):
x1 = F.relu(self.conv1(x))
x2 = F.relu(self.conv2(x1))
cat1 = torch.cat((x1, x2), 1)
x3 = F.relu(self.conv3(cat1))
cat2 = torch.cat((x2, x3),1)
x4 = F.relu(self.conv4(cat2))
cat3 = torch.cat((x1, x2, x3, x4),1)
k = F.relu(self.conv5(cat3))
if k.size() != x.size():
raise Exception("k, haze image are different size!")
output = k * x - k + self.b
return F.relu(output)