Skip to content

Commit

Permalink
[MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)
Browse files Browse the repository at this point in the history
This transform op makes it possible to query attributes associated to IR
by means of the DLTI dialect.

The op takes both a `key` and a target `op` to perform the query at.
Facility functions automatically find the closest ancestor op which
defines the appropriate DLTI interface or has an attribute implementing
a DLTI interface. By default the lookup uses the data layout interfaces
of DLTI. If the optional `device` parameter is provided, the lookup
happens with respect to the interfaces for TargetSystemSpec and
TargetDeviceSpec.

This op uses new free-standing functions in the `dlti` namespace to not
only look up specifications via the `DataLayoutSpecOpInterface` and on
`ModuleOp`s but also on any ancestor op that has an appropriate DLTI
attribute.
  • Loading branch information
rolfmorel committed Aug 9, 2024
1 parent 95820ca commit 0fcfd72
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(TransformOps)

add_mlir_dialect(DLTI dlti)
add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)

Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ namespace detail {
class DataLayoutEntryAttrStorage;
} // namespace detail
} // namespace mlir
namespace mlir {
namespace dlti {
/// Find the first DataLayoutSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);

/// Find the first TargetSystemSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
} // namespace dlti
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRDLTITransformOpsIncGen)

add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- DLTITransformOps.h - DLTI transform ops ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
#define MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"

namespace mlir {
namespace transform {
class QueryOp;
} // namespace transform
} // namespace mlir

namespace mlir {
class DialectRegistry;

namespace dlti {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace dlti
} // namespace mlir

////===----------------------------------------------------------------------===//
//// DLTI Transform Operations
////===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h.inc"

#endif // MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
61 changes: 61 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//===- DLTITransformOps.td - DLTI transform ops ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef DLTI_TRANSFORM_OPS
#define DLTI_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def QueryOp : Op<Transform_Dialect, "dlti.query", [
TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Return attribute (as param) associated to key via DTLI";
let description = [{
This op queries data layout and target information associated to payload
IR by way of the DLTI dialect. A lookup is performed for the given `key`
at the `target` op, with the DLTI dialect determining which interfaces and
attributes are consulted - first checking `target` and then its ancestors.

When only `key` is provided, the lookup occurs with respect to the data
layout specification of DLTI. When `device` is provided, the lookup occurs
with respect to DLTI's target device specifications associated to a DLTI
system device specification.

#### Return modes

When succesful, the result, `associated_attr`, associates one attribute as a
param for each op in `target`'s payload.

If the lookup fails - as DLTI specifications or entries with the right
names are missing (i.e. the values of `device` and `key`) - a definite
failure is returned.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<StrAttr>:$device,
StrAttr:$key);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
"functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
TransformState &state);
}];
}

#endif // DLTI_TRANSFORM_OPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
dlti::registerTransformDialectExtension(registry);
func::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(TransformOps)
add_mlir_dialect_library(MLIRDLTIDialect
DLTI.cpp
Traits.cpp
Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,41 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// DLTIDialect
//===----------------------------------------------------------------------===//

DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
DataLayoutSpecInterface dlSpec = nullptr;

for (Operation *cur = op; cur && !dlSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
dlSpec = dataLayoutOp.getDataLayoutSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
dlSpec = moduleOp.getDataLayoutSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((dlSpec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue())))
break;
}

return dlSpec;
}

TargetSystemSpecInterface dlti::getTargetSystemSpec(Operation *op) {
TargetSystemSpecInterface sysSpec = nullptr;

for (Operation *cur = op; cur && !sysSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
sysSpec = dataLayoutOp.getTargetSystemSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
sysSpec = moduleOp.getTargetSystemSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((sysSpec =
llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue())))
break;
}

return sysSpec;
}

constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRDLTITransformOps
DLTITransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DLTI/TransformOps

DEPENDS
MLIRDLTITransformOpsIncGen
MLIRDLTIDialect

LINK_LIBS PUBLIC
MLIRDLTIDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
)
94 changes: 94 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

//===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"

#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

using namespace mlir;
using namespace mlir::transform;

#define DEBUG_TYPE "dlti-transforms"

//===----------------------------------------------------------------------===//
// QueryOp
//===----------------------------------------------------------------------===//

void transform::QueryOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
StringAttr deviceId = getDeviceAttr();
StringAttr key = getKeyAttr();

DataLayoutEntryInterface entry;
if (deviceId) {
TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
if (!sysSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no target system spec associated to: " << target;

if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
entry = targetSpec->getSpecForIdentifier(key);
else
return mlir::emitDefiniteFailure(target->getLoc())
<< "no " << deviceId << " target device spec found";
} else {
DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
if (!dlSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no data layout spec associated to: " << target;

entry = dlSpec.getSpecForIdentifier(key);
}

if (!entry)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no DLTI entry for key: " << key;

results.push_back(entry.getValue());

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class DLTITransformDialectExtension
: public transform::TransformDialectExtension<
DLTITransformDialectExtension> {
public:
using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"

void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<DLTITransformDialectExtension>();
}
Loading

0 comments on commit 0fcfd72

Please sign in to comment.