From e78aa611274f9a946ae7243fe68427b55d5ddd18 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 1 Oct 2020 00:46:56 -0700 Subject: [PATCH] dynamic conv2d for cuda (#6598) --- include/tvm/tir/expr.h | 6 ++++-- python/tvm/topi/cuda/conv2d_direct.py | 4 +++- python/tvm/topi/cuda/conv2d_nhwc_winograd.py | 14 ++++++++++++-- python/tvm/topi/cuda/conv2d_winograd.py | 16 ++++++++++++++-- src/tir/ir/expr.cc | 6 +++++- tests/python/relay/test_any.py | 6 +++--- 6 files changed, 41 insertions(+), 11 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 9e6f440ee97f..eee0deecdc70 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1085,9 +1085,11 @@ class Reduce : public PrimExpr { /*! \brief Any shape. */ class AnyNode : public PrimExprNode { public: - void VisitAttrs(AttrVisitor* v) {} + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); } - bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { return true; } + bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } void SHashReduce(SHashReducer hash_reduce) const {} diff --git a/python/tvm/topi/cuda/conv2d_direct.py b/python/tvm/topi/cuda/conv2d_direct.py index 2065ab9732fa..e1f3d82cb3e9 100644 --- a/python/tvm/topi/cuda/conv2d_direct.py +++ b/python/tvm/topi/cuda/conv2d_direct.py @@ -117,4 +117,6 @@ def schedule_direct_cuda(cfg, s, conv): N, CO, OH, OW = get_const_tuple(output.shape) _, KH, KW, CI = get_const_tuple(kernel.shape) - cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) diff --git a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py index cc0bbebc0c10..246437a26146 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py @@ -302,6 +302,15 @@ def nhwc_winograd_cuda( tile_size = _infer_tile_size(data, kernel) N, H, W, CI = get_const_tuple(data.shape) + if isinstance(N, tvm.tir.Any): + N = tvm.te.size_var("n") + + if not isinstance(H, int) or not isinstance(W, int): + raise RuntimeError( + "cuda winograd nhwc conv2d doesn't support dynamic \ + input height or width." + ) + if isinstance(dilation, int): dilation_h = dilation_w = dilation else: @@ -330,7 +339,7 @@ def nhwc_winograd_cuda( H = (H + pt + pb - KH) // HSTR + 1 W = (W + pl + pr - KW) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW + P = N * nH * nW if isinstance(N, int) else nH * nW # Determine whether the shape is available with tensorcore shape_judge = ( @@ -432,7 +441,8 @@ def nhwc_winograd_cuda( name="output", tag="conv2d_nhwc_winograd", ) - cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) + if isinstance(N, int): + cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) return output diff --git a/python/tvm/topi/cuda/conv2d_winograd.py b/python/tvm/topi/cuda/conv2d_winograd.py index 69513d58b3b3..11502e134fd5 100644 --- a/python/tvm/topi/cuda/conv2d_winograd.py +++ b/python/tvm/topi/cuda/conv2d_winograd.py @@ -44,6 +44,15 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ N, CI, H, W = get_const_tuple(data.shape) + if isinstance(N, tvm.tir.Any): + N = tvm.te.size_var("n") + + if not isinstance(H, int) or not isinstance(W, int): + raise RuntimeError( + "cuda winograd conv2d doesn't support dynamic input\ + height or width." + ) + if isinstance(dilation, int): dilation_h = dilation_w = dilation else: @@ -73,7 +82,8 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ H = (H + pt + pb - KH) // HSTR + 1 W = (W + pl + pr - KW) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW + + P = N * nH * nW if isinstance(N, int) else nH * nW # transform kernel if not pre_computed: @@ -141,7 +151,9 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ name="output", tag="conv2d_nchw_winograd", ) - cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) + + if isinstance(N, int): + cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) return output diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 687dfd630f1d..f648aca18e46 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -908,7 +908,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Any -Any::Any() { data_ = make_object(); } +Any::Any() { + auto n = make_object(); + n->dtype = DataType::Int(32); + data_ = std::move(n); +} TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([]() { return Any(); }); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index c13f679e6108..b9b58a5a491a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -444,7 +444,7 @@ def verify_any_conv2d( # TODO(@kevinthesun): Support dynamic input height and width. -# TODO(@kevinthesun): Support gpu to enable gpu tests. +@tvm.testing.uses_gpu def test_any_conv2d(): verify_any_conv2d( (relay.Any(), 64, 224, 224), @@ -501,7 +501,7 @@ def verify_any_conv2d_NCHWc( # TODO(@kevinthesun): Support dynamic input height and width. -# TODO(@kevinthesun): Support gpu to enable gpu tests. +@tvm.testing.uses_gpu def test_any_conv2d_NCHWc(): verify_any_conv2d_NCHWc( (relay.Any(), 8, 224, 224, 8), @@ -563,7 +563,7 @@ def verify_any_conv2d_transpose_nchw( # TODO(@kevinthesun): Support dynamic input height and width. -# TODO(@kevinthesun): Support gpu to enable gpu tests. +@tvm.testing.uses_gpu def test_any_conv2d_transpose_nchw(): verify_any_conv2d_transpose_nchw( (relay.Any(), 64, 224, 224),