Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Introduce a SelectLikeOpInterface #104751

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Arith/IR/Arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
BooleanConditionOrMatchingShape<"condition", "result">,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
DeclareOpInterfaceMethods<SelectOpInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
def LLVM_SelectOp
: LLVM_Op<"select",
[Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<SelectOpInterface>]>,
LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,27 @@ def RegionBranchTerminatorOpInterface :
}];
}

def SelectOpInterface : OpInterface<"SelectOpInterface"> {
let description = [{
This interface provides information for select-like operations, i.e.,
operations that forward specific operands to the output, depending on a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a bit more documentation in terms of the invariant and promises associated with the interface I believe, in particular in terms of transforms: can we always fold or propagate based on the only behavior of a regular select?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I'm curious about poison propagation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempted to address this by restricting the semantics to fit LLVM's and arith dialect's select instruction/operation.

condition.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the operand that would be chosen for a false condition.
}], "::mlir::Value", "getFalseValue", (ins)>,
InterfaceMethod<[{
Returns the operand that would be chosen for a true condition.
}], "::mlir::Value", "getTrueValue", (ins)>,
InterfaceMethod<[{
Returns the condition operand.
}], "::mlir::Value", "getCondition", (ins)>
];
}

//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Analysis/SliceWalk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {

std::optional<SmallVector<Value>>
mlir::getControlFlowPredecessors(Value value) {
SmallVector<Value> result;
if (OpResult opResult = dyn_cast<OpResult>(value)) {
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
if (auto selectOp = opResult.getDefiningOp<SelectOpInterface>())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment on getControlFlowPredecessors that the function also uses the new SelectOpInterface to find the control flow predecessors?

return SmallVector<Value>(
{selectOp.getTrueValue(), selectOp.getFalseValue()});
auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
// If the interface is not implemented, there are no control flow
// predecessors to work with.
if (!regionOp)
Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
return WalkContinuation::advanceTo(addrCast.getOperand());

// TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
return WalkContinuation::advanceTo(
{selectOp.getTrueValue(), selectOp.getFalseValue()});

// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
getControlFlowPredecessors(val);
Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
llvm.call @region(%arg0) : (!llvm.ptr) -> ()
llvm.return
}

// -----

// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>

llvm.func @foo(%arg: i32)

llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
%cond = llvm.load %arg1 : !llvm.ptr -> i1
%1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
%selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
%2 = llvm.load %selected : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
llvm.return
}

// CHECK-LABEL: llvm.func @selects
// CHECK: llvm.load
// CHECK-NOT: alias_scopes
// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
// CHECK: llvm.load
// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
llvm.return
}

// -----

llvm.func @foo(%arg: i32)

llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
%selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
%2 = llvm.load %selected : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
llvm.return
}

// CHECK-LABEL: llvm.func @multi_ptr_select
// CHECK: llvm.load
// CHECK-NOT: alias_scopes
// CHECK-NOT: noalias_scopes
// CHECK: llvm.call @foo
llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
llvm.return
}
Loading