diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 0ed8a32eeb89..b5336cd92c51 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -85,7 +85,7 @@ class BlockBuilder(Object): gv0 = bb.emit_output(lv1) bb.emit_func_output(gv0) mod = bb.get() - + BlockBuilder can also be used to contruct neural networks with nn.Module API .. code-block:: python @@ -115,7 +115,7 @@ class BlockBuilder(Object): """ _current = None - + @staticmethod def current(): """Returns the current BlockBuilder.""" @@ -125,18 +125,18 @@ def __init__(self): self._blocks = [] self._context_mod = tvm.IRModule() # a boolean flag that tracks if emit_func_output has been called - self._is_emit_func_output_called = False; + self._is_emit_func_output_called = False self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate) def _begin_dataflow_block(self) -> None: _ffi_api.BlockBuilderBeginDataflowBlock(self) - + def _begin_binding_block(self) -> None: _ffi_api.BlockBuilderBeginBindingBlock(self) def _end_block(self) -> BindingBlock: return _ffi_api.BlockBuilderEndBlock(self) - + def _enter_function_scope(self, name, params): if BlockBuilder.current() is not None: raise RuntimeError("BlockBuilder does not allow nested functions.") @@ -144,22 +144,20 @@ def _enter_function_scope(self, name, params): self._func_name = name self._func_params = params self._begin_binding_block() - + def _exit_function_scope(self, exc_type, exc_val, exc_tb): if exc_type is None: if not self._is_emit_func_output_called: raise RuntimeError("emit_func_output must be called in a relax function.") - + self._is_emit_func_output_called = False BlockBuilder._current = None - def _convert_te_arg(self, - te_args: Any - ) -> typing.Tuple[Any, List[tvm.te.Tensor]]: + def _convert_te_arg(self, te_args: Any) -> typing.Tuple[Any, List[tvm.te.Tensor]]: """Helper function to convert Relax expressions to te tensor. In the common case, the type of te_args is a Relax expression and is converted into a te tensor. - If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), - we recursive and convert any value of type Relax expression into a te tensor. + If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), + we recursive and convert any value of type Relax expression into a te tensor. Common values of type int, float, and str are preserved. Parameters @@ -185,7 +183,9 @@ def _convert_te_arg_helper(arg): return tuple([_convert_te_arg_helper(x) for x in arg]) elif isinstance(arg, (dict, tvm.ir.Map)): for key in arg: - assert isinstance(key, str), "emit_te only supports dict with string as the key currently" + assert isinstance( + key, str + ), "emit_te only supports dict with string as the key currently" return {k: _convert_te_arg_helper(arg[k]) for k in arg} elif isinstance(arg, (int, float, str)): return arg @@ -213,9 +213,9 @@ def _populate_used_vars(expr): diff = used_vars - bound_vars return list(diff) - def function(self, - name: str, - params: Optional[Union[Var, Tuple, List[Var]]] = None) -> FunctionScope: + def function( + self, name: str, params: Optional[Union[Var, Tuple, List[Var]]] = None + ) -> FunctionScope: """Annotate a Relax function. Parameters @@ -239,8 +239,12 @@ def function(self, elif isinstance(params, (list, tuple)): for param in params: if not isinstance(param, rx.Var): - raise TypeError("each element of function parameters must be of type tvm.relax.Var,\ - but got: {}".format(type(param))) + raise TypeError( + "each element of function parameters must be of type tvm.relax.Var,\ + but got: {}".format( + type(param) + ) + ) name = self.get_unique_name(name) return FunctionScope(self, name, params) @@ -297,12 +301,12 @@ def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: type_anno = rx.DynTensorType(2, "float32") x = rx.Var("x", [n, m], type_anno) y = rx.Var("y", [n, m], type_anno) - + def te_func(args, args_dict, msg): A = args[0] B = args_dict["B"] return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) - + with bb.function([x, y], "rx_func"): out = bb.emit_te(te_func, [x], {"B": y}, msg="hello") bb.emit_func_output(out) @@ -389,8 +393,10 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten te_args = te_arg_list + te_kwarg_list te_out = func(*new_args, **new_kwargs) - assert (isinstance(te_out, tvm.te.tensor.Tensor) or \ - (isinstance(te_out, (tuple, list) and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)))), "only support te.tensor or tuple/list of te.tensor as function output" + assert isinstance(te_out, tvm.te.tensor.Tensor) or ( + isinstance(te_out, (tuple, list)) + and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out) + ), "only support te.tensor or tuple/list of te.tensor as function output" outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out) unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs) @@ -402,15 +408,18 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten self._context_mod[gvar] = tir_func call_args = [x.op.value for x in te_args] - output_shape = outs[0].shape if isinstance(te_out, tvm.te.tensor.Tensor) else Tuple([ShapeExpr(x.shape) for x in outs]) + output_shape = ( + outs[0].shape + if isinstance(te_out, tvm.te.tensor.Tensor) + else Tuple([ShapeExpr(x.shape) for x in outs]) + ) # add arguments for extra parameters from unbound var - if (len(unbound_tir_vars) > 0): + if len(unbound_tir_vars) > 0: call = call_tir(output_shape, gvar, call_args, tir_vars=ShapeExpr(unbound_tir_vars)) else: call = call_tir(output_shape, gvar, call_args) return _ffi_api.BlockBuilderEmit(self, call) - def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var: """Emit a MatchShape. @@ -446,16 +455,18 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: output = Tuple(output) return _ffi_api.BlockBuilderEmitOutput(self, output) - def emit_func_output(self, - output: Union[Expr, Tuple, List[Expr]], - params: Optional[Union[Var, Tuple, List[Var]]] = None) -> None: + def emit_func_output( + self, + output: Union[Expr, Tuple, List[Expr]], + params: Optional[Union[Var, Tuple, List[Var]]] = None, + ) -> None: """Emit output for the function. Parameters ---------- output : Expr | Tuple | List[Expr] The output of the current block/function. - + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional The parameters of the function to be built. If params is None, it means the params have been initialized in the function with scope. @@ -470,7 +481,9 @@ def emit_func_output(self, self._is_emit_func_output_called = True if self._func_params is not None and params is not None: - raise RuntimeError("function parameters have been initialized in the function with scope.") + raise RuntimeError( + "function parameters have been initialized in the function with scope." + ) if self._func_params is None and params is None: raise RuntimeError("Relax function must have parameter.") @@ -484,7 +497,7 @@ def emit_func_output(self, if isinstance(output, (list, tuple)): output = Tuple(output) self._func_ret = output - + block = self._end_block() if len(block.bindings) > 0: self._blocks.append(block) @@ -521,7 +534,6 @@ def get(self) -> tvm.IRModule: """ return self._context_mod - def get_unique_name(self, name_prefix: str) -> str: """Generate a unique name with a specified prefix. diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 8ce7b5e5c35e..ca0f7a1b958a 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -116,12 +116,13 @@ def f(x: Tensor[_, "float32"]): check_call(value, "relax.shape_of", [f.params[0]]) -@pytest.mark.xfail def test_dim_var_intro_fail(): - @R.function - def f(x: Tensor[_, _]): - y: Tensor[(n, m), "float32"] = x - return y + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor[_, _]): + y: Tensor[(n, m), "float32"] = x + return y def test_if(): @@ -165,61 +166,66 @@ def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): # TODO: figure out if-else binding type and shape -@pytest.mark.xfail def test_var_redefine_fail(): - @R.function - def f(x, y): - z = add(x, y) - y = z - return y + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x, y): + z = add(x, y) + y = z + return y -@pytest.mark.xfail def test_var_redefine_fail_if(): - @R.function - def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): - y = x - if cond: - w = add(x, x) - y = multiply(w, w) - else: - w = multiply(x, x) - y = add(w, w) - return y + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + y = x + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y @pytest.mark.xfail def test_var_if_scoping_fail(): - @R.function - def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): - if cond: - w = add(x, x) - y = multiply(w, w) - else: - w = multiply(x, x) - y = add(w, w) - return w + # TODO: fix this + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return w -@pytest.mark.xfail def test_if_mismatch_var_fail(): - @R.function - def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): - if cond: - w = add(x, x) - y = multiply(w, w) - else: - w = multiply(x, x) - z = add(w, w) - return z + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + z = add(w, w) + return z -@pytest.mark.xfail def test_unassigned_call_fail(): - @R.function - def f(x: Tensor[_, _]): - add(x, x) - return x + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor[_, _]): + add(x, x) + return x def test_tuple(): @@ -330,79 +336,81 @@ def f(x: Tensor[_, _]): @pytest.mark.xfail def test_dataflow_scope_fail(): - @R.function - def f(x: Tensor[_, _]): - with relax.dataflow(): - y = add(x, x) - z = multiply(y, x) - w = subtract(z, x) - relax.output(y, w) - t = divide(y, z) - return t + with pytest.raises(tvm.error.DiagnosticError): + # TODO + @R.function + def f(x: Tensor[_, _]): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, w) + t = divide(y, z) + return t -@pytest.mark.xfail def test_dataflow_syntax_fail_pattern(): - @R.function - def f(x: Tensor[_, _]): - with relax.dataflow() as df: - y = add(x, x) - z = multiply(y, x) - w = subtract(z, x) - relax.output(y, z) - t = divide(y, z) - return t + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(x: Tensor[_, _]): + with relax.dataflow() as df: + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, z) + t = divide(y, z) + return t -@pytest.mark.xfail def test_dataflow_syntax_fail_params(): - @R.function - def f(x: Tensor[_, _]): - with relax.dataflow(x) as df: - y = add(x, x) - z = multiply(y, x) - w = subtract(z, x) - relax.output(y, w) - t = divide(y, z) - return t + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(x: Tensor[_, _]): + with relax.dataflow(x) as df: + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, w) + t = divide(y, z) + return t -@pytest.mark.xfail def test_dataflow_unbound_outputs(): - @R.function - def f(x: Tensor[_, _]): - with relax.dataflow(): - y = add(x, x) - z = multiply(y, x) - w = subtract(z, x) - relax.output(x, y, w, q) - t = divide(y, z) - return t + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(x: Tensor[_, _]): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(x, y, w, q) + t = divide(y, z) + return t -@pytest.mark.xfail def test_invalid_special_op_dataflow(): - @R.function - def f(x: Tensor): - y = add(x, x) - z = relax.dataflow() - return z + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(x: Tensor): + y = add(x, x) + z = relax.dataflow() + return z -@pytest.mark.xfail def test_invalid_special_op_output(): - @R.function - def f(x: Tensor): - y = add(x, x) - z = relax.output(y) - return z + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(x: Tensor): + y = add(x, x) + z = relax.output(y) + return z -@pytest.mark.xfail def test_func_no_return_fail(): - @R.function - def f(x: Tensor[_, _]): - y = add(x, x) + with pytest.raises(tvm.error.DiagnosticError): + @R.function + def f(x: Tensor[_, _]): + y = add(x, x) def test_inline_tir():