diff --git a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td index ed70d7da6..da4c9ef30 100644 --- a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td +++ b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td @@ -537,6 +537,153 @@ def TTKernel_NocAsyncWriteBarrierOp : TTKernel_Op<"noc_async_write_barrier"> { }]; } +def TTKernel_GetSemaphoreOp : TTKernel_Op<"get_semaphore"> { + let summary = "GetSemaphoreOp"; + let description = [{ + Get L1 addr of the semaphore with specified semaphore id + }]; + + let arguments = (ins I32:$semaphore_id); + let results = (outs TTKernel_L1Addr:$sem_addr); +} + +def TTKernel_NocSemaphoreIncOp : TTKernel_Op<"noc_semaphore_inc"> { + let summary = "NocSemaphoreInc"; + let description = [{ + The Tensix core executing this function call initiates an atomic increment + (with 32-bit wrap) of a remote Tensix core L1 memory address. This L1 memory + address is used as a semaphore of size 4 Bytes, as a synchronization + mechanism. + }]; + + let arguments = (ins TTKernel_NocAddr:$addr, I32:$incr, I32:$noc_id); +} + +def TTKernel_NocSemaphoreSetOp : TTKernel_Op<"noc_semaphore_set"> { + let summary = "NocSemaphoreSet"; + let description = [{ + Sets the value of a local L1 memory address on the Tensix core executing + this function to a specific value. This L1 memory address is used as a + semaphore of size 4 Bytes, as a synchronization mechanism. Also, see + *noc_semaphore_wait*. + }]; + + let arguments = (ins TTKernel_L1AddrPtr:$sem_addr, I32:$val); +} + +def TTKernel_NocSemaphoreWaitOp : TTKernel_Op<"noc_semaphore_wait"> { + let summary = "NocSemaphoreWait"; + let description = [{ + A blocking call that waits until the value of a local L1 memory address on + the Tensix core executing this function becomes equal to a target value. + This L1 memory address is used as a semaphore of size 4 Bytes, as a + synchronization mechanism. Also, see *noc_semaphore_set*. + }]; + + let arguments = (ins TTKernel_L1AddrPtr:$sem_addr, I32:$val); +} + +def TTKernel_NocSemaphoreWaitMinOp : TTKernel_Op<"noc_semaphore_wait_min"> { + let summary = "NocSemaphoreWaitMin"; + let description = [{ + A blocking call that waits until the value of a local L1 memory address on + the Tensix core executing this function becomes equal or greater than a target value. + This L1 memory address is used as a semaphore of size 4 Bytes, as a + synchronization mechanism. Also, see *noc_semaphore_set*. + }]; + + let arguments = (ins TTKernel_L1AddrPtr:$sem_addr, I32:$val); +} + +def TTKernel_NocSemaphoreSetMulticastOp : TTKernel_Op<"noc_semaphore_set_multicast"> { + let summary = "NocSemaphoreSetMulticast"; + let description = [{ + Initiates an asynchronous write from a source address in L1 memory on the + Tensix core executing this function call to a rectangular destination grid. + The destinations are specified using a uint64_t encoding referencing an + on-chip grid of nodes located at NOC coordinate range + (x_start,y_start,x_end,y_end) and a local address created using + *get_noc_multicast_addr* function. The size of data that is sent is 4 Bytes. + This is usually used to set a semaphore value at the destination nodes, as a + way of a synchronization mechanism. The same as *noc_async_write_multicast* + with preset size of 4 Bytes. + With this API, the multicast sender cannot be part of the multicast + destinations. If the multicast sender has to be in the multicast + destinations (i.e. must perform a local L1 write), the other API variant + *noc_semaphore_set_multicast_loopback_src* can be used. + }]; + + let arguments = (ins TTKernel_L1Addr:$src_local_l1_addr, + TTKernel_NocAddr:$dst_noc_addr_multicast, + I32:$num_dests, + BoolAttr:$linked, + BoolAttr:$multicast_path_reserve); +} + +def TTKernel_NocSemaphoreSetMulticastLoopbackOp : TTKernel_Op<"noc_semaphore_set_multicast_loopback_src"> { + let summary = "NocSemaphoreSetMulticastLoopback"; + let description = [{ + Initiates an asynchronous write from a source address in L1 memory on the + Tensix core executing this function call to a rectangular destination grid. + The destinations are specified using a uint64_t encoding referencing an + on-chip grid of nodes located at NOC coordinate range + (x_start,y_start,x_end,y_end) and a local address created using + *get_noc_multicast_addr* function. The size of data that is sent is 4 Bytes. + This is usually used to set a semaphore value at the destination nodes, as a + way of a synchronization mechanism. The same as *noc_async_write_multicast* + with preset size of 4 Bytes. + Note: With this API, sending data only to the source node (when num_dests + is 1) may result in unexpected behaviour. For some parameters, hangs have + been observed. For some other parameters, nothing may happen. Consider using + regular non multicast operations such as *noc_async_write* in this case. + }]; + + let arguments = (ins TTKernel_L1Addr:$src_local_l1_addr, + TTKernel_NocAddr:$dst_noc_addr_multicast, + I32:$num_dests, + BoolAttr:$linked, + BoolAttr:$multicast_path_reserve); +} + +//===----------------------------------------------------------------------===// +// TTKernel Compile and runtime arguments operations +//===----------------------------------------------------------------------===// + +def TTKernel_GetArgValOp : TTKernel_Op<"get_arg_val"> { + let summary = "Get runtime arg value."; + let description = [{ + Get runtime argument value at specified index. + }]; + + let arguments = (ins I32:$arg_index); + + let results = (outs AnyTypeOf<[TTKernel_Semaphore, I32]>:$arg_val); +} + +//===----------------------------------------------------------------------===// +// TTKernel Helper functions +//===----------------------------------------------------------------------===// + +def TTKernel_CastToL1PtrOp : TTKernel_Op<"reinterpret_cast"> { + let summary = "CastToL1Ptr"; + let description = [{ + Cast specified addr to L1 pointer. + }]; + + let arguments = (ins AnyTypeOf<[I32, TTKernel_L1Addr]>:$addr); + + let results = (outs TTKernel_L1AddrPtr:$l1_ptr); +} + +def TTKernel_StoreToL1Op : TTKernel_Op<"store_to_l1"> { + let summary = "StoreToL1"; + let description = [{ + Store value to L1. + }]; + + let arguments = (ins I32:$value, TTKernel_L1AddrPtr:$l1_ptr, I32:$offset); +} + //===----------------------------------------------------------------------===// // TTKernel Multicast NoC operations //===----------------------------------------------------------------------===// @@ -660,24 +807,4 @@ def TTKernel_GetWritePtrOp : TTKernel_Op<"get_write_ptr"> { let results = (outs I32:$writePtr); } -def TTKernel_CastToL1PtrOp : TTKernel_Op<"reinterpret_cast"> { - let summary = "CastToL1Ptr"; - let description = [{ - Cast specified addr to L1 pointer. - }]; - - let arguments = (ins AnyTypeOf<[I32, TTKernel_L1Addr]>:$addr); - - let results = (outs TTKernel_L1AddrPtr:$l1_ptr); -} - -def TTKernel_StoreToL1Op : TTKernel_Op<"store_to_l1"> { - let summary = "StoreToL1"; - let description = [{ - Store value to L1. - }]; - - let arguments = (ins I32:$value, TTKernel_L1AddrPtr:$l1_ptr, I32:$offset); -} - #endif diff --git a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td index 14a98ef2f..f1b8dc88c 100644 --- a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td +++ b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td @@ -105,6 +105,13 @@ def TTKernel_CB : TTKernel_Type<"CB", "cb"> { }]; } +def TTKernel_Semaphore : TTKernel_Type<"Semaphore", "semaphore"> { + let summary = "TTKernel semaphore"; + let description = "Semaphore type in TTKernel dialect"; + let parameters = (ins "uint32_t":$initial_value); + let assemblyFormat = "`<` $initial_value `>`"; +} + def TTKernel_NocAddr : TTKernel_Type<"NocAddr", "noc_addr"> { let summary = "TTKernel noc address"; let description = "Noc address type in TTKernel dialect"; diff --git a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp index 227a7b080..fef4b414e 100644 --- a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp +++ b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp @@ -181,6 +181,10 @@ class TTMetalToEmitCFuncArgsRewriter rewriter.startOpModification(op); rewriter.setInsertionPointToStart(&op.getCallableRegion()->front()); for (auto arg : blockArgs) { + // Skip initialization if the argument is not a CBType (SemaphoreType) + if (!mlir::isa(arg.getType())) { + continue; + } auto cb = cast(arg.getType()); // Get opaque type i.e emitc::LValueType auto cbType = getTypeConverter()->convertType(cb); @@ -276,6 +280,12 @@ class TTMetalToEmitCOpaqueRewriter : public OpConversionPattern { template_args.push_back( emitc::OpaqueAttr::get(op.getContext(), reduceDim)); return ArrayAttr::get(op.getContext(), template_args); + } else if constexpr (std::is_same_v) { + SmallVector template_args; + + template_args.push_back( + emitc::OpaqueAttr::get(op.getContext(), "uint32_t")); + return ArrayAttr::get(op.getContext(), template_args); } return ArrayAttr(); } @@ -381,59 +391,68 @@ class ConvertTTKernelToEmitCPass target.addLegalOp(); target.addIllegalDialect(); - patterns - .add, - TTKernelMacroOpToEmitCOpRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter< - ttkernel::NocAsyncReadOnePacketSetStateOp>, - TTMetalToEmitCOpaqueRewriter< - ttkernel::NocAsyncReadOnePacketWithStateOp>, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter< - ttkernel::NocAsyncWriteMulticastOnePacketOp>, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter< - ttkernel::NocAsyncWriteMulticastLoopbackSrcOp>, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter, - TTMetalToEmitCOpaqueRewriter>( - typeConverter, funcOp.getContext()); + patterns.add< + TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter, + TTKernelMacroOpToEmitCOpRewriter, + TTKernelMacroOpToEmitCOpRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocSemaphoreSetMulticastLoopbackOp>, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncReadOnePacketSetStateOp>, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncReadOnePacketWithStateOp>, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncWriteMulticastOnePacketOp>, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncWriteMulticastLoopbackSrcOp>, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter>( + typeConverter, funcOp.getContext()); patterns.add>( typeConverter, funcOp.getContext(), "get_noc_addr"); diff --git a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp index 7f78c1afc..d1c7dc88e 100644 --- a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp +++ b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp @@ -89,8 +89,9 @@ ::mlir::LogicalResult DispatchOp::verify() { // Assert block inputs are CBs for (auto ®ion : getRegions()) { for (auto arg : region.getArguments()) { - if (not mlir::isa(arg.getType())) { - return emitOpError("Block inputs must be CBType"); + if (!mlir::isa(arg.getType()) && + !mlir::isa(arg.getType())) { + return emitOpError("Block inputs must be CBType or SemType"); } } } diff --git a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp index e82deaf63..94a2d4d34 100644 --- a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp +++ b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp @@ -220,6 +220,15 @@ cbTypeToFlatbuffer(FlatbufferObjectCache &cache, ttkernel::CBType cbType) { memref, cbType.getPageSize(), cbType.getNumBuffers()); } +std::pair<::tt::target::metal::RuntimeArg, ::flatbuffers::Offset> +toFlatbuffer(FlatbufferObjectCache &cache, ttkernel::SemaphoreType sem) { + auto runtimeArgType = + ::tt::target::metal::RuntimeArg::RuntimeArgSemaphoreAddress; + auto semAddr = ::tt::target::metal::CreateRuntimeArgSemaphoreAddress( + *cache.fbb, sem.getInitialValue()); + return std::make_pair(runtimeArgType, semAddr.Union()); +} + std::pair<::tt::target::metal::HostBuffer, ::flatbuffers::Offset> hostBufferToFlatbuffer(FlatbufferObjectCache &cache, ElementsAttr elementsAttr) { @@ -311,14 +320,27 @@ static std::shared_ptr translateModuleToFlatbuffer( dispatchOp.getCoreRanges()[region.getRegionNumber()]))}; std::vector<::flatbuffers::Offset<::tt::target::CBRef>> cbs; size_t argNumber = 0; + std::vector<::tt::target::metal::RuntimeArg> runtime_args_type; + std::vector<::flatbuffers::Offset> runtime_args; for (auto arg : region.getArguments()) { - auto cbType = mlir::cast(arg.getType()); - auto cbDesc = cache.getOrCreate(cbType, cbTypeToFlatbuffer); - auto tensorRef = - argNumber >= operands.size() ? 0 : operands[argNumber++]; - cbs.push_back( - ::tt::target::CreateCBRef(fbb, cache.global_id++, tensorRef, - cbType.getAddress(), cbDesc)); + if (mlir::isa(arg.getType())) { + auto cbType = mlir::cast(arg.getType()); + auto cbDesc = cache.getOrCreate(cbType, cbTypeToFlatbuffer); + auto tensorRef = + argNumber >= operands.size() ? 0 : operands[argNumber++]; + cbs.push_back( + ::tt::target::CreateCBRef(fbb, cache.global_id++, tensorRef, + cbType.getAddress(), cbDesc)); + } else if (mlir::isa(arg.getType())) { + auto semType = mlir::cast(arg.getType()); + auto [runtime_arg_type, runtime_arg] = + toFlatbuffer(cache, semType); + runtime_args_type.push_back(runtime_arg_type); + runtime_args.push_back(runtime_arg); + } else { + llvm_unreachable( + "Block arguments must be either CBType or SemaphoreType"); + } } std::string &source = cppKernels[region.getRegionNumber()]; @@ -335,7 +357,7 @@ static std::shared_ptr translateModuleToFlatbuffer( ::tt::target::metal::CreateKernelSourceDirect( fbb, source.c_str(), kernelConfigType, kernelConfigUnion) .Union(), - &coreRangeSet, &cbs, nullptr, nullptr, /* TODO rtargs*/ + &coreRangeSet, &cbs, &runtime_args_type, &runtime_args, nullptr /*TODO debug info*/)); } ::flatbuffers::Offset<::tt::target::metal::ProgramDesc> program =