Skip to content

Commit

Permalink
dialects: (stablehlo) add stablehlo.bitcast_convert (#3100)
Browse files Browse the repository at this point in the history
Adds support for `stablehlo.bitcast_convert`

Co-authored-by: Erick Ochoa <erick@ceci-nest-pas.me>
  • Loading branch information
efferifick and Erick Ochoa authored Aug 27, 2024
1 parent d7d30e7 commit 73bc048
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/stablehlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@
// CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>

// %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor<i32>) -> tensor<2xi16>
%bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor<i32>) -> tensor<2xi16>

// CHECK: "stablehlo.return"(%t0) : (tensor<i32>) -> ()
"stablehlo.return"(%t0) : (tensor<i32>) -> ()
26 changes: 26 additions & 0 deletions xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,31 @@ def __init__(
super().__init__(operands=(lhs, rhs), result_types=(result_type,))


@irdl_op_definition
class BitcastConvertOp(IRDLOperation):
"""
Performs a bitcast operation on operand tensor and produces a result tensor
where the bits of the entire operand tensor are reinterpreted using the type of the result tensor.
More formally, given E = element_type(operand), E' = element_type(result), and R = rank(operand):
If num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
If num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
If num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits returns in-memory representation of a given value,
and its behavior is implementation-defined because the exact representation of tensors is implementation-defined,
and the exact representation of element types is implementation-defined as well.
"""

name = "stablehlo.bitcast_convert"
input = operand_def(AnyTensorType)
result = result_def(AnyTensorType)

def __init__(self, input: SSAValue, result: Attribute):
super().__init__(operands=(input,), result_types=(result,))


@irdl_op_definition
class MultiplyOp(ElementwiseBinaryOperation):
"""
Expand Down Expand Up @@ -294,6 +319,7 @@ def verify_(self) -> None:
AbsOp,
AddOp,
AndOp,
BitcastConvertOp,
MultiplyOp,
ReturnOp,
SubtractOp,
Expand Down

0 comments on commit 73bc048

Please sign in to comment.