Skip to content

Commit

Permalink
[MLIR][DLTI][Transform] Introduce transform.dlti.query
Browse files Browse the repository at this point in the history
This transform op makes it possible to query attributes associated to
IR by means of the DLTI dialect.

The op takes both a `key` and a target `op` to perform the query at. Facility
functions automatically find the closest ancestor op which defines the
appropriate DLTI interface or has an attribute implementing a DLTI interface.
By default the lookup uses the data layout interfaces of DLTI. If the optional
`device` parameter is provided, the lookup happens with respect to the
interfaces for TargetSystemSpec and TargetDeviceSpec.

This op uses new free-standing functions in the `dlti` namespace to not
only look up specifications via the DataLayoutSpecOpInterface and on ModuleOps
but also on any ancestor op that has an appropriate DLTI attribute.
  • Loading branch information
rolfmorel committed Aug 1, 2024
1 parent 3a8a0b8 commit 2454ee2
Show file tree
Hide file tree
Showing 11 changed files with 560 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(TransformOps)

add_mlir_dialect(DLTI dlti)
add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)

Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ namespace detail {
class DataLayoutEntryAttrStorage;
} // namespace detail
} // namespace mlir
namespace mlir {
namespace dlti {
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
} // namespace dlti
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRDLTITransformOpsIncGen)

add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- DLTITransformOps.h - DLTI transform ops ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
#define MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"

namespace mlir {
namespace transform {
class QueryOp;
} // namespace transform
} // namespace mlir

namespace mlir {
class DialectRegistry;

namespace dlti {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace dlti
} // namespace mlir

////===----------------------------------------------------------------------===//
//// DLTI Transform Operations
////===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h.inc"

#endif // MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
61 changes: 61 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//===- DLTITransformOps.td - DLTI transform ops ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef DLTI_TRANSFORM_OPS
#define DLTI_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def QueryOp : Op<Transform_Dialect, "dlti.query", [
TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Return attribute (as param) associated to key via DTLI";
let description = [{
This op queries data layout and target information associated to payload
IR by way of the DLTI dialect. A lookup is performed for the given `key`
at the `target` op, with the DLTI dialect determining which interfaces and
attributes are consulted.

When only `key` is provided, the lookup occurs with respect to the data
layout specification of DLTI. When `device` is provided, the lookup occurs
with respect to DLTI's target device specifications associated to a DLTI
system device specification.

#### Return modes

When succesfull, the result, `associated_attr`, associates one attribute
as a param for each op in `target`'s payload.

If the lookup fails - as DLTI specifications or entries with the right
names are missing (i.e. the values of `device` and `key`) - a definite
failure is returned.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<StrAttr>:$device,
StrAttr:$key);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
"functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
TransformState &state);
}];
}

#endif // DLTI_TRANSFORM_OPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
dlti::registerTransformDialectExtension(registry);
func::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(TransformOps)
add_mlir_dialect_library(MLIRDLTIDialect
DLTI.cpp
Traits.cpp
Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,41 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
DataLayoutSpecInterface dlSpec = nullptr;

for (Operation *cur = op; cur && !dlSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
dlSpec = dataLayoutOp.getDataLayoutSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
dlSpec = moduleOp.getDataLayoutSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((dlSpec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue())))
break;
}

return dlSpec;
}

TargetSystemSpecInterface dlti::getTargetSystemSpec(Operation *op) {
TargetSystemSpecInterface sysSpec = nullptr;

for (Operation *cur = op; cur && !sysSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
sysSpec = dataLayoutOp.getTargetSystemSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
sysSpec = moduleOp.getTargetSystemSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((sysSpec =
llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue())))
break;
}

return sysSpec;
}

//===----------------------------------------------------------------------===//
// DLTIDialect
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRDLTITransformOps
DLTITransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DLTI/TransformOps

DEPENDS
MLIRDLTITransformOpsIncGen
MLIRDLTIDialect

LINK_LIBS PUBLIC
MLIRDLTIDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
)
98 changes: 98 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"

#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

using namespace mlir;
using namespace mlir::transform;

#define DEBUG_TYPE "dlti-transforms"

//===----------------------------------------------------------------------===//
// FuseOp
//===----------------------------------------------------------------------===//

void transform::QueryOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
StringAttr deviceId = getDeviceAttr();
StringAttr key = getKeyAttr();

DataLayoutEntryInterface entry;
if (deviceId) {
TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
if (!sysSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no target system spec associated to: " << target;

if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
entry = targetSpec->getSpecForIdentifier(key);
else
return mlir::emitDefiniteFailure(target->getLoc())
<< "no " << deviceId << " target device spec found";
} else {
DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
if (!dlSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no data layout spec associated to: " << target;

entry = dlSpec.getSpecForIdentifier(key);
}

if (!entry)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no DLTI entry for key: " << key;

results.push_back(entry.getValue());

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class DLTITransformDialectExtension
: public transform::TransformDialectExtension<
DLTITransformDialectExtension> {
public:
using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"

void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<DLTITransformDialectExtension>();
}
Loading

0 comments on commit 2454ee2

Please sign in to comment.