Skip to content

Commit

Permalink
[TORCH] Implement avg_pool1d (#7694)
Browse files Browse the repository at this point in the history
* [TORCH] Implement avg_pool1d

* [TORCH] Unify creation of avg_pooling operations

* [TORCH] Add tests for avg pooling with padding

* [TORCH] Make format checks happy with unified avg_pool
  • Loading branch information
cgerum authored Mar 23, 2021
1 parent 4c66fb2 commit f09f02e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 40 deletions.
84 changes: 46 additions & 38 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,47 +1353,54 @@ def softplus(self, inputs, input_types):
beta = _expr.const(float(inputs[1]), dtype=dtype)
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta

def avg_pool2d(self, inputs, input_types):
data = inputs[0]

pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

def func(x):
return _op.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
def make_avg_pool(self, dim):
def avg_pool(inputs, input_types):
data = inputs[0]

if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)
pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

return func(data)
def func(x):
if dim == 1:
return _op.nn.avg_pool1d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
elif dim == 2:
return _op.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
elif dim == 3:
return _op.nn.avg_pool3d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
else:
msg = "Average Pooling dimension should be between 1 and 3"
raise RuntimeError(msg)

def avg_pool3d(self, inputs, input_types):
data = inputs[0]
if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)

pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])
return func(data)

return _op.nn.avg_pool3d(
data,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
return avg_pool

def linear(self, inputs, input_types):
# https://pytorch.org/docs/stable/nn.functional.html#linear
Expand Down Expand Up @@ -2350,8 +2357,9 @@ def create_convert_map(self):
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
"aten::softplus": self.softplus,
"aten::avg_pool2d": self.avg_pool2d,
"aten::avg_pool3d": self.avg_pool3d,
"aten::avg_pool1d": self.make_avg_pool(1),
"aten::avg_pool2d": self.make_avg_pool(2),
"aten::avg_pool3d": self.make_avg_pool(3),
"aten::linear": self.linear,
"aten::dropout": self.dropout,
"aten::dropout_": self.dropout,
Expand Down
28 changes: 26 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,24 @@ def forward(self, *args):


@tvm.testing.uses_gpu
def test_forward_avgpool():
def test_forward_avgpool1d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10]

class AvgPool1D2(Module):
def forward(self, *args):
return torch.nn.functional.avg_pool1d(args[0], kernel_size=[10])

input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool1d(kernel_size=[10]).eval(), input_data=input_data)
verify_model(AvgPool1D2().float().eval(), input_data=input_data)
verify_model(
torch.nn.AvgPool1d(kernel_size=[5], stride=2, padding=2).eval(), input_data=input_data
)


@tvm.testing.uses_gpu
def test_forward_avgpool2d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

Expand All @@ -820,6 +837,9 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
verify_model(AvgPool2D2().float().eval(), input_data=input_data)
verify_model(
torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
)


@tvm.testing.uses_gpu
Expand All @@ -834,6 +854,9 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
verify_model(AvgPool3D1().float().eval(), input_data=input_data)
verify_model(
torch.nn.AvgPool3d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3838,7 +3861,8 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
test_forward_avgpool()
test_forward_avgpool1d()
test_forward_avgpool2d()
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()
Expand Down

0 comments on commit f09f02e

Please sign in to comment.