Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Introduce a SelectLikeOpInterface #104751

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Analysis/SliceWalk.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ WalkContinuation walkSlice(mlir::ValueRange rootValues,
WalkCallback walkCallback);

/// Computes a vector of all control predecessors of `value`. Relies on
/// RegionBranchOpInterface and BranchOpInterface to determine predecessors.
/// Returns nullopt if `value` has no predecessors or when the relevant
/// operations are missing the interface implementations.
/// RegionBranchOpInterface, BranchOpInterface, and SelectLikeOpInterface to
/// determine predecessors. Returns nullopt if `value` has no predecessors or
/// when the relevant operations are missing the interface implementations.
std::optional<SmallVector<Value>> getControlFlowPredecessors(Value value);

} // namespace mlir
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Arith/IR/Arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
BooleanConditionOrMatchingShape<"condition", "result">,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
DeclareOpInterfaceMethods<SelectLikeOpInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
def LLVM_SelectOp
: LLVM_Op<"select",
[Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<SelectLikeOpInterface>]>,
LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,8 @@ def SPIRV_SLessThanEqualOp : SPIRV_LogicalBinaryOp<"SLessThanEqual",
def SPIRV_SelectOp : SPIRV_Op<"Select",
[Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
UsableInSpecConstantOp]> {
UsableInSpecConstantOp,
DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
let summary = [{
Select between two objects. Before version 1.4, results are only
computed per component.
Expand Down
32 changes: 32 additions & 0 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,38 @@ def RegionBranchTerminatorOpInterface :
}];
}

def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
let description = [{
This interface provides information for select-like operations, i.e.,
operations that forward specific operands to the output, depending on a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a bit more documentation in terms of the invariant and promises associated with the interface I believe, in particular in terms of transforms: can we always fold or propagate based on the only behavior of a regular select?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I'm curious about poison propagation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempted to address this by restricting the semantics to fit LLVM's and arith dialect's select instruction/operation.

binary condition.

If the value of the condition is 1, then the `true` operand is returned,
and the third operand is ignored, even if it was poison.

If the value of the condition is 0, then the `false` operand is returned,
and the second operand is ignored, even if it was poison.

If the condition is poison, then poison is returned.

Implementing operations can also accept shaped conditions, in which case
the operation works element-wise.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the operand that would be chosen for a false condition.
}], "::mlir::Value", "getFalseValue", (ins)>,
InterfaceMethod<[{
Returns the operand that would be chosen for a true condition.
}], "::mlir::Value", "getTrueValue", (ins)>,
InterfaceMethod<[{
Returns the condition operand.
}], "::mlir::Value", "getCondition", (ins)>
];
}

//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Analysis/SliceWalk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {

std::optional<SmallVector<Value>>
mlir::getControlFlowPredecessors(Value value) {
SmallVector<Value> result;
if (OpResult opResult = dyn_cast<OpResult>(value)) {
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>())
return SmallVector<Value>(
{selectOp.getTrueValue(), selectOp.getFalseValue()});
auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
// If the interface is not implemented, there are no control flow
// predecessors to work with.
if (!regionOp)
Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
return WalkContinuation::advanceTo(addrCast.getOperand());

// TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
return WalkContinuation::advanceTo(
{selectOp.getTrueValue(), selectOp.getFalseValue()});

// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
getControlFlowPredecessors(val);
Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
llvm.call @region(%arg0) : (!llvm.ptr) -> ()
llvm.return
}

// -----

// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>

llvm.func @foo(%arg: i32)

llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
%cond = llvm.load %arg1 : !llvm.ptr -> i1
%1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
%selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
%2 = llvm.load %selected : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
llvm.return
}

// CHECK-LABEL: llvm.func @selects
// CHECK: llvm.load
// CHECK-NOT: alias_scopes
// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
// CHECK: llvm.load
// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
llvm.return
}

// -----

llvm.func @foo(%arg: i32)

llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
%selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
%2 = llvm.load %selected : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
llvm.return
}

// CHECK-LABEL: llvm.func @multi_ptr_select
// CHECK: llvm.load
// CHECK-NOT: alias_scopes
// CHECK-NOT: noalias_scopes
// CHECK: llvm.call @foo
llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
llvm.return
}
Loading