Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 21, 2023
1 parent 063171e commit 1995a7c
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tools/pnnx/tests/test_pnnx_fuse_conv3d_batchnorm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def __init__(self):
else:
self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2,1), groups=2, bias=False, padding_mode='zeros')
self.bn_4 = nn.BatchNorm3d(num_features=32)
self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect')
self.bn_5 = nn.BatchNorm3d(num_features=32)
self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate')
self.bn_6 = nn.BatchNorm3d(num_features=28)
#self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular')
#self.bn_7 = nn.BatchNorm3d(num_features=24)
if version.parse(torch.__version__) >= version.parse('1.10'):
self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect')
self.bn_5 = nn.BatchNorm3d(num_features=32)
self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate')
self.bn_6 = nn.BatchNorm3d(num_features=28)
#self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular')
#self.bn_7 = nn.BatchNorm3d(num_features=24)

def forward(self, x):
x = self.conv_0(x)
Expand All @@ -55,6 +56,9 @@ def forward(self, x):
x = self.bn_3(x)
x = self.conv_4(x)
x = self.bn_4(x)
if version.parse(torch.__version__) < version.parse('1.10'):
return x

x = self.conv_5(x)
x = self.bn_5(x)
x = self.conv_6(x)
Expand Down

0 comments on commit 1995a7c

Please sign in to comment.