Skip to content

Commit

Permalink
[stdlib] Remove MLIR magic from DType.sizeof
Browse files Browse the repository at this point in the history
This changes implementation of `DType.sizeof` to compute the size of the
type directly instead of generating `pop.dtype.sizeof`. With this change
`DType.sizeof` for index works correctly instead of giving a
compile-time error.

It also changes implementation of `DType.sizeof.is_half_point` to
compute the result directly instead of going via DType.sizeof.

MODULAR_ORIG_COMMIT_REV_ID: f9e492c0dc122081101d225862fdca24d2c9e1ab
  • Loading branch information
tatianashp authored and modularbot committed Sep 13, 2024
1 parent 20d5f1c commit 99b7cdb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
57 changes: 50 additions & 7 deletions stdlib/src/builtin/dtype.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from utils import unroll

alias _mIsSigned = UInt8(1)
alias _mIsInteger = UInt8(1 << 7)
alias _mIsNotInteger = UInt8(~(1 << 7))
alias _mIsFloat = UInt8(1 << 6)


Expand Down Expand Up @@ -274,21 +275,30 @@ struct DType(
)

@always_inline("nodebug")
fn is_integral(self) -> Bool:
"""Returns True if the type parameter is an integer and False otherwise.
fn _is_non_index_integral(self) -> Bool:
"""Returns True if the type parameter is a non-index integer value and False otherwise.
Returns:
Returns True if the input type parameter is an integer.
Returns True if the input type parameter is a non-index integer.
"""
if self is DType.index:
return True
return Bool(
__mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred ne>`](
__mlir_op.`pop.and`(self._as_i8(), _mIsInteger.value),
UInt8(0).value,
)
)

@always_inline("nodebug")
fn is_integral(self) -> Bool:
"""Returns True if the type parameter is an integer and False otherwise.
Returns:
Returns True if the input type parameter is an integer.
"""
if self is DType.index:
return True
return self._is_non_index_integral()

@always_inline("nodebug")
fn is_floating_point(self) -> Bool:
"""Returns True if the type parameter is a floating-point and False
Expand All @@ -315,7 +325,7 @@ struct DType(
True if the type is a half-precision float, false otherwise..
"""

return self.bitwidth() == 16 and self.is_floating_point()
return self in (DType.bfloat16, DType.float16)

@always_inline("nodebug")
fn is_numeric(self) -> Bool:
Expand All @@ -335,7 +345,40 @@ struct DType(
Returns:
Returns the size in bytes of the current DType.
"""
return __mlir_op.`pop.dtype.sizeof`(self.value)

if self._is_non_index_integral():
return int(
UInt8(
__mlir_op.`pop.shl`(
UInt8(1).value,
__mlir_op.`pop.sub`(
__mlir_op.`pop.shr`(
__mlir_op.`pop.and`(
self._as_i8(), _mIsNotInteger.value
),
UInt8(1).value,
),
UInt8(3).value,
),
)
)
)

if self == DType.bool:
return sizeof[DType.bool]()
if self == DType.index:
return sizeof[DType.index]()
if self == DType.bfloat16:
return sizeof[DType.bfloat16]()
if self == DType.float16:
return sizeof[DType.float16]()
if self == DType.float32:
return sizeof[DType.float32]()
if self == DType.tensor_float32:
return sizeof[DType.tensor_float32]()
if self == DType.float64:
return sizeof[DType.float64]()
return sizeof[DType.invalid]()

@always_inline
fn bitwidth(self) -> Int:
Expand Down
7 changes: 7 additions & 0 deletions stdlib/test/builtin/test_dtype.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,15 @@ fn test_key_element() raises:
assert_true(DType.int64 in set)


fn test_sizeof() raises:
assert_equal(DType.int16.sizeof(), sizeof[DType.int16]())
assert_equal(DType.float32.sizeof(), sizeof[DType.float32]())
assert_equal(DType.index.sizeof(), sizeof[DType.index]())


fn main() raises:
test_equality()
test_stringable()
test_representable()
test_key_element()
test_sizeof()

0 comments on commit 99b7cdb

Please sign in to comment.