From 99b7cdb941fa45cf4557a3f6dcce0b3b90d51217 Mon Sep 17 00:00:00 2001 From: Tatiana Shpeisman Date: Mon, 8 Jul 2024 20:20:11 -0700 Subject: [PATCH] [stdlib] Remove MLIR magic from DType.sizeof This changes implementation of `DType.sizeof` to compute the size of the type directly instead of generating `pop.dtype.sizeof`. With this change `DType.sizeof` for index works correctly instead of giving a compile-time error. It also changes implementation of `DType.sizeof.is_half_point` to compute the result directly instead of going via DType.sizeof. MODULAR_ORIG_COMMIT_REV_ID: f9e492c0dc122081101d225862fdca24d2c9e1ab --- stdlib/src/builtin/dtype.mojo | 57 +++++++++++++++++++++++++---- stdlib/test/builtin/test_dtype.mojo | 7 ++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/stdlib/src/builtin/dtype.mojo b/stdlib/src/builtin/dtype.mojo index aa06273720..f84bc0994e 100644 --- a/stdlib/src/builtin/dtype.mojo +++ b/stdlib/src/builtin/dtype.mojo @@ -22,6 +22,7 @@ from utils import unroll alias _mIsSigned = UInt8(1) alias _mIsInteger = UInt8(1 << 7) +alias _mIsNotInteger = UInt8(~(1 << 7)) alias _mIsFloat = UInt8(1 << 6) @@ -274,14 +275,12 @@ struct DType( ) @always_inline("nodebug") - fn is_integral(self) -> Bool: - """Returns True if the type parameter is an integer and False otherwise. + fn _is_non_index_integral(self) -> Bool: + """Returns True if the type parameter is a non-index integer value and False otherwise. Returns: - Returns True if the input type parameter is an integer. + Returns True if the input type parameter is a non-index integer. """ - if self is DType.index: - return True return Bool( __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( __mlir_op.`pop.and`(self._as_i8(), _mIsInteger.value), @@ -289,6 +288,17 @@ struct DType( ) ) + @always_inline("nodebug") + fn is_integral(self) -> Bool: + """Returns True if the type parameter is an integer and False otherwise. + + Returns: + Returns True if the input type parameter is an integer. + """ + if self is DType.index: + return True + return self._is_non_index_integral() + @always_inline("nodebug") fn is_floating_point(self) -> Bool: """Returns True if the type parameter is a floating-point and False @@ -315,7 +325,7 @@ struct DType( True if the type is a half-precision float, false otherwise.. """ - return self.bitwidth() == 16 and self.is_floating_point() + return self in (DType.bfloat16, DType.float16) @always_inline("nodebug") fn is_numeric(self) -> Bool: @@ -335,7 +345,40 @@ struct DType( Returns: Returns the size in bytes of the current DType. """ - return __mlir_op.`pop.dtype.sizeof`(self.value) + + if self._is_non_index_integral(): + return int( + UInt8( + __mlir_op.`pop.shl`( + UInt8(1).value, + __mlir_op.`pop.sub`( + __mlir_op.`pop.shr`( + __mlir_op.`pop.and`( + self._as_i8(), _mIsNotInteger.value + ), + UInt8(1).value, + ), + UInt8(3).value, + ), + ) + ) + ) + + if self == DType.bool: + return sizeof[DType.bool]() + if self == DType.index: + return sizeof[DType.index]() + if self == DType.bfloat16: + return sizeof[DType.bfloat16]() + if self == DType.float16: + return sizeof[DType.float16]() + if self == DType.float32: + return sizeof[DType.float32]() + if self == DType.tensor_float32: + return sizeof[DType.tensor_float32]() + if self == DType.float64: + return sizeof[DType.float64]() + return sizeof[DType.invalid]() @always_inline fn bitwidth(self) -> Int: diff --git a/stdlib/test/builtin/test_dtype.mojo b/stdlib/test/builtin/test_dtype.mojo index e1d8b1646c..6799f1bcfd 100644 --- a/stdlib/test/builtin/test_dtype.mojo +++ b/stdlib/test/builtin/test_dtype.mojo @@ -43,8 +43,15 @@ fn test_key_element() raises: assert_true(DType.int64 in set) +fn test_sizeof() raises: + assert_equal(DType.int16.sizeof(), sizeof[DType.int16]()) + assert_equal(DType.float32.sizeof(), sizeof[DType.float32]()) + assert_equal(DType.index.sizeof(), sizeof[DType.index]()) + + fn main() raises: test_equality() test_stringable() test_representable() test_key_element() + test_sizeof()