From 6bec67a393697313a8e4dda99f856c86847bf455 Mon Sep 17 00:00:00 2001 From: Kyle Mabee <118925087+kmabeeTT@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:45:58 -0400 Subject: [PATCH] Runtime support for RuntimeArgs on KernelDesc (Tensor Addr + Semaphores) #361/#267 (#650) - 2 types of runtime_args today: RuntimeArgTensorAddress, RuntimeArgSemaphoreAddress where runtime will resolve tensor address or semaphore id, and call metal APIs SetRuntimeArgs() and CreateSemaphore() - Compiler support not included, so did some light testing with various dummy values for each type in flatbuffer via TTMetalToFlatbuffer.cpp changes, and tested assertions. For now, keep nullptr in .ttm for runtime_args and runtime_args_types --- include/ttmlir/Target/TTMetal/program.fbs | 20 +++++++ lib/Target/TTMetal/TTMetalToFlatbuffer.cpp | 3 +- runtime/lib/ttmetal/command_queue.cpp | 67 +++++++++++++++++++++- 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/include/ttmlir/Target/TTMetal/program.fbs b/include/ttmlir/Target/TTMetal/program.fbs index 9b81d5285..e69f942ec 100644 --- a/include/ttmlir/Target/TTMetal/program.fbs +++ b/include/ttmlir/Target/TTMetal/program.fbs @@ -18,6 +18,11 @@ enum BinaryType : ushort { ERISC = 5, } +enum CoreType : ushort { + WORKER = 0, + ETH = 1, +} + table KernelSource { source_type: SourceType; source: string; @@ -34,10 +39,25 @@ union Kernel { KernelBinary, } +table RuntimeArgTensorAddress { + operand_idx: uint32; +} + +table RuntimeArgSemaphoreAddress { + initial_value: uint32; + core_type: CoreType; +} + +union RuntimeArg { + RuntimeArgTensorAddress, + RuntimeArgSemaphoreAddress, +} + table KernelDesc { kernel: Kernel; core_range_set: [Dim2dRange]; cbs: [CBRef]; + runtime_args: [RuntimeArg]; debug_info: string; } diff --git a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp index 28af2abbc..3c31cb2c4 100644 --- a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp +++ b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp @@ -189,7 +189,8 @@ static std::shared_ptr translateModuleToFlatbuffer(Operation *op) { ::tt::target::metal::CreateKernelSourceDirect( fbb, toFlatbuffer(threadType), source.c_str()) .Union(), - &coreRangeSet, &cbs, nullptr /*TODO debug info*/)); + &coreRangeSet, &cbs, nullptr, nullptr, /* TODO rtargs*/ + nullptr /*TODO debug info*/)); } ::flatbuffers::Offset<::tt::target::metal::ProgramDesc> program = ::tt::target::metal::CreateProgramDescDirect(fbb, &kernels); diff --git a/runtime/lib/ttmetal/command_queue.cpp b/runtime/lib/ttmetal/command_queue.cpp index d93e012c7..399e4a719 100644 --- a/runtime/lib/ttmetal/command_queue.cpp +++ b/runtime/lib/ttmetal/command_queue.cpp @@ -176,6 +176,17 @@ static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) { } } +// Convert from Flatbuffer CoreType to soc_descriptor CoreType. +static CoreType toCoreType(::tt::target::metal::CoreType coreType) { + switch (coreType) { + case ::tt::target::metal::CoreType::WORKER: + return CoreType::WORKER; + case ::tt::target::metal::CoreType::ETH: + return CoreType::ETH; + } + throw std::runtime_error("Unsupported core type"); +} + static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig( ::tt::target::CBRef const *cbRef, std::unordered_mapdesc()->port(), cbRef->desc()->page_size()); } +// Process various types of runtime args if present and call Metal APIs. +static void processRuntimeArgs( + ::tt::tt_metal::Program &program, + ::tt::target::metal::KernelDesc const *kernelDesc, + ::tt::tt_metal::KernelHandle &handle, CoreRangeSet &coreRange, + const ::flatbuffers::Vector<::flatbuffers::Offset> + *operands, + std::unordered_map> const + &buffers) { + + using SemaphoreAddr = ::tt::target::metal::RuntimeArgSemaphoreAddress; + using TensorAddr = ::tt::target::metal::RuntimeArgTensorAddress; + + const auto *rt_args_types = kernelDesc->runtime_args_type(); + const auto *rt_args = kernelDesc->runtime_args(); + + if (rt_args == nullptr || rt_args_types == nullptr || rt_args->size() == 0 || + rt_args_types->size() == 0) { + return; + } + + assert(rt_args_types->size() == rt_args->size()); + std::vector rt_args_vec; + + for (size_t i = 0; i < rt_args->size(); i++) { + switch (rt_args_types->Get(i)) { + case ::tt::target::metal::RuntimeArg::RuntimeArgTensorAddress: { + const auto *rt_arg = static_cast(rt_args->Get(i)); + assert(rt_arg->operand_idx() < operands->size() && "invalid operand"); + uint32_t global_id = operands->Get(rt_arg->operand_idx())->global_id(); + uint32_t addr = buffers.at(global_id)->address(); + rt_args_vec.push_back(addr); + break; + } + case ::tt::target::metal::RuntimeArg::RuntimeArgSemaphoreAddress: { + const auto *rt_arg = static_cast(rt_args->Get(i)); + auto addr = ::tt::tt_metal::CreateSemaphore( + program, coreRange, rt_arg->initial_value(), + toCoreType(rt_arg->core_type())); + rt_args_vec.push_back(addr); + break; + } + case ::tt::target::metal::RuntimeArg::NONE: + throw std::runtime_error("Unsupported runtime arg type"); + } + } + + ::tt::tt_metal::SetRuntimeArgs(program, handle, coreRange, rt_args_vec); +} + void CQExecutor::execute( ::tt::target::metal::EnqueueProgramCommand const *command) { static int gKernelId = 0; @@ -214,13 +276,16 @@ void CQExecutor::execute( createKernelConfig(kernelSource); ::tt::tt_metal::KernelHandle handle = ::tt::tt_metal::CreateKernel(program, fileName, coreRange, config); - (void)handle; // only needed for runtime args, which aren't supported yet for (::tt::target::CBRef const *cbRef : *kernelDesc->cbs()) { ::tt::tt_metal::CircularBufferConfig config = createCircularBufferConfig(cbRef, buffers); ::tt::tt_metal::CreateCircularBuffer(program, coreRange, config); } + + // Process Kernel's runtime args based on variant and call metal APIs. + processRuntimeArgs(program, kernelDesc, handle, coreRange, + command->operands(), buffers); } constexpr bool blocking = false;