Skip to content

Commit

Permalink
Add noc dataflow API for doing async read/write prereq for #321 (#324)
Browse files Browse the repository at this point in the history
- Add new noc dataflow APIs
- Lower them to EmitC dialect
- Add EmitC dialect type conversion for NocAddr
- Add support for arith dialect in TTKernelToEmitC path
- Protect arith constants from being hoisted outside of ttmetal.dispatch
  region
  • Loading branch information
nsmithtt authored Aug 9, 2024
1 parent 9a91dea commit f4fda50
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 35 deletions.
54 changes: 50 additions & 4 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def TTKernel_CBPushBackOp : TTKernel_Op<"cb_push_back"> {
CBPushBack operation
}];

let arguments = (ins TTKernel_CB:$cb);
let arguments = (ins TTKernel_CB:$cb, I32:$numPages);

let hasVerifier = 1;
}
Expand All @@ -144,7 +144,7 @@ def TTKernel_CBPopFrontOp : TTKernel_Op<"cb_pop_front"> {
CBPopFront operation
}];

let arguments = (ins TTKernel_CB:$cb);
let arguments = (ins TTKernel_CB:$cb, I32:$numPages);

let hasVerifier = 1;
}
Expand All @@ -155,7 +155,7 @@ def TTKernel_CBReserveBackOp : TTKernel_Op<"cb_reserve_back"> {
CBReserveBack operation
}];

let arguments = (ins TTKernel_CB:$cb);
let arguments = (ins TTKernel_CB:$cb, I32:$numPages);

let hasVerifier = 1;
}
Expand All @@ -166,11 +166,57 @@ def TTKernel_CBWaitFrontOp : TTKernel_Op<"cb_wait_front"> {
CBWaitFront operation
}];

let arguments = (ins TTKernel_CB:$cb);
let arguments = (ins TTKernel_CB:$cb, I32:$numPages);

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// TTKernel NOC operations
//===----------------------------------------------------------------------===//

def TTKernel_GetNocAddrOp : TTKernel_Op<"get_noc_addr"> {
let summary = "GetNocAddr";
let description = [{
GetNocAddr
}];

let arguments = (ins I32:$x, I32:$y, I32:$l1Address);
let results = (outs TTKernel_NocAddr:$nocAddr);
}

def TTKernel_NocAsyncReadOp : TTKernel_Op<"noc_async_read"> {
let summary = "NocAsyncRead";
let description = [{
NocAsyncRead
}];

let arguments = (ins TTKernel_NocAddr:$srcNocAddr, I32:$dstLocalL1Addr, I32:$size);
}

def TTKernel_NocAsyncReadBarrierOp : TTKernel_Op<"noc_async_read_barrier"> {
let summary = "NocAsyncReadBarrier";
let description = [{
NocAsyncReadBarrier
}];
}

def TTKernel_NocAsyncWriteOp : TTKernel_Op<"noc_async_write"> {
let summary = "NocAsyncWrite";
let description = [{
NocAsyncWrite
}];

let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddr, I32:$size);
}

def TTKernel_NocAsyncWriteBarrierOp : TTKernel_Op<"noc_async_write_barrier"> {
let summary = "NocAsyncWriteBarrier";
let description = [{
NocAsyncWriteBarrier
}];
}

//===----------------------------------------------------------------------===//
// TTKernel Misc operations
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def TTKernel_CB : TTKernel_Type<"CB", "cb"> {
}];
}

def TTKernel_NocAddr : TTKernel_Type<"NocAddr", "noc_addr"> {
let summary = "TTKernel noc address";
let description = "Noc address type in TTKernel dialect";
}

def TTKernel_ThreadTypeAttr : EnumAttr<TTKernel_Dialect, TTKernel_ThreadType, "thread"> {
let assemblyFormat = "`<` $value `>`";
}
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTMetal/IR/TTMetalDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "mlir/IR/DialectImplementation.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTMetal/IR/TTMetalOps.h"
#include "ttmlir/Dialect/TTMetal/IR/TTMetalOpsTypes.h"
Expand Down Expand Up @@ -36,6 +37,19 @@ parseDimensionList(::mlir::AsmParser &odsParser,
// TTMetal dialect.
//===----------------------------------------------------------------------===//

struct TTMetalDialectFoldInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;

/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
bool shouldMaterializeInto(Region *region) const final {
// If this is a DispatchOp, protect it from hoisting constants outside of
// its region body
return isa<DispatchOp>(region->getParentOp());
}
};

void TTMetalDialect::initialize() {
addOperations<
#define GET_OP_LIST
Expand All @@ -47,4 +61,6 @@ void TTMetalDialect::initialize() {
#include "ttmlir/Dialect/TTMetal/IR/TTMetalOpsAttrDefs.cpp.inc"
>();
registerTypes();

addInterfaces<TTMetalDialectFoldInterface>();
}
71 changes: 64 additions & 7 deletions lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

#include "llvm/ADT/ScopeExit.h"

#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
Expand All @@ -22,6 +25,16 @@

namespace mlir::tt::ttmetal {

class TTKernelToEmitCTypeConverter : public TypeConverter {
public:
TTKernelToEmitCTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([ctx](mlir::tt::ttkernel::NocAddrType type) -> Type {
return Builder(ctx).getI64Type();
});
}
};

class TTMetalToEmitCFuncArgsRewriter : public OpRewritePattern<func::FuncOp> {
public:
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
Expand Down Expand Up @@ -67,25 +80,33 @@ class TTMetalToEmitCReturnRewriter
template <typename OpTy>
class TTMetalToEmitCOpaqueRewriter : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
TTMetalToEmitCOpaqueRewriter(TTKernelToEmitCTypeConverter &typeConverter,
MLIRContext *ctx)
: OpRewritePattern<OpTy>(ctx), typeConverter(&typeConverter) {}

StringRef getOpName(OpTy op) const {
if constexpr (std::is_same_v<OpTy, ttkernel::BuiltinOp>) {
return op.getOp();
}
auto name = op.getOperation()->getName().getStringRef();
if (name.starts_with("ttmetal.")) {
return name.drop_front(8);
if (name.starts_with("ttkernel.")) {
return name.drop_front(9);
}
return name;
}

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const final {
SmallVector<Type, 4> resultTypes;
for (auto type : op->getResultTypes()) {
resultTypes.push_back(typeConverter->convertType(type));
}
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, TypeRange(), getOpName(op), nullptr, nullptr, op->getOperands());
op, resultTypes, getOpName(op), nullptr, nullptr, op->getOperands());
return success();
}

TTKernelToEmitCTypeConverter *typeConverter;
};

LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
Expand All @@ -111,6 +132,19 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
threadTypeAttr);
builder.setInsertionPointToStart(&moduleBlock);

builder.create<emitc::IncludeOp>(module.getLoc(), "cstdint",
/*isStandard=*/true);
if (threadTypeAttr.getValue() == ttkernel::ThreadType::Noc0 ||
threadTypeAttr.getValue() == ttkernel::ThreadType::Noc1) {
builder.create<emitc::IncludeOp>(module.getLoc(), "dataflow_api.h",
/*isStandard=*/false);
}
if (threadTypeAttr.getValue() == ttkernel::ThreadType::Tensix) {
builder.create<emitc::IncludeOp>(module.getLoc(),
"compute_kernel_api/common.h",
/*isStandard=*/false);
}

// Create a new func op and move the existing block into it.
auto func = builder.create<func::FuncOp>(
module.getLoc(), "kernel_main",
Expand All @@ -120,14 +154,37 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
IRMapping irMapper;
funcBody->takeBody(region);

// Apply arith to emitc conversion first
{
ConversionTarget target(*module.getContext());
target.addLegalDialect<emitc::EmitCDialect>();
target.addIllegalDialect<arith::ArithDialect>();
RewritePatternSet arithPatterns(module.getContext());
TypeConverter arithTypeConverter;
arithTypeConverter.addConversion([](Type type) { return type; });
populateArithToEmitCPatterns(arithTypeConverter, arithPatterns);
if (failed(
applyPartialConversion(module, target, std::move(arithPatterns)))) {
return failure();
}
}

TTKernelToEmitCTypeConverter typeConverter(module.getContext());
RewritePatternSet patterns(module.getContext());
patterns.add<TTMetalToEmitCFuncArgsRewriter,
TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,

patterns.add<TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter>(
module.getContext());
patterns.add<TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPushBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPopFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBReserveBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBWaitFrontOp>,
TTMetalToEmitCReturnRewriter>(module.getContext());
TTMetalToEmitCOpaqueRewriter<ttkernel::GetNocAddrOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadBarrierOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteBarrierOp>>(
typeConverter, module.getContext());

FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(module, patternSet))) {
Expand Down
48 changes: 25 additions & 23 deletions lib/Dialect/TTMetal/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
#include "ttmlir/Dialect/TTMetal/Transforms/Passes.h"

#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
Expand Down Expand Up @@ -175,25 +172,29 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
arg.setType(ty);
}

rewriter.setInsertionPointToStart(noc0Block);
auto push0 = rewriter.create<ttkernel::CBPushBackOp>(
op.getLoc(), noc0Block->getArgument(0));
push0->remove();
noc0Block->push_back(push0);
auto return0 =
rewriter.create<ttkernel::ReturnOp>(op.getLoc(), ValueRange());
return0->remove();
noc0Block->push_back(return0);

rewriter.setInsertionPointToStart(noc1Block);
auto push1 = rewriter.create<ttkernel::CBPushBackOp>(
op.getLoc(), noc1Block->getArgument(1));
push1->remove();
noc1Block->push_back(push1);
auto return1 =
rewriter.create<ttkernel::ReturnOp>(op.getLoc(), ValueRange());
return1->remove();
noc1Block->push_back(return1);
{
OpBuilder noc0Builder(noc0Block, noc0Block->begin());
auto one = noc0Builder.create<arith::ConstantOp>(
op.getLoc(), noc0Builder.getI32Type(),
noc0Builder.getI32IntegerAttr(1));
noc0Builder.create<ttkernel::CBReserveBackOp>(
op.getLoc(), noc0Block->getArgument(0), one);
noc0Builder.create<ttkernel::CBPushBackOp>(
op.getLoc(), noc0Block->getArgument(0), one);
noc0Builder.create<ttkernel::ReturnOp>(op.getLoc(), ValueRange());
}

{
OpBuilder noc1Builder(noc1Block, noc1Block->begin());
auto one = noc1Builder.create<arith::ConstantOp>(
op.getLoc(), noc1Builder.getI32Type(),
noc1Builder.getI32IntegerAttr(1));
noc1Builder.create<ttkernel::CBReserveBackOp>(
op.getLoc(), noc1Block->getArgument(0), one);
noc1Builder.create<ttkernel::CBPushBackOp>(
op.getLoc(), noc1Block->getArgument(0), one);
noc1Builder.create<ttkernel::ReturnOp>(op.getLoc(), ValueRange());
}

rewriter.replaceOp(op, metalDispatch);

Expand Down Expand Up @@ -246,6 +247,7 @@ class ConvertTTIRToTTMetal
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::ttmetal::TTMetalDialect>();
registry.insert<mlir::tt::ttkernel::TTKernelDialect>();
registry.insert<mlir::arith::ArithDialect>();
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/python/simple_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def emit_call(self, value):
func = TTKernelBuilder.cb_fn_map[value.func.attr]
cb = self.symbol_table[value.func.value.id]
assert ttkernel.ir.CBType.cast(cb.type)
return func(cb)
return func(cb, self.get_constant(1))
elif self.symbol_table[value.func.value.id] == "Tensix":
assert (
value.func.attr in TTKernelBuilder.t6_fn_map
Expand Down

0 comments on commit f4fda50

Please sign in to comment.