From 00b269d8cb2fbf30c7d3bb0545b900f3beda1c5d Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 16 Sep 2020 20:32:20 -0700 Subject: [PATCH] [Frontend][Pytorch] Improve Pytorch frontend for object detection models (#6449) * Improve Pytorch Frontend * Add tests * Fix pylint * Improve data cast * Use int64 for slice axis * Fix lint * fix roi_align(..., aligned=True) * Minor fix * Add e2e test * Add asf header * Minor change * Use dynamic topk * Improve test * Rollback topk * py format * remove print * More improve * Fix test * Improve addmm * Fix test * Fix format * Fix format * Fix test scatter Co-authored-by: q.yao --- python/tvm/relay/frontend/pytorch.py | 606 +++++++++++++----- tests/python/frontend/pytorch/test_forward.py | 87 ++- .../frontend/pytorch/test_object_detection.py | 139 ++++ 3 files changed, 688 insertions(+), 144 deletions(-) create mode 100644 tests/python/frontend/pytorch/test_object_detection.py diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 886729bff51f..c9320a9b2882 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -25,6 +25,7 @@ import numpy as np import tvm +from tvm.topi.util import get_const_tuple from .. import analysis as _analysis from .. import expr as _expr @@ -184,8 +185,14 @@ def _impl(inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): - return _op.cast(val, dtype) - return _create_typed_const(val, dtype) + try: + ret = _infer_value(_op.cast(val, dtype), {}).asnumpy() + ret = _expr.const(ret, dtype) + except Exception: + ret = _op.cast(val, dtype) + else: + ret = _create_typed_const(val, dtype) + return ret def _get_type(val, inp_type): if isinstance(val, _expr.Expr): @@ -205,9 +212,9 @@ def _get_type(val, inp_type): dtype = "float32" else: dtype = "int64" - start = _get_value(0, dtype) + start = _expr.const(0, dtype) stop = _get_value(inputs[0], dtype) - step = _get_value(1, dtype) + step = _expr.const(1, dtype) elif len(inputs) == 7: types = [_get_type(inputs[i], input_types[i]) for i in range(3)] if inputs[3] is not None: @@ -282,38 +289,103 @@ def _impl(inputs, input_types): def _slice(): def _impl(inputs, input_types): + axis_dtype = "int64" + index_size_limit = 2 ** 63 - 1 data = inputs[0] - strides = [] - - if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data) - end = [] - for infer in inferred_shape: - end.append(int(infer)) - if isinstance(data, _expr.Var): - end = inferred_shape - end = list(end) - else: - end = data.shape + dshape = _infer_shape(data) + ndim = len(dshape) + end = [] + for dim in dshape: + if isinstance(dim, tvm.tir.Any): + end = _op.shape_of(data) + break + end.append(int(dim)) - begin = [0] * len(end) + begin = [0] * ndim dim = int(inputs[1]) + stride = int(inputs[4]) if isinstance(inputs[2], _expr.Call): - begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + try: + begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + except Exception: + begin[dim] = inputs[2] else: begin[dim] = int(inputs[2]) + # Process begin + if not isinstance(begin[dim], int): + tmp = [] + for b in begin: + if isinstance(b, int): + tmp.append(_op.expand_dims(_expr.const(b, axis_dtype), axis=0)) + else: + tmp.append(_op.cast(_op.expand_dims(b, axis=0), axis_dtype)) + begin = _op.concatenate(tmp, axis=0) + btype = _infer_type(begin).checked_type.dtype + if str(btype) != axis_dtype: + begin = _op.cast(begin, axis_dtype) + if isinstance(inputs[3], str) and inputs[3].isdigit(): - end[dim] = min(end[dim], int(inputs[3])) + target_end = int(inputs[3]) else: - if isinstance(inputs[3], _expr.Call): - target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + if isinstance(inputs[3], _expr.Expr): + try: + target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + except Exception: + target_end = inputs[3] else: target_end = inputs[3] - end[dim] = min(end[dim], target_end) + if isinstance(target_end, int) and target_end >= index_size_limit: + # Quick path for original data. + if ( + isinstance(begin, _expr.Constant) + and begin.data.asnumpy().tolist()[dim] == 0 + and stride == 1 + ): + return data + target_end = dshape[dim] + + # Process end + if isinstance(target_end, int): + if isinstance(end, list): + end[dim] = target_end + else: + all_static = True + for i, shape_dim in enumerate(dshape): + if i != dim and isinstance(shape_dim, tvm.tir.Any): + all_static = False + + if all_static: + end = list(get_const_tuple(dshape)) + end[dim] = target_end + else: + target_end = _expr.const(target_end) + end = _op.scatter( + end, + _op.expand_dims(_expr.const(dim), axis=0), + _op.expand_dims(target_end, axis=0), + axis=0, + ) + else: + end = _op.cast(_op.shape_of(data), axis_dtype) + if not isinstance(target_end, tvm.tir.Any): + ttype = _infer_type(target_end).checked_type.dtype + if str(ttype) != axis_dtype: + target_end = _op.cast(target_end, axis_dtype) + end = _op.scatter( + end, + _op.expand_dims(_expr.const(dim), axis=0), + _op.expand_dims(target_end, axis=0), + axis=0, + ) + + if not isinstance(end, list): + etype = _infer_type(end).checked_type.dtype + if str(etype) != axis_dtype: + end = _op.cast(end, axis_dtype) - strides = [1] * len(end) + strides = [1] * ndim strides[dim] = int(inputs[4]) return _op.transform.strided_slice( @@ -380,16 +452,23 @@ def _impl(inputs, input_types): def _topk(): def _impl(inputs, input_types): data = inputs[0] - k = int(inputs[1]) axis = int(inputs[2]) is_ascend = not bool(inputs[3]) sort = bool(inputs[4]) + if isinstance(inputs[1], _expr.Expr): + try: + k = _infer_value(inputs[1], {}).asnumpy().tolist() + except Exception: + k = inputs[1] + else: + k = inputs[1] + if not sort: msg = "Currently supports only sorted output for topk operator." raise AssertionError(msg) - outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both") + outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both", dtype="int64") return outs[0], outs[1] @@ -407,7 +486,7 @@ def _impl(inputs, input_types): def _repeat(): def _impl(inputs, input_types): data = inputs[0] - reps = _get_dims(inputs[1]) + reps = inputs[1] return _op.transform.tile(data, reps=reps) return _impl @@ -455,36 +534,73 @@ def _impl(inputs, input_types): return _impl -def _ones(): +def _full_impl(data, fill_value, dtype): + size = [] + need_reshape = False + new_shape = [] + for dim in data: + if isinstance(dim, _expr.Expr): + if isinstance(dim, _expr.Constant): + dim = int(dim.data.asnumpy()) + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) + else: + try: + dim = int(_infer_value(dim, {}).asnumpy()) + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) + except Exception: + size = None + need_reshape = True + new_shape.append(0) + else: + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) + + if size is None: + tmp = [] + for dim in data: + tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) + size = _op.concatenate(tmp, axis=0) + + out = _op.full(_expr.const(fill_value), size, dtype=dtype) + if need_reshape: + out = _op.reshape(out, new_shape) + return out + + +def _ones(default_dtype): def _impl(inputs, input_types): data = inputs[0] import torch - if isinstance(data, _expr.Expr): - shape = _infer_shape(data) - elif isinstance(data, list): - shape = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - shape = data.shape - else: + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) - dtype = _convert_dtype_value(inputs[1]) - - return _op.full(_expr.const(1), shape, dtype=dtype) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype + return _full_impl(data, 1, dtype) return _impl -def _ones_like(): +def _ones_like(default_dtype): def _impl(inputs, input_types): data = inputs[0] out = _op.ones_like(data) # If the input and the output datatype is different, do a cast - dtype = _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype if input_types[0] != dtype: out = _op.cast(out, dtype) @@ -493,36 +609,35 @@ def _impl(inputs, input_types): return _impl -def _zeros(): +def _zeros(default_dtype): def _impl(inputs, input_types): data = inputs[0] import torch - if isinstance(data, _expr.Expr): - shape = _infer_shape(data) - elif isinstance(data, list): - shape = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - shape = data.shape - else: + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) - dtype = _convert_dtype_value(inputs[1]) - - return _op.full(_expr.const(0), shape, dtype=dtype) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype + return _full_impl(data, 0, dtype) return _impl -def _zeros_like(): +def _zeros_like(default_dtype): def _impl(inputs, input_types): data = inputs[0] out = _op.zeros_like(data) # If the input and the output datatype is different, do a cast - dtype = _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) @@ -534,18 +649,12 @@ def _impl(inputs, input_types): def _full(default_dtype): def _impl(inputs, input_types): data = inputs[0] - fill_value = inputs[1] + import torch - if isinstance(data, _expr.Expr): - shape = _infer_shape(data) - elif isinstance(data, list): - shape = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - shape = data.shape - else: - msg = "Data type %s could not be parsed in zeros op" % (type(data)) + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): + msg = "Data type %s could not be parsed in full op" % (type(data)) raise AssertionError(msg) if inputs[2] is not None: # dtype given @@ -554,12 +663,12 @@ def _impl(inputs, input_types): # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() dtype = default_dtype - return _op.full(_expr.const(fill_value), shape, dtype=dtype) + return _full_impl(data, fill_value, dtype) return _impl -def _full_like(): +def _full_like(default_dtype): def _impl(inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -567,7 +676,11 @@ def _impl(inputs, input_types): out = _op.full_like(data, _expr.const(fill_value)) # If the input and the output datatype is different, do a cast - dtype = _convert_dtype_value(inputs[2]) + if inputs[2] is not None: # dtype given + dtype = _convert_dtype_value(inputs[2]) + else: + # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() + dtype = default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) @@ -1100,55 +1213,60 @@ def _impl(inputs, input_types): def _flatten(): def _impl(inputs, input_types): data = inputs[0] - start_dim = inputs[1] if len(inputs) > 0 else 0 - end_dim = inputs[2] if len(inputs) > 1 else -1 - - if start_dim == 0 and end_dim == -1: - return _op.transform.reshape(data, (-1,)) - if start_dim == 1 and end_dim == -1: - return _op.nn.batch_flatten(data) - - raise NotImplementedError("Only support 1d flatten or batch flatten") + start = int(inputs[1]) + end = int(inputs[2]) + dshape = get_const_tuple(_infer_shape(data)) + ndim = len(dshape) + if end < 0: + end += ndim + new_shape = [0] * start + + new_shape.append(-1) + squeeze_axes = [] + for i in range(start + 1, end + 1): + new_shape.append(1) + squeeze_axes.append(i) + for _ in range(end + 1, ndim): + new_shape.append(0) + out = _op.reshape(data, new_shape) + if squeeze_axes: + out = _op.squeeze(out, axis=squeeze_axes) + return out return _impl -def _dense(): +def _addmm(): def _impl(inputs, input_types): - use_bias = isinstance(inputs[0], _expr.Expr) - - data = inputs[1] + input_mat = inputs[0] + mat1 = inputs[1] data_type = input_types[1] - weight = inputs[2] + mat2 = inputs[2] beta = inputs[3] alpha = inputs[4] if not isinstance(alpha, _expr.Expr) and alpha != 1: alpha = _create_typed_const(alpha, data_type) - data *= alpha + mat1 *= alpha if not isinstance(beta, _expr.Expr) and beta != 1: beta = _create_typed_const(beta, data_type) - weight *= beta + mat2 *= beta - weight_out = _op.transform.transpose(weight, axes=[1, 0]) + transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0]) - units = _infer_shape(weight_out)[0] - dense_out = _op.nn.dense(data, weight_out, units=units) + units = _infer_shape(transposed_mat2)[0] + dense_out = _op.nn.dense(mat1, transposed_mat2, units=units) - if use_bias: - bias = inputs[0] - return _op.nn.bias_add(dense_out, bias) - else: - return dense_out + return dense_out + input_mat return _impl def _size(prelude): def _impl_dynamic(inp, axis): - shape_dynamic = _op.shape_of(inp) + shape_dynamic = _op.shape_of(inp, dtype="int32") if axis is not None: return _op.take(shape_dynamic, _expr.const(axis), 0) return shape_dynamic @@ -1164,8 +1282,8 @@ def _impl(inputs, input_types): return _impl_dynamic(inputs[0], axis) if axis is not None: - return shape[axis] - return shape + return _expr.const(shape[axis]) + return _expr.const(shape) return _impl @@ -1220,12 +1338,34 @@ def _impl(inputs, input_types): def _reshape(): def _impl(inputs, input_types): data = inputs[0] - if _is_int_seq(inputs[1]): - new_shape = inputs[1] + new_shape = inputs[1] + + tmp_shape = [] + is_dyn = False + for s in new_shape: + if isinstance(s, _expr.Constant): + tmp_shape.append(int(s.data.asnumpy())) + elif isinstance(s, _expr.Expr): + try: + dim = int(_infer_value(s, {}).asnumpy()) + tmp_shape.append(dim) + except Exception: + is_dyn = True + tmp_shape.append(s) + else: + tmp_shape.append(s) + + if is_dyn: + new_shape = [] + for i, s in enumerate(tmp_shape): + if not isinstance(s, _expr.Expr): + s = _expr.const(s, "int64") + else: + s = _op.cast(s, "int64") + new_shape.append(_op.expand_dims(s, axis=0)) + new_shape = _op.concatenate(new_shape, axis=0) else: - assert isinstance(inputs[1], list) - infer_res = [_infer_value(_wrap_const(size), {}) for size in inputs[1]] - new_shape = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] + new_shape = tmp_shape return _op.transform.reshape(data, new_shape) return _impl @@ -1577,12 +1717,11 @@ def _impl(inputs, input_types): def _expand(): def _impl(inputs, input_types): data_in = inputs[0] - if isinstance(data_in, _expr.Expr): - shape = list(_infer_shape(data_in)) + shape = list(_infer_shape(data_in)) ndims = len(shape) sizes = inputs[1] - out = inputs[0] + out = data_in out_dims = len(sizes) if ndims < out_dims: @@ -1590,14 +1729,11 @@ def _impl(inputs, input_types): out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis) shape = [1] * num_newaxis + shape - for i in range(ndims): - if sizes[i] == -1 or sizes[i] == shape[i]: - continue - data = list() - for temp in range(sizes[i]): - data.append(out) - - out = _op.tensor.concatenate(data, i) + for i in range(out_dims): + if sizes[i] != -1 and shape[i] == 1: + if not isinstance(sizes[i], int): + sizes[i] = int(_infer_value(sizes[i], {}).asnumpy()) + out = _op.repeat(out, sizes[i], axis=i) return out @@ -1652,10 +1788,18 @@ def _impl(inputs, input_types): # group into tuple of 2 ints paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] + const_paddings = [] + for pad in paddings: + const_paddings.append([]) + for p in pad: + if not isinstance(p, int): + p = int(_infer_value(p, {}).asnumpy()) + const_paddings[-1].append(p) + if mode == "constant": - return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode) + return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) else: - return _op.nn.pad(data, paddings, pad_mode=mode) + return _op.nn.pad(data, const_paddings, pad_mode=mode) return _impl @@ -1673,36 +1817,40 @@ def _impl(inputs, input_types): def _to(): def _impl(inputs, input_types): data = inputs[0] - if inputs[3] in ["cpu", "cuda"]: - return data + dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2] # special handling for aten::to(data, 6, _, _, _) case # 6 means dtype = float # this happens when converting upsampling with scale factor - cast_func = {6: float, 7: float, 3: int, 4: int} - cast_func_expr = { - 6: lambda x: _op.cast(x, "float32"), - 7: lambda x: _op.cast(x, "float64"), - 3: lambda x: _op.cast(x, "int32"), - 4: lambda x: _op.cast(x, "int64"), + cast_map = { + 6: "float32", + 7: "float64", + 3: "int32", + 4: "int64", } - if inputs[1] in cast_func and not isinstance(data, _expr.Expr): - return cast_func[inputs[1]](data) - elif inputs[1] in cast_func_expr and isinstance(data, _expr.Expr): - return cast_func_expr[inputs[1]](data) - return data + + cast_func = {6: float, 7: float, 3: int, 4: int} + + ret = data + if isinstance(data, _expr.Expr): + actual_dtype = str(_infer_type(data).checked_type.dtype) + if dtype in cast_map and cast_map[dtype] != actual_dtype: + ret = _op.cast(data, cast_map[dtype]) + elif dtype in cast_map: + ret = cast_func[dtype](data) + + return ret return _impl def _upsample(method, prelude): def _impl(inputs, input_types): - if isinstance(inputs[1], _expr.Var): - out_size = _infer_shape(inputs[1]) - elif _is_int_seq(inputs[1]): - out_size = inputs[1] - elif isinstance(inputs[1], list): - infer_res = [_infer_value(size, {}) for size in inputs[1]] - out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] + out_size = [] + for size in inputs[1]: + if not isinstance(size, int): + out_size.append(int(_infer_value(size, {}).asnumpy())) + else: + out_size.append(size) data = inputs[0] @@ -1772,11 +1920,12 @@ def _impl(inputs, input_types): def _expand_as(): def _impl(inputs, input_types): - # TODO: maybe fix this - # This assumes expand_as can be removed because TVM has broadcast op - msg = "aten::expand_as(...) found, assume it is part of broadcast op" - logging.warning(msg) - return inputs[0] + target = inputs[1] + t0 = _infer_type(inputs[0]).checked_type.dtype + t1 = _infer_type(inputs[1]).checked_type.dtype + if str(t0) != str(t1): + target = _op.cast(target, t0) + return _op.broadcast_to_like(inputs[0], target) return _impl @@ -2047,6 +2196,148 @@ def _impl(inputs, input_types): return _impl +def _roi_align(prelude): + def _impl(inputs, input_types): + data = inputs[0] + boxes = inputs[1] + + output_size = (inputs[3], inputs[4]) + spatial_scale = inputs[2] + sample_ratio = inputs[5] + aligned = False if len(inputs) < 7 else inputs[6] + + if aligned: + boxes -= _expr.const(0.5 / spatial_scale) + + return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) + + return _impl + + +def _unbind(): + def _impl(inputs, input_types): + data = inputs[0] + dim = int(inputs[1]) + ishapes = _infer_shape(data) + if dim >= len(ishapes): + msg = "Please check input dim, it shouldn't" "be greater than or equal to rank." + raise AttributeError(msg) + + selections = ishapes[dim] + res_split = _op.split(data, selections, dim) + # squeeze each split piece to get same shape as aten::unbind + # TODO (yongwww): add new op to avoid the squeeze overhead + ret = [] + for i in range(selections): + ret.append(_op.transform.squeeze(res_split[i], axis=[dim])) + ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) + return ret + + return _impl + + +def _shape_as_tensor(prelude): + def _impl(inputs, input_types): + is_symbolic_shape = False + input_shape = _infer_shape(inputs[0], prelude.mod) + for axis in input_shape: + if not isinstance(axis, (int, tvm.tir.IntImm)): + is_symbolic_shape = True + break + + if is_symbolic_shape: + ret = _op.shape_of(inputs[0], dtype="int64") + else: + ret = _expr.const(np.array(input_shape), dtype="int64") + + return ret + + return _impl + + +def _logical_and(): + def _impl(inputs, input_types): + lhs = _op.cast(inputs[0], "bool") + rhs = _op.cast(inputs[1], "bool") + + return _op.logical_and(lhs, rhs) + + return _impl + + +def _nonzero(is_numpy_style): + def _impl(inputs, input_types): + data = inputs[0] + ret = _op.transform.argwhere(data) + + if is_numpy_style or (len(inputs) > 1 and inputs[1]): + # TODO(kevinthesun): Support this by adding unbind op + # ret = _unbind()([ret, 0], None) + raise RuntimeError("as_tuple is not supported yet for nonzero.") + return ret + + return _impl + + +def _scatter(): + def _impl(inputs, input_types): + data = inputs[0] + axis = int(inputs[1]) + index = inputs[2] + src = inputs[3] + return _op.transform.scatter(data, index, src, axis) + + return _impl + + +def _scalar_tensor(): + def _impl(inputs, input_types): + data = inputs[0] + cast_map = { + 6: "float32", + 7: "float64", + 3: "int32", + 4: "int64", + } + type_key = inputs[1] + if isinstance(data, _expr.Constant): + data = data.data.asnumpy().tolist() + return _expr.const(data, cast_map[type_key]) + + return _impl + + +def _interpolate(): + def _impl(inputs, input_types): + if isinstance(inputs[1], _expr.Expr): + out_size = inputs[1] + elif isinstance(inputs[1], list): + try: + infer_res = [_infer_value(size, {}) for size in inputs[1]] + out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] + except Exception: + h = _op.expand_dims(inputs[1][0], axis=0) + w = _op.expand_dims(inputs[1][1], axis=0) + out_size = _op.concatenate([h, w], axis=0) + + data = inputs[0] + align_corners = inputs[4] + method = inputs[3] + if method.startswith("nearest"): + method = "nearest_neighbor" + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return _op.image.resize(data, out_size, "NCHW", method, coord_trans) + + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2087,6 +2378,14 @@ def _pytorch_result_type(dtypes, non_tensor_inputs): def _pytorch_promote_types(inputs, dtypes): """This promotes TVM inputs with TVM dtypes passed like PyTorch would""" + actual_dtypes = [] + for i, inp in enumerate(inputs): + if isinstance(inp, _expr.Expr): + idt = _infer_type(inp).checked_type.dtype + actual_dtypes.append(idt) + else: + actual_dtypes.append(dtypes[i]) + dtypes = actual_dtypes tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)] non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)] result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs) @@ -2154,6 +2453,8 @@ def _convert_data_type(input_type, default_dtype=None): return "qint32" elif input_type in ["bool", "torch.bool"]: return "bool" + elif input_type in ["str"]: + return "str" else: raise NotImplementedError("input_type {} is not handled yet".format(input_type)) return "float32" # Never reached @@ -2210,12 +2511,12 @@ def _get_convert_map(prelude, default_dtype): "aten::floor_divide": _elemwise("floor_divide"), "aten::addcdiv": _addcdiv(), "aten::addcmul": _addcmul(), - "aten::ones": _ones(), - "aten::ones_like": _ones_like(), - "aten::zeros": _zeros(), - "aten::zeros_like": _zeros_like(), + "aten::ones": _ones(default_dtype), + "aten::ones_like": _ones_like(default_dtype), + "aten::zeros": _zeros(default_dtype), + "aten::zeros_like": _zeros_like(default_dtype), "aten::full": _full(default_dtype), - "aten::full_like": _full_like(), + "aten::full_like": _full_like(default_dtype), "aten::linspace": _linspace(), "aten::reciprocal": _reciprocal(), "aten::repeat": _repeat(), @@ -2263,7 +2564,7 @@ def _get_convert_map(prelude, default_dtype): "aten::transpose_": _transpose(prelude), "aten::t": _transpose(prelude), "aten::flatten": _flatten(), - "aten::addmm": _dense(), + "aten::addmm": _addmm(), "aten::size": _size(prelude), "aten::view": _view(), "aten::reshape": _reshape(), @@ -2364,6 +2665,15 @@ def _get_convert_map(prelude, default_dtype): "aten::index": _index(), "torchvision::nms": _nms(prelude), "aten::logsumexp": _logsumexp(), + "torchvision::roi_align": _roi_align(prelude), + "aten::unbind": _unbind(), + "aten::__and__": _logical_and(), + "aten::_shape_as_tensor": _shape_as_tensor(prelude), + "aten::nonzero": _nonzero(False), + "aten::nonzero_numpy": _nonzero(True), + "aten::scatter": _scatter(), + "aten::scalar_tensor": _scalar_tensor(), + "aten::__interpolate": _interpolate(), } return convert_map @@ -2512,7 +2822,7 @@ def _get_constant(node): # TODO(t-vi): When is this needed? return tensor.item() return _wrap_const(tensor.numpy()) - elif ty == "DeviceObjType": + elif ty in ["DeviceObjType", "StringType"]: return node.s(attr_name) elif ty == "FunctionType": return None @@ -2799,7 +3109,17 @@ def get_input(index): def get_var(name, val): if val: checked_type = _infer_type_with_prelude(val, prelude) - return _expr.var(name, type_annotation=checked_type) + if hasattr(checked_type, "shape"): + shape = get_const_tuple(checked_type.shape) + actual_shape = [] + for dim in shape: + if isinstance(dim, int) and dim == 0: + actual_shape.append(Any()) + else: + actual_shape.append(dim) + return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + else: + return _expr.var(name, type_annotation=checked_type) return _expr.var(name) loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) @@ -2816,7 +3136,7 @@ def get_var(name, val): var for var in _get_free_vars_from_block(body_block) if var in outputs - and not isinstance(outputs[var], (_expr.Constant, int, float)) + and not isinstance(outputs[var], (_expr.Constant, int, float, str)) and outputs[var] ] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index fe7be5b3fa8a..e8a8507158a3 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1679,6 +1679,36 @@ def _gen_rand_inputs(num_boxes): verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores], targets) +def test_forward_roi_align(): + """ROI align""" + torch.set_grad_enabled(False) + + class ROIAlgin(Module): + def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1): + super().__init__() + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.output_sizes = output_sizes + + def forward(self, *args): + return torchvision.ops.roi_align( + args[0], + args[1], + self.output_sizes, + self.spatial_scale, + self.sampling_ratio, + ) + + in_data = torch.Tensor(np.random.uniform(size=(1, 8, 100, 100))) + in_boxes = torch.Tensor(np.random.uniform(0.0, 100.0, size=(35, 4))) + in_batch = torch.zeros((35, 1), dtype=torch.float) + in_boxes = torch.cat([in_batch, in_boxes], dim=1) + + verify_model(ROIAlgin(7), [in_data, in_boxes]) + verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes]) + verify_model(ROIAlgin(15, 0.9, 3), [in_data, in_boxes]) + + @tvm.testing.uses_gpu def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), (1, 32, 13, 7, 7)]: @@ -3025,6 +3055,57 @@ def forward(self, x): verify_script_model(Stack(), [(8, 8, 8)], _get_default_vm_targets()) +def test_forward_unbind(): + class Unbind(torch.nn.Module): + def __init__(self, axis=0): + super().__init__() + self.axis = axis + + def forward(self, x): + return torch.unbind(x, self.axis) + + inp = torch.randn(8, 8, 8) + verify_model(Unbind(0), input_data=inp) + verify_model(Unbind(1), input_data=inp) + verify_model(Unbind(2), input_data=inp) + + +def test_forward_nonzero(): + class Nonzero(Module): + def __init__(self, as_tuple=False): + super().__init__() + self.as_tuple = as_tuple + + def forward(self, data): + return torch.nonzero(data, as_tuple=self.as_tuple) + + inp = torch.Tensor(np.array([[0, 1, 0], [2, 0, 9], [-1, -1, 0]]).astype("float32")) + verify_trace_model(Nonzero(), [inp], ["llvm"]) + + +def test_forward_scatter(): + class Scatter(Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, data, index, src): + return torch.scatter(data, dim=self.dim, index=index, src=src) + + in_data = torch.zeros(3, 5) + in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) + in_src = torch.rand(2, 5) + # TODO: add scatter gpu schedule to enable gpu test. + verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"]) + + in_data = torch.zeros(2, 4) + in_index = torch.tensor([[2], [3]]) + in_src = torch.rand(2, 1) + + # TODO: add scatter gpu schedule to enable gpu test. + verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) + + def test_forward_pretrained_bert_base_uncased(): ###################################################################### # This is an example how to run BERT models using TVM @@ -3264,6 +3345,7 @@ def test_forward_pretrained_bert_base_uncased(): test_upsample() test_forward_upsample3d() test_forward_nms() + test_forward_roi_align() test_to() test_flatten() test_type_as() @@ -3285,6 +3367,9 @@ def test_forward_pretrained_bert_base_uncased(): test_logsumexp() test_stack() test_stack_dynamic() + test_forward_unbind() + test_forward_nonzero() + test_forward_scatter() # Model tests test_resnet18() @@ -3314,7 +3399,7 @@ def test_forward_pretrained_bert_base_uncased(): test_simple_rnn() # More complex recurrent models - from lstm_test import test_custom_lstm + from test_lstm import test_custom_lstm test_custom_lstm() diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py new file mode 100644 index 000000000000..f5197494a345 --- /dev/null +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument +"""Test torch vision fasterrcnn and maskrcnn models""" +import numpy as np +import torch +import torchvision +import cv2 + +import tvm + +from tvm import relay +from tvm.runtime.vm import VirtualMachine +from tvm.contrib.download import download + + +in_size = 300 + + +def process_image(img): + img = cv2.imread(img).astype("float32") + img = cv2.resize(img, (in_size, in_size)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img / 255.0).permute(2, 0, 1).float() + img = torch.unsqueeze(img, axis=0) + + return img + + +def do_trace(model, inp, in_size=in_size): + model_trace = torch.jit.trace(model, inp) + model_trace.eval() + return model_trace + + +def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + +class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + +def generate_jit_model(index): + model_funcs = [ + torchvision.models.detection.fasterrcnn_resnet50_fpn, + torchvision.models.detection.maskrcnn_resnet50_fpn, + ] + + model_func = model_funcs[index] + model = TraceWrapper(model_func(pretrained=True)) + + model.eval() + inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) + + with torch.no_grad(): + out = model(inp) + + script_module = do_trace(model, inp) + script_out = script_module(inp) + + assert len(out[0]) > 0 and len(script_out[0]) > 0 + return script_module + + +def test_detection_models(): + img = "test_street_small.jpg" + img_url = ( + "https://raw.githubusercontent.com/dmlc/web-data/" + "master/gluoncv/detection/street_small.jpg" + ) + download(img_url, img) + + input_shape = (1, 3, in_size, in_size) + target = "llvm" + input_name = "input0" + shape_list = [(input_name, input_shape)] + score_threshold = 0.9 + + scripted_model = generate_jit_model(1) + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): + vm_exec = relay.vm.compile(mod, target=target, params=params) + + ctx = tvm.cpu() + vm = VirtualMachine(vm_exec, ctx) + data = process_image(img) + pt_res = scripted_model(data) + data = data.detach().numpy() + vm.set_input("main", **{input_name: data}) + tvm_res = vm.run() + + # Note: due to accumulated numerical error, we can't directly compare results + # with pytorch output. Some boxes might have a quite tiny difference in score + # and the order can become different. We just measure how many valid boxes + # there are for input image. + pt_scores = pt_res[1].detach().numpy().tolist() + tvm_scores = tvm_res[1].asnumpy().tolist() + num_pt_valid_scores = num_tvm_valid_scores = 0 + + for score in pt_scores: + if score >= score_threshold: + num_pt_valid_scores += 1 + else: + break + + for score in tvm_scores: + if score >= score_threshold: + num_tvm_valid_scores += 1 + else: + break + + assert num_pt_valid_scores == num_tvm_valid_scores, ( + "Output mismatch: Under score threshold {}, Pytorch has {} valid " + "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) + )