Skip to content

Commit

Permalink
dynamic conv2d for cuda (apache#6598)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and Tushar Dey committed Oct 15, 2020
1 parent e5ef9fc commit 736a844
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 11 deletions.
6 changes: 4 additions & 2 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,9 +1085,11 @@ class Reduce : public PrimExpr {
/*! \brief Any shape. */
class AnyNode : public PrimExprNode {
public:
void VisitAttrs(AttrVisitor* v) {}
void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }

bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { return true; }
bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {}

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/topi/cuda/conv2d_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,6 @@ def schedule_direct_cuda(cfg, s, conv):

N, CO, OH, OW = get_const_tuple(output.shape)
_, KH, KW, CI = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)

if isinstance(N, int):
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
14 changes: 12 additions & 2 deletions python/tvm/topi/cuda/conv2d_nhwc_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ def nhwc_winograd_cuda(
tile_size = _infer_tile_size(data, kernel)
N, H, W, CI = get_const_tuple(data.shape)

if isinstance(N, tvm.tir.Any):
N = tvm.te.size_var("n")

if not isinstance(H, int) or not isinstance(W, int):
raise RuntimeError(
"cuda winograd nhwc conv2d doesn't support dynamic \
input height or width."
)

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
Expand Down Expand Up @@ -330,7 +339,7 @@ def nhwc_winograd_cuda(
H = (H + pt + pb - KH) // HSTR + 1
W = (W + pl + pr - KW) // WSTR + 1
nH, nW = (H + m - 1) // m, (W + m - 1) // m
P = N * nH * nW
P = N * nH * nW if isinstance(N, int) else nH * nW

# Determine whether the shape is available with tensorcore
shape_judge = (
Expand Down Expand Up @@ -432,7 +441,8 @@ def nhwc_winograd_cuda(
name="output",
tag="conv2d_nhwc_winograd",
)
cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
if isinstance(N, int):
cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)
return output


Expand Down
16 changes: 14 additions & 2 deletions python/tvm/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_

N, CI, H, W = get_const_tuple(data.shape)

if isinstance(N, tvm.tir.Any):
N = tvm.te.size_var("n")

if not isinstance(H, int) or not isinstance(W, int):
raise RuntimeError(
"cuda winograd conv2d doesn't support dynamic input\
height or width."
)

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
Expand Down Expand Up @@ -73,7 +82,8 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
H = (H + pt + pb - KH) // HSTR + 1
W = (W + pl + pr - KW) // WSTR + 1
nH, nW = (H + m - 1) // m, (W + m - 1) // m
P = N * nH * nW

P = N * nH * nW if isinstance(N, int) else nH * nW

# transform kernel
if not pre_computed:
Expand Down Expand Up @@ -141,7 +151,9 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
name="output",
tag="conv2d_nchw_winograd",
)
cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)

if isinstance(N, int):
cfg.add_flop(2 * N * CO * H * W * CI * KH * KW)

return output

Expand Down
6 changes: 5 additions & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

// Any
Any::Any() { data_ = make_object<AnyNode>(); }
Any::Any() {
auto n = make_object<AnyNode>();
n->dtype = DataType::Int(32);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([]() { return Any(); });

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def verify_any_conv2d(


# TODO(@kevinthesun): Support dynamic input height and width.
# TODO(@kevinthesun): Support gpu to enable gpu tests.
@tvm.testing.uses_gpu
def test_any_conv2d():
verify_any_conv2d(
(relay.Any(), 64, 224, 224),
Expand Down Expand Up @@ -501,7 +501,7 @@ def verify_any_conv2d_NCHWc(


# TODO(@kevinthesun): Support dynamic input height and width.
# TODO(@kevinthesun): Support gpu to enable gpu tests.
@tvm.testing.uses_gpu
def test_any_conv2d_NCHWc():
verify_any_conv2d_NCHWc(
(relay.Any(), 8, 224, 224, 8),
Expand Down Expand Up @@ -563,7 +563,7 @@ def verify_any_conv2d_transpose_nchw(


# TODO(@kevinthesun): Support dynamic input height and width.
# TODO(@kevinthesun): Support gpu to enable gpu tests.
@tvm.testing.uses_gpu
def test_any_conv2d_transpose_nchw():
verify_any_conv2d_transpose_nchw(
(relay.Any(), 64, 224, 224),
Expand Down

0 comments on commit 736a844

Please sign in to comment.