diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ac1e990a96e2..48b283447969 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1229,36 +1229,107 @@ def evaluate(value: PrimExpr) -> None: return _ffi_api.Evaluate(value) # type: ignore[attr-defined] # pylint: disable=no-member -__all__ = [] -for _dtype in ["Float", "UInt", "Int"]: - for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: - _name = _dtype + _size + _lanes # pylint: disable=invalid-name - - def func_gen(name: str): - """Generate a function for each PrimExpr dtype. - - Parameters - ---------- - name: str - The ffi function name to call. - """ - - def func( - expr: Union[ - None, - PrimExpr, - Literal["inf", "-inf", "nan"], - ] = None - ) -> PrimExpr: - if isinstance(expr, str): - expr = float(expr) - return getattr(_ffi_api, name)(expr) - - return func - - globals()[_name.lower()] = func_gen(_name) - __all__.append(_name.lower()) +def func_gen(name: str): + """Generate a function for each PrimExpr dtype. + + Parameters + ---------- + name: str + The ffi function name to call. + """ + + def func( + expr: Union[ + None, + PrimExpr, + Literal["inf", "-inf", "nan"], + int, + float, + ] = None + ) -> PrimExpr: + if isinstance(expr, str): + expr = float(expr) + return getattr(_ffi_api, name)(expr) + + return func + + +# pylint: disable=invalid-name +int8 = func_gen(("Int8")) +int16 = func_gen(("Int16")) +int32 = func_gen(("Int32")) +int64 = func_gen(("Int64")) +int8x4 = func_gen(("Int8x4")) +int16x4 = func_gen(("Int16x4")) +int32x4 = func_gen(("Int32x4")) +int64x4 = func_gen(("Int64x4")) +int8x8 = func_gen(("Int8x8")) +int16x8 = func_gen(("Int16x8")) +int32x8 = func_gen(("Int32x8")) +int64x8 = func_gen(("Int64x8")) +int8x16 = func_gen(("Int8x16")) +int16x16 = func_gen(("Int16x16")) +int32x16 = func_gen(("Int32x16")) +int64x16 = func_gen(("Int64x16")) +int8x32 = func_gen(("Int8x32")) +int16x32 = func_gen(("Int16x32")) +int32x32 = func_gen(("Int32x32")) +int64x32 = func_gen(("Int64x32")) +int8x64 = func_gen(("Int8x64")) +int16x64 = func_gen(("Int16x64")) +int32x64 = func_gen(("Int32x64")) +int64x64 = func_gen(("Int64x64")) + +uint8 = func_gen(("UInt8")) +uint16 = func_gen(("UInt16")) +uint32 = func_gen(("UInt32")) +uint64 = func_gen(("UInt64")) +uint8x4 = func_gen(("UInt8x4")) +uint16x4 = func_gen(("UInt16x4")) +uint32x4 = func_gen(("UInt32x4")) +uint64x4 = func_gen(("UInt64x4")) +uint8x8 = func_gen(("UInt8x8")) +uint16x8 = func_gen(("UInt16x8")) +uint32x8 = func_gen(("UInt32x8")) +uint64x8 = func_gen(("UInt64x8")) +uint8x16 = func_gen(("UInt8x16")) +uint16x16 = func_gen(("UInt16x16")) +uint32x16 = func_gen(("UInt32x16")) +uint64x16 = func_gen(("UInt64x16")) +uint8x32 = func_gen(("UInt8x32")) +uint16x32 = func_gen(("UInt16x32")) +uint32x32 = func_gen(("UInt32x32")) +uint64x32 = func_gen(("UInt64x32")) +uint8x64 = func_gen(("UInt8x64")) +uint16x64 = func_gen(("UInt16x64")) +uint32x64 = func_gen(("UInt32x64")) +uint64x64 = func_gen(("UInt64x64")) + +float8 = func_gen(("Float8")) +float16 = func_gen(("Float16")) +float32 = func_gen(("Float32")) +float64 = func_gen(("Float64")) +float8x4 = func_gen(("Float8x4")) +float16x4 = func_gen(("Float16x4")) +float32x4 = func_gen(("Float32x4")) +float64x4 = func_gen(("Float64x4")) +float8x8 = func_gen(("Float8x8")) +float16x8 = func_gen(("Float16x8")) +float32x8 = func_gen(("Float32x8")) +float64x8 = func_gen(("Float64x8")) +float8x16 = func_gen(("Float8x16")) +float16x16 = func_gen(("Float16x16")) +float32x16 = func_gen(("Float32x16")) +float64x16 = func_gen(("Float64x16")) +float8x32 = func_gen(("Float8x32")) +float16x32 = func_gen(("Float16x32")) +float32x32 = func_gen(("Float32x32")) +float64x32 = func_gen(("Float64x32")) +float8x64 = func_gen(("Float8x64")) +float16x64 = func_gen(("Float16x64")) +float32x64 = func_gen(("Float32x64")) +float64x64 = func_gen(("Float64x64")) +# pylint: enable=invalid-name def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: @@ -1621,7 +1692,79 @@ def f(): # pylint: enable=invalid-name -__all__ += [ +__all__ = [ + "int8", + "int16", + "int32", + "int64", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "int8x8", + "int16x8", + "int32x8", + "int64x8", + "int8x16", + "int16x16", + "int32x16", + "int64x16", + "int8x32", + "int16x32", + "int32x32", + "int64x32", + "int8x64", + "int16x64", + "int32x64", + "int64x64", + "uint8", + "uint16", + "uint32", + "uint64", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "uint8x8", + "uint16x8", + "uint32x8", + "uint64x8", + "uint8x16", + "uint16x16", + "uint32x16", + "uint64x16", + "uint8x32", + "uint16x32", + "uint32x32", + "uint64x32", + "uint8x64", + "uint16x64", + "uint32x64", + "uint64x64", + "float8", + "float16", + "float32", + "float64", + "float8x4", + "float16x4", + "float32x4", + "float64x4", + "float8x8", + "float16x8", + "float32x8", + "float64x8", + "float8x16", + "float16x16", + "float32x16", + "float64x16", + "float8x32", + "float16x32", + "float32x32", + "float64x32", + "float8x64", + "float16x64", + "float32x64", + "float64x64", "buffer_decl", "prim_func", "arg",