Skip to content

Commit

Permalink
Making ASPP-Layer in DeepLab more generic (#2174)
Browse files Browse the repository at this point in the history
At the moment in the ASPP-Layer the number of output channels are predefined as a constant,
which is good for DeepLab but not necessairly in other projects, where another out-channel Nr. is required.

Also the number of "atrous rates" is fixed to three, which also could be sometimes more or less depending on the notwork-arch.
Again these fixed values may make sense in DeepLab-Model but not necessarily in other type of models.

This pull-req. contains the needed changes to make ASPP-Layer generic.
  • Loading branch information
ArashJavan authored May 4, 2020
1 parent 1a40d9c commit bd27e94
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions torchvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,18 @@ def forward(self, x):


class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates):
def __init__(self, in_channels, atrous_rates, out_channels=256):
super(ASPP, self).__init__()
out_channels = 256
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))

rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))

modules.append(ASPPPooling(in_channels, out_channels))

self.convs = nn.ModuleList(modules)
Expand Down

0 comments on commit bd27e94

Please sign in to comment.