diff --git a/tests/filecheck/dialects/stablehlo/ops.mlir b/tests/filecheck/dialects/stablehlo/ops.mlir index eddd4c0be4..08d568ca50 100644 --- a/tests/filecheck/dialects/stablehlo/ops.mlir +++ b/tests/filecheck/dialects/stablehlo/ops.mlir @@ -31,5 +31,8 @@ // CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor, tensor) -> tensor %and = "stablehlo.and"(%t0, %t0) : (tensor, tensor) -> tensor +// %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor) -> tensor<2xi16> +%bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor) -> tensor<2xi16> + // CHECK: "stablehlo.return"(%t0) : (tensor) -> () "stablehlo.return"(%t0) : (tensor) -> () diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index a88926f14d..e39ba4872f 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -176,6 +176,31 @@ def __init__( super().__init__(operands=(lhs, rhs), result_types=(result_type,)) +@irdl_op_definition +class BitcastConvertOp(IRDLOperation): + """ + Performs a bitcast operation on operand tensor and produces a result tensor + where the bits of the entire operand tensor are reinterpreted using the type of the result tensor. + + More formally, given E = element_type(operand), E' = element_type(result), and R = rank(operand): + + If num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). + If num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). + If num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]). + + bits returns in-memory representation of a given value, + and its behavior is implementation-defined because the exact representation of tensors is implementation-defined, + and the exact representation of element types is implementation-defined as well. + """ + + name = "stablehlo.bitcast_convert" + input = operand_def(AnyTensorType) + result = result_def(AnyTensorType) + + def __init__(self, input: SSAValue, result: Attribute): + super().__init__(operands=(input,), result_types=(result,)) + + @irdl_op_definition class MultiplyOp(ElementwiseBinaryOperation): """ @@ -294,6 +319,7 @@ def verify_(self) -> None: AbsOp, AddOp, AndOp, + BitcastConvertOp, MultiplyOp, ReturnOp, SubtractOp,