Skip to content

Commit

Permalink
xfail => pytest.raises; fix a unittest (apache#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and YuchenJin committed Nov 17, 2022
1 parent 6a0a2d1 commit 7f5ae9c
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 133 deletions.
76 changes: 44 additions & 32 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class BlockBuilder(Object):
gv0 = bb.emit_output(lv1)
bb.emit_func_output(gv0)
mod = bb.get()
BlockBuilder can also be used to contruct neural networks with nn.Module API
.. code-block:: python
Expand Down Expand Up @@ -115,7 +115,7 @@ class BlockBuilder(Object):
"""

_current = None

@staticmethod
def current():
"""Returns the current BlockBuilder."""
Expand All @@ -125,41 +125,39 @@ def __init__(self):
self._blocks = []
self._context_mod = tvm.IRModule()
# a boolean flag that tracks if emit_func_output has been called
self._is_emit_func_output_called = False;
self._is_emit_func_output_called = False
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate)

def _begin_dataflow_block(self) -> None:
_ffi_api.BlockBuilderBeginDataflowBlock(self)

def _begin_binding_block(self) -> None:
_ffi_api.BlockBuilderBeginBindingBlock(self)

def _end_block(self) -> BindingBlock:
return _ffi_api.BlockBuilderEndBlock(self)

def _enter_function_scope(self, name, params):
if BlockBuilder.current() is not None:
raise RuntimeError("BlockBuilder does not allow nested functions.")
BlockBuilder._current = self
self._func_name = name
self._func_params = params
self._begin_binding_block()

def _exit_function_scope(self, exc_type, exc_val, exc_tb):
if exc_type is None:
if not self._is_emit_func_output_called:
raise RuntimeError("emit_func_output must be called in a relax function.")

self._is_emit_func_output_called = False
BlockBuilder._current = None

def _convert_te_arg(self,
te_args: Any
) -> typing.Tuple[Any, List[tvm.te.Tensor]]:
def _convert_te_arg(self, te_args: Any) -> typing.Tuple[Any, List[tvm.te.Tensor]]:
"""Helper function to convert Relax expressions to te tensor.
In the common case, the type of te_args is a Relax expression and is converted into a te tensor.
If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array),
we recursive and convert any value of type Relax expression into a te tensor.
If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array),
we recursive and convert any value of type Relax expression into a te tensor.
Common values of type int, float, and str are preserved.
Parameters
Expand All @@ -185,7 +183,9 @@ def _convert_te_arg_helper(arg):
return tuple([_convert_te_arg_helper(x) for x in arg])
elif isinstance(arg, (dict, tvm.ir.Map)):
for key in arg:
assert isinstance(key, str), "emit_te only supports dict with string as the key currently"
assert isinstance(
key, str
), "emit_te only supports dict with string as the key currently"
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
elif isinstance(arg, (int, float, str)):
return arg
Expand Down Expand Up @@ -213,9 +213,9 @@ def _populate_used_vars(expr):
diff = used_vars - bound_vars
return list(diff)

def function(self,
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None) -> FunctionScope:
def function(
self, name: str, params: Optional[Union[Var, Tuple, List[Var]]] = None
) -> FunctionScope:
"""Annotate a Relax function.
Parameters
Expand All @@ -239,8 +239,12 @@ def function(self,
elif isinstance(params, (list, tuple)):
for param in params:
if not isinstance(param, rx.Var):
raise TypeError("each element of function parameters must be of type tvm.relax.Var,\
but got: {}".format(type(param)))
raise TypeError(
"each element of function parameters must be of type tvm.relax.Var,\
but got: {}".format(
type(param)
)
)

name = self.get_unique_name(name)
return FunctionScope(self, name, params)
Expand Down Expand Up @@ -297,12 +301,12 @@ def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [n, m], type_anno)
def te_func(args, args_dict, msg):
A = args[0]
B = args_dict["B"]
return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
with bb.function([x, y], "rx_func"):
out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
bb.emit_func_output(out)
Expand Down Expand Up @@ -389,8 +393,10 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten
te_args = te_arg_list + te_kwarg_list

te_out = func(*new_args, **new_kwargs)
assert (isinstance(te_out, tvm.te.tensor.Tensor) or \
(isinstance(te_out, (tuple, list) and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)))), "only support te.tensor or tuple/list of te.tensor as function output"
assert isinstance(te_out, tvm.te.tensor.Tensor) or (
isinstance(te_out, (tuple, list))
and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
), "only support te.tensor or tuple/list of te.tensor as function output"
outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out)
unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)

Expand All @@ -402,15 +408,18 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten
self._context_mod[gvar] = tir_func

call_args = [x.op.value for x in te_args]
output_shape = outs[0].shape if isinstance(te_out, tvm.te.tensor.Tensor) else Tuple([ShapeExpr(x.shape) for x in outs])
output_shape = (
outs[0].shape
if isinstance(te_out, tvm.te.tensor.Tensor)
else Tuple([ShapeExpr(x.shape) for x in outs])
)
# add arguments for extra parameters from unbound var
if (len(unbound_tir_vars) > 0):
if len(unbound_tir_vars) > 0:
call = call_tir(output_shape, gvar, call_args, tir_vars=ShapeExpr(unbound_tir_vars))
else:
call = call_tir(output_shape, gvar, call_args)
return _ffi_api.BlockBuilderEmit(self, call)


def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
"""Emit a MatchShape.
Expand Down Expand Up @@ -446,16 +455,18 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
output = Tuple(output)
return _ffi_api.BlockBuilderEmitOutput(self, output)

def emit_func_output(self,
output: Union[Expr, Tuple, List[Expr]],
params: Optional[Union[Var, Tuple, List[Var]]] = None) -> None:
def emit_func_output(
self,
output: Union[Expr, Tuple, List[Expr]],
params: Optional[Union[Var, Tuple, List[Var]]] = None,
) -> None:
"""Emit output for the function.
Parameters
----------
output : Expr | Tuple | List[Expr]
The output of the current block/function.
params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
The parameters of the function to be built.
If params is None, it means the params have been initialized in the function with scope.
Expand All @@ -470,7 +481,9 @@ def emit_func_output(self,
self._is_emit_func_output_called = True

if self._func_params is not None and params is not None:
raise RuntimeError("function parameters have been initialized in the function with scope.")
raise RuntimeError(
"function parameters have been initialized in the function with scope."
)

if self._func_params is None and params is None:
raise RuntimeError("Relax function must have parameter.")
Expand All @@ -484,7 +497,7 @@ def emit_func_output(self,
if isinstance(output, (list, tuple)):
output = Tuple(output)
self._func_ret = output

block = self._end_block()
if len(block.bindings) > 0:
self._blocks.append(block)
Expand Down Expand Up @@ -521,7 +534,6 @@ def get(self) -> tvm.IRModule:
"""
return self._context_mod


def get_unique_name(self, name_prefix: str) -> str:
"""Generate a unique name with a specified prefix.
Expand Down
Loading

0 comments on commit 7f5ae9c

Please sign in to comment.