Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime support for RuntimeArgs on KernelDesc (Tensor Addr + Semaphores) #650

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/ttmlir/Target/TTMetal/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ enum BinaryType : ushort {
ERISC = 5,
}

enum CoreType : ushort {
WORKER = 0,
ETH = 1,
}

table KernelSource {
source_type: SourceType;
source: string;
Expand All @@ -34,10 +39,25 @@ union Kernel {
KernelBinary,
}

table RuntimeArgTensorAddress {
operand_idx: uint32;
}

table RuntimeArgSemaphoreAddress {
initial_value: uint32;
core_type: CoreType;
}

union RuntimeArg {
RuntimeArgTensorAddress,
RuntimeArgSemaphoreAddress,
}

table KernelDesc {
kernel: Kernel;
core_range_set: [Dim2dRange];
cbs: [CBRef];
runtime_args: [RuntimeArg];
debug_info: string;
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Target/TTMetal/TTMetalToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ static std::shared_ptr<void> translateModuleToFlatbuffer(Operation *op) {
::tt::target::metal::CreateKernelSourceDirect(
fbb, toFlatbuffer(threadType), source.c_str())
.Union(),
&coreRangeSet, &cbs, nullptr /*TODO debug info*/));
&coreRangeSet, &cbs, nullptr, nullptr, /* TODO rtargs*/
nullptr /*TODO debug info*/));
}
::flatbuffers::Offset<::tt::target::metal::ProgramDesc> program =
::tt::target::metal::CreateProgramDescDirect(fbb, &kernels);
Expand Down
67 changes: 66 additions & 1 deletion runtime/lib/ttmetal/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) {
}
}

// Convert from Flatbuffer CoreType to soc_descriptor CoreType.
static CoreType toCoreType(::tt::target::metal::CoreType coreType) {
switch (coreType) {
case ::tt::target::metal::CoreType::WORKER:
return CoreType::WORKER;
case ::tt::target::metal::CoreType::ETH:
return CoreType::ETH;
}
throw std::runtime_error("Unsupported core type");
}

static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig(
::tt::target::CBRef const *cbRef,
std::unordered_map<std::uint32_t,
Expand All @@ -192,6 +203,57 @@ static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig(
.set_page_size(cbRef->desc()->port(), cbRef->desc()->page_size());
}

// Process various types of runtime args if present and call Metal APIs.
static void processRuntimeArgs(
::tt::tt_metal::Program &program,
::tt::target::metal::KernelDesc const *kernelDesc,
::tt::tt_metal::KernelHandle &handle, CoreRangeSet &coreRange,
const ::flatbuffers::Vector<::flatbuffers::Offset<tt::target::TensorRef>>
*operands,
std::unordered_map<std::uint32_t,
std::shared_ptr<::tt::tt_metal::Buffer>> const
&buffers) {

using SemaphoreAddr = ::tt::target::metal::RuntimeArgSemaphoreAddress;
using TensorAddr = ::tt::target::metal::RuntimeArgTensorAddress;

const auto *rt_args_types = kernelDesc->runtime_args_type();
const auto *rt_args = kernelDesc->runtime_args();

if (rt_args == nullptr || rt_args_types == nullptr || rt_args->size() == 0 ||
rt_args_types->size() == 0) {
return;
}

assert(rt_args_types->size() == rt_args->size());
std::vector<uint32_t> rt_args_vec;

for (size_t i = 0; i < rt_args->size(); i++) {
switch (rt_args_types->Get(i)) {
case ::tt::target::metal::RuntimeArg::RuntimeArgTensorAddress: {
const auto *rt_arg = static_cast<const TensorAddr *>(rt_args->Get(i));
assert(rt_arg->operand_idx() < operands->size() && "invalid operand");
uint32_t global_id = operands->Get(rt_arg->operand_idx())->global_id();
uint32_t addr = buffers.at(global_id)->address();
rt_args_vec.push_back(addr);
break;
}
case ::tt::target::metal::RuntimeArg::RuntimeArgSemaphoreAddress: {
const auto *rt_arg = static_cast<const SemaphoreAddr *>(rt_args->Get(i));
auto addr = ::tt::tt_metal::CreateSemaphore(
program, coreRange, rt_arg->initial_value(),
toCoreType(rt_arg->core_type()));
rt_args_vec.push_back(addr);
break;
}
case ::tt::target::metal::RuntimeArg::NONE:
throw std::runtime_error("Unsupported runtime arg type");
}
}

::tt::tt_metal::SetRuntimeArgs(program, handle, coreRange, rt_args_vec);
}

void CQExecutor::execute(
::tt::target::metal::EnqueueProgramCommand const *command) {
static int gKernelId = 0;
Expand All @@ -214,13 +276,16 @@ void CQExecutor::execute(
createKernelConfig(kernelSource);
::tt::tt_metal::KernelHandle handle =
::tt::tt_metal::CreateKernel(program, fileName, coreRange, config);
(void)handle; // only needed for runtime args, which aren't supported yet

for (::tt::target::CBRef const *cbRef : *kernelDesc->cbs()) {
::tt::tt_metal::CircularBufferConfig config =
createCircularBufferConfig(cbRef, buffers);
::tt::tt_metal::CreateCircularBuffer(program, coreRange, config);
}

// Process Kernel's runtime args based on variant and call metal APIs.
processRuntimeArgs(program, kernelDesc, handle, coreRange,
command->operands(), buffers);
}

constexpr bool blocking = false;
Expand Down
Loading