diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..b8857e598dc4 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -711,7 +711,7 @@ class RangeNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); }; -/*! \brief Range constainer */ +/*! \brief Range container */ class Range : public ObjectRef { public: /*! @@ -736,7 +736,7 @@ class Range : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); }; -// implementataions +// implementations inline const Type& RelayExprNode::checked_type() const { ICHECK(checked_type_.defined()) << "internal error: the type checker has " << "not populated the checked_type " diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 9cd2bed65739..6c2c6dd5fc86 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -270,7 +270,7 @@ class IterVarNode : public Object { IterVarType iter_type; /*! * \brief additional tag on the iteration variable, - * set this if this is binded already to a known thread tag. + * set this if this is bound already to a known thread tag. */ String thread_tag; /*! diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index f0f2245e7f62..36c425cb85e2 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -16,17 +16,21 @@ # under the License. """Common expressions data structures in the IR.""" from numbers import Number +from typing import Callable, Optional import tvm._ffi -from ..runtime import Scriptable, const, convert +from ..runtime import Object, Scriptable, const, convert from . import _ffi_api -from .base import Node +from .base import Node, Span +from .type import Type class BaseExpr(Node): """Base class of all the expressions.""" + span: Optional[Span] + class PrimExpr(BaseExpr): """Base class of all primitive expressions. @@ -35,6 +39,8 @@ class PrimExpr(BaseExpr): optimizations and integer analysis. """ + dtype: str + class RelayExpr(BaseExpr): """Base class of all non-primitive expressions.""" @@ -67,10 +73,12 @@ class GlobalVar(RelayExpr): The name of the variable. """ - def __init__(self, name_hint, type_annot=None): + name_hint: str + + def __init__(self, name_hint: str, type_annot: Optional[Type] = None): self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot) - def __call__(self, *args): + def __call__(self, *args: RelayExpr) -> BaseExpr: """Call the global variable. Parameters @@ -94,7 +102,9 @@ def __call__(self, *args): arg_types = [type(x) for x in args] raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}") - def astext(self, show_meta_data=True, annotate=None): + def astext( + self, show_meta_data: bool = True, annotate: Optional[Callable[[Object], str]] = None + ) -> str: """Get the text format of the expression. Parameters @@ -140,7 +150,7 @@ class Range(Node, Scriptable): The end value of the range. span : Optional[Span] - The location of this itervar in the source code. + The location of this node in the source code. Note ---- @@ -148,14 +158,22 @@ class Range(Node, Scriptable): if the end argument is not None. Otherwise, it creates `[0, begin)`. """ - def __init__(self, begin, end=None, span=None): + min: PrimExpr + extent: PrimExpr + span: Optional[Span] + + def __init__( + self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None + ) -> None: if end is None: end = convert(begin) begin = const(0, dtype=end.dtype, span=span) self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod - def from_min_extent(min_value, extent, span=None): + def from_min_extent( + min_value: PrimExpr, extent: PrimExpr, span: Optional[Span] = None + ) -> "Range": """Construct a Range by min and extent. This constructs a range in [min_value, min_value + extent) @@ -169,7 +187,7 @@ def from_min_extent(min_value, extent, span=None): The extent of the range. span : Optional[Span] - The location of this itervar in the source code. + The location of this node in the source code. Returns ------- @@ -178,8 +196,8 @@ def from_min_extent(min_value, extent, span=None): """ return _ffi_api.Range_from_min_extent(min_value, extent, span) - def __eq__(self, other): + def __eq__(self, other: Object) -> bool: return tvm.ir.structural_equal(self, other) - def __ne__(self, other): + def __ne__(self, other: Object) -> bool: return not self.__eq__(other) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f93e39ee0fbd..fad9fca083a1 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,7 +27,7 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -from typing import Optional, Union +from typing import List, Optional, Union import tvm._ffi import tvm.ir._ffi_api @@ -38,9 +38,10 @@ from . import _ffi_api from . import generic as _generic +from .buffer import Buffer, DataProducer -def div_ambiguity_error(): +def div_ambiguity_error() -> RuntimeError: return RuntimeError( "TVM supports multiple types of integer divisions, " + "please call div, indexdiv/indexmod, floordiv/floormod " @@ -69,111 +70,111 @@ class ExprOp(object): # TODO(tkonolige): use inspect to add source information to these objects - def __add__(self, other): + def __add__(self, other: PrimExpr) -> PrimExpr: return _generic.add(self, other) - def __radd__(self, other): + def __radd__(self, other: PrimExpr) -> PrimExpr: return _generic.add(other, self) - def __sub__(self, other): + def __sub__(self, other: PrimExpr) -> PrimExpr: return _generic.subtract(self, other) - def __rsub__(self, other): + def __rsub__(self, other: PrimExpr) -> PrimExpr: return _generic.subtract(other, self) - def __mul__(self, other): + def __mul__(self, other: PrimExpr) -> PrimExpr: return _generic.multiply(self, other) - def __rmul__(self, other): + def __rmul__(self, other: PrimExpr) -> PrimExpr: return _generic.multiply(other, self) - def __div__(self, other): + def __div__(self, other: PrimExpr) -> PrimExpr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rdiv__(self, other): + def __rdiv__(self, other: PrimExpr) -> PrimExpr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __truediv__(self, other): + def __truediv__(self, other: PrimExpr) -> PrimExpr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rtruediv__(self, other): + def __rtruediv__(self, other: PrimExpr) -> PrimExpr: if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __floordiv__(self, other): + def __floordiv__(self, other: PrimExpr) -> PrimExpr: return _generic.floordiv(self, other) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: PrimExpr) -> PrimExpr: return _generic.floordiv(other, self, None) - def __mod__(self, other): + def __mod__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpFloorMod(self, other, None) # type: ignore - def __rmod__(self, other): + def __rmod__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpFloorMod(other, self, None) # type: ignore - def __neg__(self): + def __neg__(self) -> PrimExpr: neg_one = const(-1, self.dtype) # type: ignore return self.__mul__(neg_one) - def __lshift__(self, other): + def __lshift__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.left_shift(self, other, None) # type: ignore - def __rlshift__(self, other): + def __rlshift__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.left_shift(other, self, None) # type: ignore - def __rshift__(self, other): + def __rshift__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.right_shift(self, other, None) # type: ignore - def __rrshift__(self, other): + def __rrshift__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.right_shift(other, self, None) # type: ignore - def __and__(self, other): + def __and__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.bitwise_and(self, other, None) # type: ignore - def __rand__(self, other): + def __rand__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.bitwise_and(other, self, None) # type: ignore - def __or__(self, other): + def __or__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.bitwise_or(self, other, None) # type: ignore - def __ror__(self, other): + def __ror__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.bitwise_or(other, self, None) # type: ignore - def __xor__(self, other): + def __xor__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.bitwise_xor(self, other, None) # type: ignore - def __rxor__(self, other): + def __rxor__(self, other: PrimExpr) -> PrimExpr: return _ffi_api.bitwise_xor(other, self, None) # type: ignore - def __invert__(self): + def __invert__(self) -> PrimExpr: if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.bitwise_not(self, None) # type: ignore - def __lt__(self, other): + def __lt__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpLT(self, other, None) # type: ignore - def __le__(self, other): + def __le__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpLE(self, other, None) # type: ignore - def __eq__(self, other): + def __eq__(self, other: PrimExpr) -> PrimExpr: return EqualOp(self, other) - def __ne__(self, other): + def __ne__(self, other: PrimExpr) -> PrimExpr: return NotEqualOp(self, other) - def __gt__(self, other): + def __gt__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpGT(self, other, None) # type: ignore - def __ge__(self, other): + def __ge__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): @@ -182,10 +183,10 @@ def __nonzero__(self): + "use tvm.tir.all / tvm.tir.any instead" ) - def __bool__(self): + def __bool__(self) -> bool: return self.__nonzero__() - def equal(self, other, span=None): + def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool: """Build an equal check expression with other expr. Parameters @@ -203,7 +204,7 @@ def equal(self, other, span=None): """ return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype: str, span: Optional[Span] = None): + def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: """Cast the expression to other type. Parameters @@ -243,18 +244,18 @@ class EqualOp(ObjectGeneric, ExprOp): # This class is not manipulated by C++. So use python's identity check function is sufficient same_as = object.__eq__ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None): self.a = a self.b = b self.span = span - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.a.same_as(self.b) - def __bool__(self): + def __bool__(self) -> bool: return self.__nonzero__() - def asobject(self): + def asobject(self) -> PrimExpr: """Convert object.""" return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore @@ -280,18 +281,18 @@ class NotEqualOp(ObjectGeneric, ExprOp): # This class is not manipulated by C++. So use python's identity check function is sufficient same_as = object.__eq__ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.a = a self.b = b self.span = span - def __nonzero__(self): + def __nonzero__(self) -> bool: return not self.a.same_as(self.b) - def __bool__(self): + def __bool__(self) -> bool: return self.__nonzero__() - def asobject(self): + def asobject(self) -> PrimExpr: """Convert object.""" return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore @@ -309,11 +310,11 @@ class IntImmEnum(ObjectGeneric): The location of the cast in the source. """ - def __init__(self, value, span=None): + def __init__(self, value: int, span: Optional[Span] = None) -> None: self.value = value self.span = span - def asobject(self): + def asobject(self) -> "IntImm": """Convert object.""" return IntImm("int32", self.value, self.span) # type: ignore @@ -331,11 +332,13 @@ class ConstExpr(PrimExprWithOp): class BinaryOpExpr(PrimExprWithOp): - pass + a: PrimExpr + b: PrimExpr class CmpExpr(PrimExprWithOp): - pass + a: PrimExpr + b: PrimExpr class LogicalExpr(PrimExprWithOp): @@ -351,14 +354,17 @@ class Var(PrimExprWithOp): name : str The name - dtype : Union[str, tvm.irType] + dtype : Union[str, ir.Type] The data type span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None): + name_hint: str + type_annotation: ir.Type + + def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) # type: ignore @@ -372,15 +378,15 @@ class SizeVar(Var): name : str The name - dtype : int + dtype : Union[str, ir.Type] The data type span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ # pylint: disable=super-init-not-called - def __init__(self, name, dtype, span=None): + def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore @@ -405,7 +411,7 @@ class IterVar(Object, ExprOp, Scriptable): The thread type tag. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. See Also -------- @@ -423,7 +429,19 @@ class IterVar(Object, ExprOp, Scriptable): Parallelized = 7 Tensorized = 8 - def __init__(self, dom, var, iter_type, thread_tag="", span=None): + dom: ir.Range + var: Var + iter_type: int + thread_tag: str + + def __init__( + self, + dom: ir.Range, + var: Union[Var, str], + iter_type: int, + thread_tag: str = "", + span: Optional[Span] = None, + ) -> None: if dom is not None: if isinstance(dom, (list, tuple)): if len(dom) != 2: @@ -464,10 +482,22 @@ class CommReducer(Object, Scriptable): The identity elements. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, lhs, rhs, result, identity_element, span=None): + lhs: List[Var] + rhs: List[Var] + result: List[PrimExpr] + identity_element: List[PrimExpr] + + def __init__( + self, + lhs: List[Var], + rhs: List[Var], + result: List[PrimExpr], + identity_element: List[PrimExpr], + span: Optional[Span] = None, + ) -> None: self.__init_handle_by_constructor__( _ffi_api.CommReducer, lhs, rhs, result, identity_element, span # type: ignore ) @@ -498,10 +528,27 @@ class Reduce(PrimExprWithOp): The initial value for output. This can be an int, float or ProducerLoad span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, combiner, src, rdom, condition, value_index, init=None, span=None): + combiner: CommReducer + source: List[PrimExpr] + init: List[PrimExpr] + axis: List[IterVar] + condition: PrimExpr + value_index: int + + def __init__( + self, + combiner: CommReducer, + src: List[PrimExpr], + rdom: List[IterVar], + condition: PrimExpr, + value_index: int, + init: Optional[List[PrimExpr]] = None, + span: Optional[Span] = None, + ) -> None: + init = [] if init is None else init self.__init_handle_by_constructor__( _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init, span # type: ignore ) @@ -520,15 +567,17 @@ class FloatImm(ConstExpr): The constant value. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, dtype, value, span=None): + value: float + + def __init__(self, dtype: str, value: float, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__( tvm.ir._ffi_api.FloatImm, dtype, value, span # type: ignore ) - def __float__(self): + def __float__(self) -> float: return self.value @@ -545,30 +594,32 @@ class IntImm(ConstExpr): The constant value. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, dtype, value, span=None): + value: int + + def __init__(self, dtype: str, value: int, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__( tvm.ir._ffi_api.IntImm, dtype, value, span # type: ignore ) - def __hash__(self): + def __hash__(self) -> int: return self.value - def __int__(self): + def __int__(self) -> int: return self.value - def __nonzero__(self): + def __nonzero__(self) -> bool: return self.value != 0 - def __eq__(self, other): + def __eq__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpEQ(self, other, None) # type: ignore - def __ne__(self, other): + def __ne__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpNE(self, other, None) # type: ignore - def __bool__(self): + def __bool__(self) -> bool: return self.__nonzero__() @@ -582,23 +633,25 @@ class StringImm(ConstExpr): The value of the function. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, value, span=None): + value: str + + def __init__(self, value: str, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore - def __eq__(self, other): + def __eq__(self, other: PrimExpr) -> bool: if isinstance(other, ConstExpr): return self.value == other.value return self.value == other - def __ne__(self, other): + def __ne__(self, other: PrimExpr) -> bool: if isinstance(other, ConstExpr): return self.value != other.value return self.value != other - def __hash__(self): + def __hash__(self) -> int: return PrimExpr.__hash__(self) @@ -615,10 +668,12 @@ class Cast(PrimExprWithOp): The value of the function. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, dtype, value, span=None): + value: PrimExpr + + def __init__(self, dtype, value, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore @@ -635,10 +690,10 @@ class Add(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore @@ -655,10 +710,10 @@ class Sub(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore @@ -675,10 +730,10 @@ class Mul(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore @@ -695,10 +750,10 @@ class Div(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore @@ -715,10 +770,10 @@ class Mod(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore @@ -735,10 +790,10 @@ class FloorDiv(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore @@ -755,10 +810,10 @@ class FloorMod(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore @@ -775,10 +830,10 @@ class Min(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore @@ -795,10 +850,10 @@ class Max(BinaryOpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore @@ -815,10 +870,10 @@ class EQ(CmpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore @@ -835,10 +890,10 @@ class NE(CmpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore @@ -855,10 +910,10 @@ class LT(CmpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore @@ -875,10 +930,10 @@ class LE(CmpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore @@ -895,10 +950,10 @@ class GT(CmpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore @@ -915,10 +970,10 @@ class GE(CmpExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore @@ -935,10 +990,10 @@ class And(LogicalExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore @@ -955,10 +1010,13 @@ class Or(LogicalExpr): The right hand operand. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, b, span=None): + a: PrimExpr + b: PrimExpr + + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore @@ -972,10 +1030,12 @@ class Not(LogicalExpr): The input value span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, a, span=None): + a: PrimExpr + + def __init__(self, a: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore @@ -1002,10 +1062,20 @@ class Select(PrimExprWithOp): The value to take when condition is false. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, condition, true_value, false_value, span=None): + condition: PrimExpr + true_value: PrimExpr + false_value: PrimExpr + + def __init__( + self, + condition: PrimExpr, + true_value: PrimExpr, + false_value: PrimExpr, + span: Optional[Span] = None, + ) -> None: if isinstance(condition, bool): condition = IntImm("bool", condition) self.__init_handle_by_constructor__( @@ -1026,10 +1096,15 @@ class BufferLoad(PrimExprWithOp): The buffer indices. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, buffer, indices, span=None): + buffer: Buffer + indices: List[PrimExpr] + + def __init__( + self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None + ) -> None: self.__init_handle_by_constructor__( _ffi_api.BufferLoad, buffer, indices, span # type: ignore ) @@ -1048,10 +1123,15 @@ class ProducerLoad(PrimExprWithOp): The buffer indices. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, producer, indices, span=None): + producer: DataProducer + indices: List[PrimExpr] + + def __init__( + self, producer: DataProducer, indices: List[PrimExpr], span: Optional[Span] = None + ) -> None: self.__init_handle_by_constructor__( _ffi_api.ProducerLoad, producer, indices, span # type: ignore ) @@ -1073,10 +1153,16 @@ class Ramp(PrimExprWithOp): The lanes of the expression. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, base, stride, lanes, span=None): + base: PrimExpr + stride: PrimExpr + lanes: int + + def __init__( + self, base: PrimExpr, stride: PrimExpr, lanes: int, span: Optional[Span] = None + ) -> None: self.__init_handle_by_constructor__( _ffi_api.Ramp, base, stride, lanes, span # type: ignore ) @@ -1095,10 +1181,13 @@ class Broadcast(PrimExprWithOp): The lanes of the expression. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, value, lanes, span=None): + value: PrimExpr + lanes: int + + def __init__(self, value: PrimExpr, lanes: int, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore @@ -1108,17 +1197,22 @@ class Shuffle(PrimExprWithOp): Parameters ---------- - vectors : Array of Expr + vectors : List[PrimExpr] The vectors - indices : Array of indices + indices : List[PrimExpr] The indices span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, vectors, indices, span=None): + vectors: List[PrimExpr] + indices: List[PrimExpr] + + def __init__( + self, vectors: List[PrimExpr], indices: List[PrimExpr], span: Optional[Span] = None + ) -> None: self.__init_handle_by_constructor__( _ffi_api.Shuffle, vectors, indices, span # type: ignore ) @@ -1144,7 +1238,7 @@ class Call(PrimExprWithOp): dtype : str The return data type - op : Union[RelayExpr, str] + op : Union[Op, str] The function to be called, or the name to the global tvm.Op @@ -1152,10 +1246,15 @@ class Call(PrimExprWithOp): The input arguments to the call span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, dtype, op, args, span=None): + op: Op + args: List[PrimExpr] + + def __init__( + self, dtype: str, op: Union[Op, str], args: List[PrimExpr], span: Optional[Span] = None + ) -> None: if isinstance(op, str): if not op.startswith("tir."): raise ValueError( @@ -1180,16 +1279,22 @@ class Let(PrimExprWithOp): The variable in the binding. value : PrimExpr - The value in to be binded. + The value in to be bound. body : PrimExpr The body expression. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, var, value, body, span=None): + var: Var + value: PrimExpr + body: PrimExpr + + def __init__( + self, var: Var, value: PrimExpr, body: PrimExpr, span: Optional[Span] = None + ) -> None: self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) # type: ignore @@ -1198,8 +1303,8 @@ class Any(PrimExprWithOp): """Any node. span : Optional[Span] - The location of this itervar in the source code. + The location of this expression in the source code. """ - def __init__(self, span=None): + def __init__(self, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore