Skip to content

Commit

Permalink
fix bugs of sum, add warning to topk
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Mar 13, 2021
1 parent 818e759 commit 167e8f0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
15 changes: 9 additions & 6 deletions torch2trt_dynamic/converters/sum.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from torch2trt_dynamic.torch2trt_dynamic import *
from torch2trt_dynamic.module_test import add_module_test
from .unary import UnaryModule


@tensorrt_converter('torch.sum')
@tensorrt_converter('torch.Tensor.sum')
def convert_sum(ctx):
input = ctx.method_args[0]
dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1, input.ndim)))
keepdim = get_arg(ctx, 'keepdim', pos=2, default=False)
input_trt= trt_(ctx.network, input)
if dim < 0:
dim = input.dim() + dim
input_trt = trt_(ctx.network, input)
output = ctx.method_return
layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.SUM, torch_dim_to_trt_axes(dim), keepdim)
layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.SUM,
torch_dim_to_trt_axes(dim), keepdim)
output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3)])
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
def test_sum_reduce_all():
return UnaryModule(lambda x: torch.sum(x))
return UnaryModule(lambda x: torch.sum(x))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3)])
Expand Down
6 changes: 6 additions & 0 deletions torch2trt_dynamic/converters/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def convert_topk(ctx):
axis = len(input.shape) - 1
if axis < 0:
axis = len(input.shape) + axis

if k > 3840:
print("warning: topk = " + k +
" > 3840 is not allowed in TensorRT, use 3840 instead.")
k = 3840

largest = get_arg(ctx, 'largest', pos=3, default=True)
topkOp = trt.TopKOperation.MAX if largest else trt.TopKOperation.MIN

Expand Down

0 comments on commit 167e8f0

Please sign in to comment.