diff --git a/stdlib/src/builtin/dtype.mojo b/stdlib/src/builtin/dtype.mojo index aa06273720..f84bc0994e 100644 --- a/stdlib/src/builtin/dtype.mojo +++ b/stdlib/src/builtin/dtype.mojo @@ -22,6 +22,7 @@ from utils import unroll alias _mIsSigned = UInt8(1) alias _mIsInteger = UInt8(1 << 7) +alias _mIsNotInteger = UInt8(~(1 << 7)) alias _mIsFloat = UInt8(1 << 6) @@ -274,14 +275,12 @@ struct DType( ) @always_inline("nodebug") - fn is_integral(self) -> Bool: - """Returns True if the type parameter is an integer and False otherwise. + fn _is_non_index_integral(self) -> Bool: + """Returns True if the type parameter is a non-index integer value and False otherwise. Returns: - Returns True if the input type parameter is an integer. + Returns True if the input type parameter is a non-index integer. """ - if self is DType.index: - return True return Bool( __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( __mlir_op.`pop.and`(self._as_i8(), _mIsInteger.value), @@ -289,6 +288,17 @@ struct DType( ) ) + @always_inline("nodebug") + fn is_integral(self) -> Bool: + """Returns True if the type parameter is an integer and False otherwise. + + Returns: + Returns True if the input type parameter is an integer. + """ + if self is DType.index: + return True + return self._is_non_index_integral() + @always_inline("nodebug") fn is_floating_point(self) -> Bool: """Returns True if the type parameter is a floating-point and False @@ -315,7 +325,7 @@ struct DType( True if the type is a half-precision float, false otherwise.. """ - return self.bitwidth() == 16 and self.is_floating_point() + return self in (DType.bfloat16, DType.float16) @always_inline("nodebug") fn is_numeric(self) -> Bool: @@ -335,7 +345,40 @@ struct DType( Returns: Returns the size in bytes of the current DType. """ - return __mlir_op.`pop.dtype.sizeof`(self.value) + + if self._is_non_index_integral(): + return int( + UInt8( + __mlir_op.`pop.shl`( + UInt8(1).value, + __mlir_op.`pop.sub`( + __mlir_op.`pop.shr`( + __mlir_op.`pop.and`( + self._as_i8(), _mIsNotInteger.value + ), + UInt8(1).value, + ), + UInt8(3).value, + ), + ) + ) + ) + + if self == DType.bool: + return sizeof[DType.bool]() + if self == DType.index: + return sizeof[DType.index]() + if self == DType.bfloat16: + return sizeof[DType.bfloat16]() + if self == DType.float16: + return sizeof[DType.float16]() + if self == DType.float32: + return sizeof[DType.float32]() + if self == DType.tensor_float32: + return sizeof[DType.tensor_float32]() + if self == DType.float64: + return sizeof[DType.float64]() + return sizeof[DType.invalid]() @always_inline fn bitwidth(self) -> Int: diff --git a/stdlib/test/builtin/test_dtype.mojo b/stdlib/test/builtin/test_dtype.mojo index e1d8b1646c..6799f1bcfd 100644 --- a/stdlib/test/builtin/test_dtype.mojo +++ b/stdlib/test/builtin/test_dtype.mojo @@ -43,8 +43,15 @@ fn test_key_element() raises: assert_true(DType.int64 in set) +fn test_sizeof() raises: + assert_equal(DType.int16.sizeof(), sizeof[DType.int16]()) + assert_equal(DType.float32.sizeof(), sizeof[DType.float32]()) + assert_equal(DType.index.sizeof(), sizeof[DType.index]()) + + fn main() raises: test_equality() test_stringable() test_representable() test_key_element() + test_sizeof()