Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][DLTI][Transform] Introduce transform.dlti.query - 2nd attempt #102652

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(TransformOps)

add_mlir_dialect(DLTI dlti)
add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)

Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ namespace detail {
class DataLayoutEntryAttrStorage;
} // namespace detail
} // namespace mlir
namespace mlir {
namespace dlti {
/// Find the first DataLayoutSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);

/// Find the first TargetSystemSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
} // namespace dlti
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRDLTITransformOpsIncGen)

add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- DLTITransformOps.h - DLTI transform ops ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
#define MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"

namespace mlir {
namespace transform {
class QueryOp;
} // namespace transform
} // namespace mlir

namespace mlir {
class DialectRegistry;

namespace dlti {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace dlti
} // namespace mlir

////===----------------------------------------------------------------------===//
//// DLTI Transform Operations
////===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h.inc"

#endif // MLIR_DIALECT_DLTI_TRANSFORMOPS_DLTITRANSFORMOPS_H
61 changes: 61 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//===- DLTITransformOps.td - DLTI transform ops ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef DLTI_TRANSFORM_OPS
#define DLTI_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def QueryOp : Op<Transform_Dialect, "dlti.query", [
TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Return attribute (as param) associated to key via DTLI";
let description = [{
This op queries data layout and target information associated to payload
IR by way of the DLTI dialect. A lookup is performed for the given `key`
at the `target` op, with the DLTI dialect determining which interfaces and
attributes are consulted - first checking `target` and then its ancestors.

When only `key` is provided, the lookup occurs with respect to the data
layout specification of DLTI. When `device` is provided, the lookup occurs
with respect to DLTI's target device specifications associated to a DLTI
system device specification.

#### Return modes

When succesful, the result, `associated_attr`, associates one attribute as a
param for each op in `target`'s payload.

If the lookup fails - as DLTI specifications or entries with the right
names are missing (i.e. the values of `device` and `key`) - a definite
failure is returned.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<StrAttr>:$device,
StrAttr:$key);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
"functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
TransformState &state);
}];
}

#endif // DLTI_TRANSFORM_OPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
dlti::registerTransformDialectExtension(registry);
func::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/DLTI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(TransformOps)
add_mlir_dialect_library(MLIRDLTIDialect
DLTI.cpp
Traits.cpp
Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,41 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// DLTIDialect
//===----------------------------------------------------------------------===//

DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
DataLayoutSpecInterface dlSpec = nullptr;

for (Operation *cur = op; cur && !dlSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
dlSpec = dataLayoutOp.getDataLayoutSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
dlSpec = moduleOp.getDataLayoutSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((dlSpec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue())))
break;
}

return dlSpec;
}

TargetSystemSpecInterface dlti::getTargetSystemSpec(Operation *op) {
TargetSystemSpecInterface sysSpec = nullptr;

for (Operation *cur = op; cur && !sysSpec; cur = cur->getParentOp()) {
if (auto dataLayoutOp = dyn_cast<DataLayoutOpInterface>(cur))
sysSpec = dataLayoutOp.getTargetSystemSpec();
else if (auto moduleOp = dyn_cast<ModuleOp>(cur))
sysSpec = moduleOp.getTargetSystemSpec();
else
for (NamedAttribute attr : cur->getAttrs())
if ((sysSpec =
llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue())))
break;
}

return sysSpec;
}

constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRDLTITransformOps
DLTITransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DLTI/TransformOps

DEPENDS
MLIRDLTITransformOpsIncGen
MLIRDLTIDialect

LINK_LIBS PUBLIC
MLIRDLTIDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
)
96 changes: 96 additions & 0 deletions mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

//===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"

#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

using namespace mlir;
using namespace mlir::transform;

#define DEBUG_TYPE "dlti-transforms"

//===----------------------------------------------------------------------===//
// QueryOp
//===----------------------------------------------------------------------===//

void transform::QueryOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
StringAttr deviceId = getDeviceAttr();
StringAttr key = getKeyAttr();

DataLayoutEntryInterface entry;
if (deviceId) {
TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
if (!sysSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no target system spec associated to: " << target;

if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
entry = targetSpec->getSpecForIdentifier(key);
else
return mlir::emitDefiniteFailure(target->getLoc())
<< "no " << deviceId << " target device spec found";
} else {
DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
if (!dlSpec)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no data layout spec associated to: " << target;

entry = dlSpec.getSpecForIdentifier(key);
}

if (!entry)
return mlir::emitDefiniteFailure(target->getLoc())
<< "no DLTI entry for key: " << key;

results.push_back(entry.getValue());

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class DLTITransformDialectExtension
: public transform::TransformDialectExtension<
DLTITransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTITransformDialectExtension)

using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc"

void mlir::dlti::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<DLTITransformDialectExtension>();
}
Loading
Loading