Skip to content

Commit

Permalink
#1190: Added runtime support for doing golden comparision for flatbuf…
Browse files Browse the repository at this point in the history
…fers in ttrt
  • Loading branch information
tapspatel committed Nov 14, 2024
1 parent 78ace9c commit 5a40629
Show file tree
Hide file tree
Showing 16 changed files with 598 additions and 63 deletions.
47 changes: 47 additions & 0 deletions runtime/include/tt/runtime/detail/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
#ifndef TT_RUNTIME_DETAIL_DEBUG_H
#define TT_RUNTIME_DETAIL_DEBUG_H

#include <functional>
#include <optional>
#include <ostream>

#include "tt/runtime/types.h"

namespace tt::runtime::debug {

struct Env {
Expand Down Expand Up @@ -41,6 +45,49 @@ inline std::ostream &operator<<(std::ostream &os, Env const &env) {
return os;
}

struct Hooks {
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
static Hooks const &
get(std::optional<std::function<void(std::optional<CallbackContext>,
std::optional<OpContext>)>>
operatorCallback = std::nullopt);
#else
constexpr static Hooks get() { return Hooks(); }
#endif

std::optional<std::function<void(std::optional<CallbackContext>,
std::optional<OpContext>)>>
getOperatorCallback() const {
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
return operatorCallback;
#else
return std::nullopt;
#endif
}

private:
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
Hooks(std::optional<std::function<void(std::optional<CallbackContext>,
std::optional<OpContext>)>>
operatorCallback)
: operatorCallback(operatorCallback) {}

std::optional<std::function<void(std::optional<CallbackContext>,
std::optional<OpContext>)>>
operatorCallback;
#else
constexpr Hooks() = default;
#endif
};

inline std::ostream &operator<<(std::ostream &os, Hooks const &hooks) {
os << "debug::Hooks{\n"
<< "\t"
<< "operatorCallback: " << bool(hooks.getOperatorCallback()) << ",\n"
<< "}";
return os;
}

} // namespace tt::runtime::debug

#endif // TT_RUNTIME_DETAIL_DEBUG_H
3 changes: 2 additions & 1 deletion runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs,
std::unordered_map<std::string, Tensor> const &goldens = {});

void wait(Event event);

Expand Down
13 changes: 11 additions & 2 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ constexpr std::size_t kL1SmallSize = 1 << 15;

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();

template <typename StorageType, typename ElementType>
StorageType createStorage(ElementType *ptr, std::uint32_t numElements);

template <typename StorageType>
StorageType createStorage(void *ptr, std::uint32_t numElements,
::tt::target::DataType dataType);

Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride,
Expand Down Expand Up @@ -116,14 +123,16 @@ void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs,
std::unordered_map<std::string, Tensor> const &goldens = {});

void wait(Event event);

void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);
std::vector<::ttnn::Tensor *> const &outputs,
std::unordered_map<std::string, Tensor> const &goldens);

} // namespace tt::runtime::ttnn

Expand Down
3 changes: 2 additions & 1 deletion runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ void closeDevice(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs,
std::unordered_map<std::string, Tensor> const &goldens = {});

void wait(Event event);

Expand Down
19 changes: 17 additions & 2 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,24 @@ struct Event : public detail::RuntimeCheckedObjectImpl {

struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;
int volume;
Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
DeviceRuntime runtime, int volume)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
volume(volume) {}

std::vector<float> getData();
};

struct CallbackContext : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
Tensor getDebugInfoGolden(std::string loc);
};

struct OpContext : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
Tensor getOpOutputTensor(CallbackContext context);
std::string getOpDebugString();
};

} // namespace tt::runtime
Expand Down
12 changes: 12 additions & 0 deletions runtime/lib/common/debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ Env const &Env::get(bool loadKernelsFromDisk, bool enableAsyncTTNN) {
return config;
}

#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
Hooks const &
Hooks::get(std::optional<std::function<void(std::optional<CallbackContext>,
std::optional<OpContext>)>>
operatorCallback) {
static Hooks config(operatorCallback);
return config;
}
#else
Hooks get() { return Hooks(); }
#endif

} // namespace tt::runtime::debug

#endif
171 changes: 168 additions & 3 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#if defined(TT_RUNTIME_ENABLE_TTNN)
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
Expand Down Expand Up @@ -211,20 +212,21 @@ void closeDevice(Device device) {
Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
std::vector<Tensor> const &outputHandles,
std::unordered_map<std::string, Tensor> const &goldenHandles) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
outputHandles, goldenHandles);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
outputHandles, goldenHandles);
}
#endif
throw std::runtime_error("runtime is not enabled");
Expand All @@ -245,4 +247,167 @@ void wait(Event event) {
throw std::runtime_error("runtime is not enabled");
}

std::vector<float> Tensor::getData() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return std::vector<float>(static_cast<float *>(this->data.get()),
static_cast<float *>(this->data.get()) +
this->volume);
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
LOG_WARNING("Getting data for ttmetal tensor is not enabled yet");
return {};
#endif
}

Tensor OpContext::getOpOutputTensor(CallbackContext context) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
auto *contextPtr =
static_cast<tt::runtime::ttnn::ProgramContext *>(context.handle.get());
auto *opContextPtr =
static_cast<::tt::target::ttnn::Operation *>(this->handle.get());
const ttnn::ProgramTensorPool &tensorPool = contextPtr->getTensorPool();
std::int32_t globalId{-1};
const ::ttnn::Tensor *outPtr = nullptr;

switch (opContextPtr->type_type()) {
case ::tt::target::ttnn::OpType::GetDeviceOp: {
globalId = opContextPtr->type_as_GetDeviceOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::ToMemoryConfigOp: {
globalId = opContextPtr->type_as_ToMemoryConfigOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::ToLayoutOp: {
globalId = opContextPtr->type_as_ToLayoutOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::TypecastOp: {
globalId = opContextPtr->type_as_TypecastOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::ToDeviceOp: {
globalId = opContextPtr->type_as_ToDeviceOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::FromDeviceOp: {
globalId = opContextPtr->type_as_FromDeviceOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::EmptyOp: {
globalId = opContextPtr->type_as_EmptyOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::FullOp: {
globalId = opContextPtr->type_as_FullOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::EltwiseOp: {
globalId = opContextPtr->type_as_EltwiseOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::MatmulOp: {
globalId = opContextPtr->type_as_MatmulOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::ReductionOp: {
globalId = opContextPtr->type_as_ReductionOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::EmbeddingOp: {
globalId = opContextPtr->type_as_EmbeddingOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::SoftmaxOp: {
globalId = opContextPtr->type_as_SoftmaxOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::TransposeOp: {
globalId = opContextPtr->type_as_TransposeOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::ConcatOp: {
globalId = opContextPtr->type_as_ConcatOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::ReshapeOp: {
globalId = opContextPtr->type_as_ReshapeOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::SliceOp: {
globalId = opContextPtr->type_as_SliceOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::Conv2dOp: {
globalId = opContextPtr->type_as_Conv2dOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::MaxPool2dOp: {
globalId = opContextPtr->type_as_MaxPool2dOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::AllGatherOp: {
globalId = opContextPtr->type_as_AllGatherOp()->out()->global_id();
break;
}
case ::tt::target::ttnn::OpType::DeallocateOp: {
LOG_WARNING("getting output tensor for DeallocateOp is not supported");
return Tensor(nullptr, nullptr, DeviceRuntime::TTNN, 0);
}
default: {
throw std::runtime_error("Unsupported operation type");
}
}

if (tensorPool.contains(globalId)) {
outPtr = &tensorPool.at(globalId);
} else {
LOG_WARNING("Output tensor not found in tensor pool");
return Tensor(nullptr, nullptr, DeviceRuntime::TTNN, 0);
}

::ttnn::Tensor hostTensor = ::ttnn::from_device(*outPtr);
::ttnn::Tensor outCopy =
::ttnn::to_layout(hostTensor, ::ttnn::ROW_MAJOR_LAYOUT, std::nullopt,
std::nullopt, static_cast<::ttnn::Device *>(nullptr));

void *src = ::tt::tt_metal::get_raw_host_data_ptr(outCopy);
std::uint32_t outCopySize = outCopy.volume() * outCopy.element_size();
std::shared_ptr<void> data = ::tt::runtime::utils::malloc_shared(outCopySize);
std::memcpy(data.get(), src, outCopySize);

auto tensor = std::make_shared<::ttnn::Tensor>(
ttnn::createStorage<BorrowedStorage>(data.get(), outCopy.volume(),
::tt::target::DataType::Float32),
outCopy.shape().value, ::ttnn::DataType::FLOAT32,
::ttnn::Layout::ROW_MAJOR);

return Tensor(std::static_pointer_cast<void>(tensor), data,
DeviceRuntime::TTNN, outCopy.volume());
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
LOG_WARNING("Getting device tensor for ttmetal runtime is not enabled yet!");
return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal, 0);
#endif
}

std::string OpContext::getOpDebugString() {
auto *opContextPtr =
static_cast<::tt::target::ttnn::Operation *>(this->handle.get());
return std::string(opContextPtr->debug_info()->c_str());
}

Tensor CallbackContext::getDebugInfoGolden(std::string loc) {
auto *contextPtr =
static_cast<tt::runtime::ttnn::ProgramContext *>(this->handle.get());
if (contextPtr->getGoldenMap().contains(loc)) {
return contextPtr->getGoldenMap().at(loc);
}

LOG_WARNING("Golden tensor not found!");
return Tensor(nullptr, nullptr, DeviceRuntime::TTNN, 0);
}

} // namespace tt::runtime
9 changes: 6 additions & 3 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ Tensor createTensor(std::shared_ptr<void> data,
desc.itemsize = itemsize;
desc.dataType = dataType;
std::shared_ptr<MetalTensor> tensor = std::make_shared<MetalTensor>(desc);
return Tensor(static_pointer_cast<void>(tensor), data,
DeviceRuntime::TTMetal);
std::uint32_t volume = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<std::uint32_t>());
return Tensor(static_pointer_cast<void>(tensor), data, DeviceRuntime::TTMetal,
volume);
}

tt::target::DataType getTensorDataType(Tensor tensor) {
Expand Down Expand Up @@ -169,7 +171,8 @@ Events maybeCopyHostOutputs(::tt::tt_metal::Device *device,
Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
std::vector<Tensor> const &outputHandles,
std::unordered_map<std::string, Tensor> const &goldenHandles) {
::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle);
::tt::target::metal::Program const *program =
fbb.programs()->Get(programIndex);
Expand Down
Loading

0 comments on commit 5a40629

Please sign in to comment.