Skip to content

Commit

Permalink
Support tt-metal semaphores
Browse files Browse the repository at this point in the history
  • Loading branch information
pjanevskiTT committed Oct 15, 2024
1 parent 90fbe57 commit 7de191f
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 46 deletions.
130 changes: 130 additions & 0 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,136 @@ def TTKernel_NocAsyncWriteBarrierOp : TTKernel_Op<"noc_async_write_barrier"> {
}];
}

def TTKernel_GetSemaphoreOp : TTKernel_Op<"get_semaphore"> {
let summary = "GetSemaphoreOp";
let description = [{
Get L1 addr of the semaphore with specified semaphore id
}];

let arguments = (ins I32:$semaphore_id);
let results = (outs TTKernel_L1Addr:$sem_addr);
}

def TTKernel_NocSemaphoreIncOp : TTKernel_Op<"noc_semaphore_inc"> {
let summary = "NocSemaphoreInc";
let description = [{
The Tensix core executing this function call initiates an atomic increment
(with 32-bit wrap) of a remote Tensix core L1 memory address. This L1 memory
address is used as a semaphore of size 4 Bytes, as a synchronization
mechanism.
}];

let arguments = (ins TTKernel_NocAddr:$addr, I32:$incr, I32:$noc_id);
}

def TTKernel_NocSemaphoreSetOp : TTKernel_Op<"noc_semaphore_set"> {
let summary = "NocSemaphoreSet";
let description = [{
Sets the value of a local L1 memory address on the Tensix core executing
this function to a specific value. This L1 memory address is used as a
semaphore of size 4 Bytes, as a synchronization mechanism. Also, see
*noc_semaphore_wait*.
}];

let arguments = (ins I32:$sem_addr, I32:$val);
}

def TTKernel_NocSemaphoreWaitOp : TTKernel_Op<"noc_semaphore_wait"> {
let summary = "NocSemaphoreWait";
let description = [{
A blocking call that waits until the value of a local L1 memory address on
the Tensix core executing this function becomes equal to a target value.
This L1 memory address is used as a semaphore of size 4 Bytes, as a
synchronization mechanism. Also, see *noc_semaphore_set*.
}];

let arguments = (ins TTKernel_L1AddrPtr:$sem_addr, I32:$val);
}

def TTKernel_NocSemaphoreWaitMinOp : TTKernel_Op<"noc_semaphore_wait_min"> {
let summary = "NocSemaphoreWaitMin";
let description = [{
A blocking call that waits until the value of a local L1 memory address on
the Tensix core executing this function becomes equal to a target value.
This L1 memory address is used as a semaphore of size 4 Bytes, as a
synchronization mechanism. Also, see *noc_semaphore_set*.
}];

let arguments = (ins TTKernel_L1AddrPtr:$sem_addr, I32:$val);
}

def TTKernel_NocSemaphoreSetMulticastOp : TTKernel_Op<"noc_semaphore_set_multicast"> {
let summary = "NocSemaphoreSetMulticast";
let description = [{
Initiates an asynchronous write from a source address in L1 memory on the
Tensix core executing this function call to a rectangular destination grid.
The destinations are specified using a uint64_t encoding referencing an
on-chip grid of nodes located at NOC coordinate range
(x_start,y_start,x_end,y_end) and a local address created using
*get_noc_multicast_addr* function. The size of data that is sent is 4 Bytes.
This is usually used to set a semaphore value at the destination nodes, as a
way of a synchronization mechanism. The same as *noc_async_write_multicast*
with preset size of 4 Bytes.
With this API, the multicast sender cannot be part of the multicast
destinations. If the multicast sender has to be in the multicast
destinations (i.e. must perform a local L1 write), the other API variant
*noc_semaphore_set_multicast_loopback_src* can be used.
}];

let arguments = (ins I32:$src_local_l1_addr, TTKernel_NocAddr:$dst_noc_addr_multicast, I32:$num_dests, BoolAttr:$linked, BoolAttr:$multicast_path_reserve);
}

def TTKernel_NocSemaphoreSetMulticastLoopbackOp : TTKernel_Op<"noc_semaphore_set_multicast_loopback_src"> {
let summary = "NocSemaphoreSetMulticastLoopback";
let description = [{
Initiates an asynchronous write from a source address in L1 memory on the
Tensix core executing this function call to a rectangular destination grid.
The destinations are specified using a uint64_t encoding referencing an
on-chip grid of nodes located at NOC coordinate range
(x_start,y_start,x_end,y_end) and a local address created using
*get_noc_multicast_addr* function. The size of data that is sent is 4 Bytes.
This is usually used to set a semaphore value at the destination nodes, as a
way of a synchronization mechanism. The same as *noc_async_write_multicast*
with preset size of 4 Bytes.
Note: With this API, sending data only to the source node (when num_dests
is 1) may result in unexpected behaviour. For some parameters, hangs have
been observed. For some other parameters, nothing may happen. Consider using
regular non multicast operations such as *noc_async_write* in this case.
}];

let arguments = (ins I32:$src_local_l1_addr, TTKernel_NocAddr:$dst_noc_addr_multicast, I32:$num_dests, BoolAttr:$linked, BoolAttr:$multicast_path_reserve);
}

//===----------------------------------------------------------------------===//
// TTKernel Compile and runtime arguments operations
//===----------------------------------------------------------------------===//

def TTKernel_GetArgValOp : TTKernel_Op<"get_arg_val"> {
let summary = "Get runtime arg value.";
let description = [{
Get runtime argument value at specified index.
}];

let arguments = (ins I32:$arg_index);

let results = (outs I32:$arg_val);
}

//===----------------------------------------------------------------------===//
// TTKernel Helper functions
//===----------------------------------------------------------------------===//

def TTKernel_CastToL1PtrOp : TTKernel_Op<"reinterpret_cast<volatile tt_l1_ptr uint32_t*>"> {
let summary = "CastToL1Ptr";
let description = [{
Cast specified addr to L1 pointer.
}];

let arguments = (ins TTKernel_L1Addr:$addr);

let results = (outs TTKernel_L1AddrPtr:$l1_ptr);
}

//===----------------------------------------------------------------------===//
// TTKernel Misc operations
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,28 @@ def TTKernel_CB : TTKernel_Type<"CB", "cb"> {
}];
}

def TTKernel_Semaphore : TTKernel_Type<"Semaphore", "semaphore"> {
let summary = "TTKernel semaphore";
let description = "Semaphore type in TTKernel dialect";
let parameters = (ins "uint32_t":$initial_value);
let assemblyFormat = "`<` $initial_value `>`";
}

def TTKernel_NocAddr : TTKernel_Type<"NocAddr", "noc_addr"> {
let summary = "TTKernel noc address";
let description = "Noc address type in TTKernel dialect";
}

def TTKernel_L1Addr : TTKernel_Type<"L1Addr", "l1_addr"> {
let summary = "TTKernel l1 address";
let description = "L1 address type in TTKernel dialect";
}

def TTKernel_L1AddrPtr : TTKernel_Type<"L1AddrPtr", "l1_addr_ptr"> {
let summary = "TTKernel l1 address pointer";
let description = "L1 pointer address type in TTKernel dialect";
}

def TTKernel_ThreadTypeAttr : EnumAttr<TTKernel_Dialect, TTKernel_ThreadType, "thread"> {
let assemblyFormat = "`<` $value `>`";
}
Expand Down
108 changes: 72 additions & 36 deletions lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ class TTKernelToEmitCTypeConverter : public TypeConverter {
addConversion([ctx](mlir::tt::ttkernel::CBType type) -> Type {
return Builder(ctx).getType<emitc::OpaqueType>("::tt::CB");
});
addConversion([ctx](mlir::tt::ttkernel::L1AddrType type) -> Type {
return Builder(ctx).getI32Type();
});
addConversion([ctx](mlir::tt::ttkernel::L1AddrPtrType type) -> Type {
// TODO: This is a very hacky way to get around the limitation that emitc
// won't emit strings that have * at the back of the string. That is why
// we have space at the end.
return Builder(ctx).getType<emitc::OpaqueType>(
"volatile tt_l1_ptr uint32_t* ");
});
}
};

Expand All @@ -137,6 +147,10 @@ class TTMetalToEmitCFuncArgsRewriter
rewriter.startOpModification(op);
rewriter.setInsertionPointToStart(&op.getCallableRegion()->front());
for (auto arg : blockArgs) {
// Skip initialization if the argument is not a CBType (SemaphoreType)
if (!mlir::isa<ttkernel::CBType>(arg.getType())) {
continue;
}
auto cb = cast<ttkernel::CBType>(arg.getType());
auto cbType = getTypeConverter()->convertType(cb);
auto var = rewriter.create<emitc::VariableOp>(
Expand Down Expand Up @@ -187,6 +201,18 @@ class TTMetalToEmitCOpaqueRewriter : public OpConversionPattern<SourceOp> {
return name;
}

ArrayAttr getTemplateArgs(SourceOp op) const {
if constexpr (std::is_same_v<SourceOp, ttkernel::GetArgValOp>) {
SmallVector<Attribute, 1> template_args;

template_args.push_back(
emitc::OpaqueAttr::get(op.getContext(), "uint32_t"));

return ArrayAttr::get(op.getContext(), template_args);
}
return ArrayAttr();
}

LogicalResult
matchAndRewrite(SourceOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Expand All @@ -199,7 +225,7 @@ class TTMetalToEmitCOpaqueRewriter : public OpConversionPattern<SourceOp> {
resultTypes.push_back(ct);
}
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, resultTypes, getOpName(op), nullptr, nullptr,
op, resultTypes, getOpName(op), nullptr, getTemplateArgs(op),
adaptor.getOperands());
return success();
}
Expand Down Expand Up @@ -258,41 +284,51 @@ class ConvertTTKernelToEmitCPass
target.addLegalOp<func::ReturnOp>();
target.addIllegalDialect<ttkernel::TTKernelDialect>();

patterns
.add<TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter,
TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CopyTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsAcquireOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsCommitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsWaitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsReleaseOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::PackTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPushBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPopFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBReserveBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBWaitFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TilizeInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UntilizeInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TilizeBlockOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UntilizeBlockOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::BinaryOpInitCommonOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitFOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetNocAddrOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadBarrierOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteBarrierOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UnaryOpInitCommonOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CopyTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::ExpTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::ExpTileOp>>(
typeConverter, funcOp.getContext());
patterns.add<
TTMetalToEmitCFuncArgsRewriter, TTMetalToEmitCReturnRewriter,
TTMetalToEmitCOpaqueRewriter<ttkernel::BuiltinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetArgValOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CastToL1PtrOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetSemaphoreOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocSemaphoreSetOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocSemaphoreWaitMinOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocSemaphoreIncOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocSemaphoreWaitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocSemaphoreSetMulticastOp>,
TTMetalToEmitCOpaqueRewriter<
ttkernel::NocSemaphoreSetMulticastLoopbackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CopyTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::RecipTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsAcquireOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsCommitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsWaitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TileRegsReleaseOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::PackTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPushBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPopFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBReserveBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBWaitFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TilizeInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UntilizeInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TilizeBlockOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UntilizeBlockOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::BinaryOpInitCommonOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesInitFOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::AddTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::MulTilesOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetNocAddrOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadBarrierOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteBarrierOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UnaryOpInitCommonOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CopyTileOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::ExpTileInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::ExpTileOp>>(
typeConverter, funcOp.getContext());

if (failed(applyFullConversion(funcOp, target, std::move(patterns)))) {
signalPassFailure();
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TTMetal/IR/TTMetalOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ ::mlir::LogicalResult DispatchOp::verify() {
// Assert block inputs are CBs
for (auto &region : getRegions()) {
for (auto arg : region.getArguments()) {
if (not mlir::isa<ttkernel::CBType>(arg.getType())) {
return emitOpError("Block inputs must be CBType");
if (not(mlir::isa<ttkernel::CBType>(arg.getType()) ||
mlir::isa<ttkernel::SemaphoreType>(arg.getType()))) {
return emitOpError("Block inputs must be CBType or SemType");
}
}
}
Expand Down
37 changes: 29 additions & 8 deletions lib/Target/TTMetal/TTMetalToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ cbTypeToFlatbuffer(FlatbufferObjectCache &cache, ttkernel::CBType cbType) {
memref, cbType.getPageSize(), cbType.getNumBuffers());
}

std::pair<::tt::target::metal::RuntimeArg, ::flatbuffers::Offset<void>>
toFlatbuffer(FlatbufferObjectCache &cache, ttkernel::SemaphoreType sem) {
auto runtimeArgType =
::tt::target::metal::RuntimeArg::RuntimeArgSemaphoreAddress;
auto semAddr = ::tt::target::metal::CreateRuntimeArgSemaphoreAddress(
*cache.fbb, sem.getInitialValue());
return std::make_pair(runtimeArgType, semAddr.Union());
}

std::pair<::tt::target::metal::HostBuffer, ::flatbuffers::Offset<void>>
hostBufferToFlatbuffer(FlatbufferObjectCache &cache,
ElementsAttr elementsAttr) {
Expand Down Expand Up @@ -239,14 +248,26 @@ static std::shared_ptr<void> translateModuleToFlatbuffer(Operation *op) {
toFlatbuffer(mlir::cast<CoreRangeAttr>(
dispatchOp.getCoreRanges()[region.getRegionNumber()]))};
std::vector<::flatbuffers::Offset<::tt::target::CBRef>> cbs;
std::vector<::tt::target::metal::RuntimeArg> runtime_args_type;
std::vector<::flatbuffers::Offset<void>> runtime_args;
for (auto arg : region.getArguments()) {
assert(arg.getArgNumber() < operands.size());
auto cbType = mlir::cast<ttkernel::CBType>(arg.getType());
auto cbDesc = cache.getOrCreate(cbType, cbTypeToFlatbuffer);
auto tensorRef = operands[arg.getArgNumber()];
cbs.push_back(
::tt::target::CreateCBRef(fbb, cache.global_id++, tensorRef,
cbType.getAddress(), cbDesc));
if (mlir::isa<ttkernel::CBType>(arg.getType())) {
auto cbType = mlir::cast<ttkernel::CBType>(arg.getType());
auto cbDesc = cache.getOrCreate(cbType, cbTypeToFlatbuffer);
auto tensorRef = operands[arg.getArgNumber()];
cbs.push_back(
::tt::target::CreateCBRef(fbb, cache.global_id++, tensorRef,
cbType.getAddress(), cbDesc));
} else if (mlir::isa<ttkernel::SemaphoreType>(arg.getType())) {
auto semType = mlir::cast<ttkernel::SemaphoreType>(arg.getType());
auto [runtime_arg_type, runtime_arg] =
toFlatbuffer(cache, semType);
runtime_args_type.push_back(runtime_arg_type);
runtime_args.push_back(runtime_arg);
} else {
llvm_unreachable(
"Block arguments must be either CBType or SemaphoreType");
}
}

std::string &source = cppKernels[region.getRegionNumber()];
Expand All @@ -263,7 +284,7 @@ static std::shared_ptr<void> translateModuleToFlatbuffer(Operation *op) {
::tt::target::metal::CreateKernelSourceDirect(
fbb, source.c_str(), kernelConfigType, kernelConfigUnion)
.Union(),
&coreRangeSet, &cbs, nullptr, nullptr, /* TODO rtargs*/
&coreRangeSet, &cbs, &runtime_args_type, &runtime_args,
nullptr /*TODO debug info*/));
}
::flatbuffers::Offset<::tt::target::metal::ProgramDesc> program =
Expand Down

0 comments on commit 7de191f

Please sign in to comment.