Skip to content

Commit

Permalink
Runtime support for RuntimeArgs on KernelDesc (Tensor Addr + Semaphor…
Browse files Browse the repository at this point in the history
…es) #361/#267 (#650)

- 2 types of runtime_args today: RuntimeArgTensorAddress,
   RuntimeArgSemaphoreAddress where runtime will resolve tensor address
   or semaphore id, and call metal APIs SetRuntimeArgs() and CreateSemaphore()

 - Compiler support not included, so did some light testing with various
   dummy values for each type in flatbuffer via TTMetalToFlatbuffer.cpp
   changes, and tested assertions. For now, keep nullptr in .ttm for
   runtime_args and runtime_args_types
  • Loading branch information
kmabeeTT authored Sep 9, 2024
1 parent 97fdf8d commit 6bec67a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 2 deletions.
20 changes: 20 additions & 0 deletions include/ttmlir/Target/TTMetal/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ enum BinaryType : ushort {
ERISC = 5,
}

enum CoreType : ushort {
WORKER = 0,
ETH = 1,
}

table KernelSource {
source_type: SourceType;
source: string;
Expand All @@ -34,10 +39,25 @@ union Kernel {
KernelBinary,
}

table RuntimeArgTensorAddress {
operand_idx: uint32;
}

table RuntimeArgSemaphoreAddress {
initial_value: uint32;
core_type: CoreType;
}

union RuntimeArg {
RuntimeArgTensorAddress,
RuntimeArgSemaphoreAddress,
}

table KernelDesc {
kernel: Kernel;
core_range_set: [Dim2dRange];
cbs: [CBRef];
runtime_args: [RuntimeArg];
debug_info: string;
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Target/TTMetal/TTMetalToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ static std::shared_ptr<void> translateModuleToFlatbuffer(Operation *op) {
::tt::target::metal::CreateKernelSourceDirect(
fbb, toFlatbuffer(threadType), source.c_str())
.Union(),
&coreRangeSet, &cbs, nullptr /*TODO debug info*/));
&coreRangeSet, &cbs, nullptr, nullptr, /* TODO rtargs*/
nullptr /*TODO debug info*/));
}
::flatbuffers::Offset<::tt::target::metal::ProgramDesc> program =
::tt::target::metal::CreateProgramDescDirect(fbb, &kernels);
Expand Down
67 changes: 66 additions & 1 deletion runtime/lib/ttmetal/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) {
}
}

// Convert from Flatbuffer CoreType to soc_descriptor CoreType.
static CoreType toCoreType(::tt::target::metal::CoreType coreType) {
switch (coreType) {
case ::tt::target::metal::CoreType::WORKER:
return CoreType::WORKER;
case ::tt::target::metal::CoreType::ETH:
return CoreType::ETH;
}
throw std::runtime_error("Unsupported core type");
}

static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig(
::tt::target::CBRef const *cbRef,
std::unordered_map<std::uint32_t,
Expand All @@ -192,6 +203,57 @@ static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig(
.set_page_size(cbRef->desc()->port(), cbRef->desc()->page_size());
}

// Process various types of runtime args if present and call Metal APIs.
static void processRuntimeArgs(
::tt::tt_metal::Program &program,
::tt::target::metal::KernelDesc const *kernelDesc,
::tt::tt_metal::KernelHandle &handle, CoreRangeSet &coreRange,
const ::flatbuffers::Vector<::flatbuffers::Offset<tt::target::TensorRef>>
*operands,
std::unordered_map<std::uint32_t,
std::shared_ptr<::tt::tt_metal::Buffer>> const
&buffers) {

using SemaphoreAddr = ::tt::target::metal::RuntimeArgSemaphoreAddress;
using TensorAddr = ::tt::target::metal::RuntimeArgTensorAddress;

const auto *rt_args_types = kernelDesc->runtime_args_type();
const auto *rt_args = kernelDesc->runtime_args();

if (rt_args == nullptr || rt_args_types == nullptr || rt_args->size() == 0 ||
rt_args_types->size() == 0) {
return;
}

assert(rt_args_types->size() == rt_args->size());
std::vector<uint32_t> rt_args_vec;

for (size_t i = 0; i < rt_args->size(); i++) {
switch (rt_args_types->Get(i)) {
case ::tt::target::metal::RuntimeArg::RuntimeArgTensorAddress: {
const auto *rt_arg = static_cast<const TensorAddr *>(rt_args->Get(i));
assert(rt_arg->operand_idx() < operands->size() && "invalid operand");
uint32_t global_id = operands->Get(rt_arg->operand_idx())->global_id();
uint32_t addr = buffers.at(global_id)->address();
rt_args_vec.push_back(addr);
break;
}
case ::tt::target::metal::RuntimeArg::RuntimeArgSemaphoreAddress: {
const auto *rt_arg = static_cast<const SemaphoreAddr *>(rt_args->Get(i));
auto addr = ::tt::tt_metal::CreateSemaphore(
program, coreRange, rt_arg->initial_value(),
toCoreType(rt_arg->core_type()));
rt_args_vec.push_back(addr);
break;
}
case ::tt::target::metal::RuntimeArg::NONE:
throw std::runtime_error("Unsupported runtime arg type");
}
}

::tt::tt_metal::SetRuntimeArgs(program, handle, coreRange, rt_args_vec);
}

void CQExecutor::execute(
::tt::target::metal::EnqueueProgramCommand const *command) {
static int gKernelId = 0;
Expand All @@ -214,13 +276,16 @@ void CQExecutor::execute(
createKernelConfig(kernelSource);
::tt::tt_metal::KernelHandle handle =
::tt::tt_metal::CreateKernel(program, fileName, coreRange, config);
(void)handle; // only needed for runtime args, which aren't supported yet

for (::tt::target::CBRef const *cbRef : *kernelDesc->cbs()) {
::tt::tt_metal::CircularBufferConfig config =
createCircularBufferConfig(cbRef, buffers);
::tt::tt_metal::CreateCircularBuffer(program, coreRange, config);
}

// Process Kernel's runtime args based on variant and call metal APIs.
processRuntimeArgs(program, kernelDesc, handle, coreRange,
command->operands(), buffers);
}

constexpr bool blocking = false;
Expand Down

0 comments on commit 6bec67a

Please sign in to comment.