diff --git a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td index 423eec2aa..9231e986b 100644 --- a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td +++ b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td @@ -133,7 +133,7 @@ def TTKernel_CBPushBackOp : TTKernel_Op<"cb_push_back"> { CBPushBack operation }]; - let arguments = (ins TTKernel_CB:$cb); + let arguments = (ins TTKernel_CB:$cb, I32:$numPages); let hasVerifier = 1; } @@ -144,7 +144,7 @@ def TTKernel_CBPopFrontOp : TTKernel_Op<"cb_pop_front"> { CBPopFront operation }]; - let arguments = (ins TTKernel_CB:$cb); + let arguments = (ins TTKernel_CB:$cb, I32:$numPages); let hasVerifier = 1; } @@ -155,7 +155,7 @@ def TTKernel_CBReserveBackOp : TTKernel_Op<"cb_reserve_back"> { CBReserveBack operation }]; - let arguments = (ins TTKernel_CB:$cb); + let arguments = (ins TTKernel_CB:$cb, I32:$numPages); let hasVerifier = 1; } @@ -166,11 +166,57 @@ def TTKernel_CBWaitFrontOp : TTKernel_Op<"cb_wait_front"> { CBWaitFront operation }]; - let arguments = (ins TTKernel_CB:$cb); + let arguments = (ins TTKernel_CB:$cb, I32:$numPages); let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// TTKernel NOC operations +//===----------------------------------------------------------------------===// + +def TTKernel_GetNocAddrOp : TTKernel_Op<"get_noc_addr"> { + let summary = "GetNocAddr"; + let description = [{ + GetNocAddr + }]; + + let arguments = (ins I32:$x, I32:$y, I32:$l1Address); + let results = (outs TTKernel_NocAddr:$nocAddr); +} + +def TTKernel_NocAsyncReadOp : TTKernel_Op<"noc_async_read"> { + let summary = "NocAsyncRead"; + let description = [{ + NocAsyncRead + }]; + + let arguments = (ins TTKernel_NocAddr:$srcNocAddr, I32:$dstLocalL1Addr, I32:$size); +} + +def TTKernel_NocAsyncReadBarrierOp : TTKernel_Op<"noc_async_read_barrier"> { + let summary = "NocAsyncReadBarrier"; + let description = [{ + NocAsyncReadBarrier + }]; +} + +def TTKernel_NocAsyncWriteOp : TTKernel_Op<"noc_async_write"> { + let summary = "NocAsyncWrite"; + let description = [{ + NocAsyncWrite + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddr, I32:$size); +} + +def TTKernel_NocAsyncWriteBarrierOp : TTKernel_Op<"noc_async_write_barrier"> { + let summary = "NocAsyncWriteBarrier"; + let description = [{ + NocAsyncWriteBarrier + }]; +} + //===----------------------------------------------------------------------===// // TTKernel Misc operations //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td index 5007abb31..d720ae46f 100644 --- a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td +++ b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td @@ -30,6 +30,11 @@ def TTKernel_CB : TTKernel_Type<"CB", "cb"> { }]; } +def TTKernel_NocAddr : TTKernel_Type<"NocAddr", "noc_addr"> { + let summary = "TTKernel noc address"; + let description = "Noc address type in TTKernel dialect"; +} + def TTKernel_ThreadTypeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } diff --git a/lib/Dialect/TTMetal/IR/TTMetalDialect.cpp b/lib/Dialect/TTMetal/IR/TTMetalDialect.cpp index 2ee615b1f..d27a3d168 100644 --- a/lib/Dialect/TTMetal/IR/TTMetalDialect.cpp +++ b/lib/Dialect/TTMetal/IR/TTMetalDialect.cpp @@ -6,6 +6,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/InitAllDialects.h" +#include "mlir/Interfaces/FoldInterfaces.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TTMetal/IR/TTMetalOps.h" #include "ttmlir/Dialect/TTMetal/IR/TTMetalOpsTypes.h" @@ -36,6 +37,19 @@ parseDimensionList(::mlir::AsmParser &odsParser, // TTMetal dialect. //===----------------------------------------------------------------------===// +struct TTMetalDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; + + /// Registered hook to check if the given region, which is attached to an + /// operation that is *not* isolated from above, should be used when + /// materializing constants. + bool shouldMaterializeInto(Region *region) const final { + // If this is a DispatchOp, protect it from hoisting constants outside of + // its region body + return isa(region->getParentOp()); + } +}; + void TTMetalDialect::initialize() { addOperations< #define GET_OP_LIST @@ -47,4 +61,6 @@ void TTMetalDialect::initialize() { #include "ttmlir/Dialect/TTMetal/IR/TTMetalOpsAttrDefs.cpp.inc" >(); registerTypes(); + + addInterfaces(); } diff --git a/lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp b/lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp index 0d9848731..4bd12ebf5 100644 --- a/lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp +++ b/lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp @@ -4,12 +4,15 @@ #include "llvm/ADT/ScopeExit.h" +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/Cpp/CppEmitter.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" @@ -22,6 +25,16 @@ namespace mlir::tt::ttmetal { +class TTKernelToEmitCTypeConverter : public TypeConverter { +public: + TTKernelToEmitCTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion([ctx](mlir::tt::ttkernel::NocAddrType type) -> Type { + return Builder(ctx).getI64Type(); + }); + } +}; + class TTMetalToEmitCFuncArgsRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -67,25 +80,33 @@ class TTMetalToEmitCReturnRewriter template class TTMetalToEmitCOpaqueRewriter : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + TTMetalToEmitCOpaqueRewriter(TTKernelToEmitCTypeConverter &typeConverter, + MLIRContext *ctx) + : OpRewritePattern(ctx), typeConverter(&typeConverter) {} StringRef getOpName(OpTy op) const { if constexpr (std::is_same_v) { return op.getOp(); } auto name = op.getOperation()->getName().getStringRef(); - if (name.starts_with("ttmetal.")) { - return name.drop_front(8); + if (name.starts_with("ttkernel.")) { + return name.drop_front(9); } return name; } LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final { + SmallVector resultTypes; + for (auto type : op->getResultTypes()) { + resultTypes.push_back(typeConverter->convertType(type)); + } rewriter.replaceOpWithNewOp( - op, TypeRange(), getOpName(op), nullptr, nullptr, op->getOperands()); + op, resultTypes, getOpName(op), nullptr, nullptr, op->getOperands()); return success(); } + + TTKernelToEmitCTypeConverter *typeConverter; }; LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp, @@ -111,6 +132,19 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp, threadTypeAttr); builder.setInsertionPointToStart(&moduleBlock); + builder.create(module.getLoc(), "cstdint", + /*isStandard=*/true); + if (threadTypeAttr.getValue() == ttkernel::ThreadType::Noc0 || + threadTypeAttr.getValue() == ttkernel::ThreadType::Noc1) { + builder.create(module.getLoc(), "dataflow_api.h", + /*isStandard=*/false); + } + if (threadTypeAttr.getValue() == ttkernel::ThreadType::Tensix) { + builder.create(module.getLoc(), + "compute_kernel_api/common.h", + /*isStandard=*/false); + } + // Create a new func op and move the existing block into it. auto func = builder.create( module.getLoc(), "kernel_main", @@ -120,14 +154,37 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp, IRMapping irMapper; funcBody->takeBody(region); + // Apply arith to emitc conversion first + { + ConversionTarget target(*module.getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + RewritePatternSet arithPatterns(module.getContext()); + TypeConverter arithTypeConverter; + arithTypeConverter.addConversion([](Type type) { return type; }); + populateArithToEmitCPatterns(arithTypeConverter, arithPatterns); + if (failed( + applyPartialConversion(module, target, std::move(arithPatterns)))) { + return failure(); + } + } + + TTKernelToEmitCTypeConverter typeConverter(module.getContext()); RewritePatternSet patterns(module.getContext()); - patterns.add, + + patterns.add( + module.getContext()); + patterns.add, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCReturnRewriter>(module.getContext()); + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter>( + typeConverter, module.getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(module, patternSet))) { diff --git a/lib/Dialect/TTMetal/Transforms/Passes.cpp b/lib/Dialect/TTMetal/Transforms/Passes.cpp index 05b5051b3..6fd7474de 100644 --- a/lib/Dialect/TTMetal/Transforms/Passes.cpp +++ b/lib/Dialect/TTMetal/Transforms/Passes.cpp @@ -5,16 +5,13 @@ #include "ttmlir/Dialect/TTMetal/Transforms/Passes.h" #include "mlir/Analysis/Liveness.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassManager.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" @@ -175,25 +172,29 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { arg.setType(ty); } - rewriter.setInsertionPointToStart(noc0Block); - auto push0 = rewriter.create( - op.getLoc(), noc0Block->getArgument(0)); - push0->remove(); - noc0Block->push_back(push0); - auto return0 = - rewriter.create(op.getLoc(), ValueRange()); - return0->remove(); - noc0Block->push_back(return0); - - rewriter.setInsertionPointToStart(noc1Block); - auto push1 = rewriter.create( - op.getLoc(), noc1Block->getArgument(1)); - push1->remove(); - noc1Block->push_back(push1); - auto return1 = - rewriter.create(op.getLoc(), ValueRange()); - return1->remove(); - noc1Block->push_back(return1); + { + OpBuilder noc0Builder(noc0Block, noc0Block->begin()); + auto one = noc0Builder.create( + op.getLoc(), noc0Builder.getI32Type(), + noc0Builder.getI32IntegerAttr(1)); + noc0Builder.create( + op.getLoc(), noc0Block->getArgument(0), one); + noc0Builder.create( + op.getLoc(), noc0Block->getArgument(0), one); + noc0Builder.create(op.getLoc(), ValueRange()); + } + + { + OpBuilder noc1Builder(noc1Block, noc1Block->begin()); + auto one = noc1Builder.create( + op.getLoc(), noc1Builder.getI32Type(), + noc1Builder.getI32IntegerAttr(1)); + noc1Builder.create( + op.getLoc(), noc1Block->getArgument(0), one); + noc1Builder.create( + op.getLoc(), noc1Block->getArgument(0), one); + noc1Builder.create(op.getLoc(), ValueRange()); + } rewriter.replaceOp(op, metalDispatch); @@ -246,6 +247,7 @@ class ConvertTTIRToTTMetal registry.insert(); registry.insert(); registry.insert(); + registry.insert(); } }; diff --git a/test/python/simple_kernel.py b/test/python/simple_kernel.py index d1fc267ce..0e708d7d5 100644 --- a/test/python/simple_kernel.py +++ b/test/python/simple_kernel.py @@ -150,7 +150,7 @@ def emit_call(self, value): func = TTKernelBuilder.cb_fn_map[value.func.attr] cb = self.symbol_table[value.func.value.id] assert ttkernel.ir.CBType.cast(cb.type) - return func(cb) + return func(cb, self.get_constant(1)) elif self.symbol_table[value.func.value.id] == "Tensix": assert ( value.func.attr in TTKernelBuilder.t6_fn_map