From 5a4062968b215b4e24c6c03e71b14a9adac1ab8d Mon Sep 17 00:00:00 2001 From: Tapasvi Patel Date: Fri, 8 Nov 2024 18:46:37 +0000 Subject: [PATCH] #1190: Added runtime support for doing golden comparision for flatbuffers in ttrt --- runtime/include/tt/runtime/detail/debug.h | 47 +++++ runtime/include/tt/runtime/detail/ttmetal.h | 3 +- runtime/include/tt/runtime/detail/ttnn.h | 13 +- runtime/include/tt/runtime/runtime.h | 3 +- runtime/include/tt/runtime/types.h | 19 +- runtime/lib/common/debug.cpp | 12 ++ runtime/lib/runtime.cpp | 171 +++++++++++++++++- runtime/lib/ttmetal/runtime.cpp | 9 +- .../lib/ttnn/include/tt/runtime/ttnn/types.h | 19 +- runtime/lib/ttnn/program.cpp | 23 ++- runtime/lib/ttnn/runtime.cpp | 19 +- runtime/tools/python/ttrt/common/golden.py | 157 ++++++++++++++++ runtime/tools/python/ttrt/common/run.py | 32 ++++ runtime/tools/python/ttrt/common/util.py | 100 ++++++---- runtime/tools/python/ttrt/runtime/__init__.py | 1 + runtime/tools/python/ttrt/runtime/module.cpp | 33 +++- 16 files changed, 598 insertions(+), 63 deletions(-) create mode 100644 runtime/tools/python/ttrt/common/golden.py diff --git a/runtime/include/tt/runtime/detail/debug.h b/runtime/include/tt/runtime/detail/debug.h index c5d84c4d98..d7fdb96917 100644 --- a/runtime/include/tt/runtime/detail/debug.h +++ b/runtime/include/tt/runtime/detail/debug.h @@ -5,8 +5,12 @@ #ifndef TT_RUNTIME_DETAIL_DEBUG_H #define TT_RUNTIME_DETAIL_DEBUG_H +#include +#include #include +#include "tt/runtime/types.h" + namespace tt::runtime::debug { struct Env { @@ -41,6 +45,49 @@ inline std::ostream &operator<<(std::ostream &os, Env const &env) { return os; } +struct Hooks { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + static Hooks const & + get(std::optional, + std::optional)>> + operatorCallback = std::nullopt); +#else + constexpr static Hooks get() { return Hooks(); } +#endif + + std::optional, + std::optional)>> + getOperatorCallback() const { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + return operatorCallback; +#else + return std::nullopt; +#endif + } + +private: +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + Hooks(std::optional, + std::optional)>> + operatorCallback) + : operatorCallback(operatorCallback) {} + + std::optional, + std::optional)>> + operatorCallback; +#else + constexpr Hooks() = default; +#endif +}; + +inline std::ostream &operator<<(std::ostream &os, Hooks const &hooks) { + os << "debug::Hooks{\n" + << "\t" + << "operatorCallback: " << bool(hooks.getOperatorCallback()) << ",\n" + << "}"; + return os; +} + } // namespace tt::runtime::debug #endif // TT_RUNTIME_DETAIL_DEBUG_H diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 01bb9c86eb..4b207adc0c 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -75,7 +75,8 @@ void deallocateBuffers(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, + std::unordered_map const &goldens = {}); void wait(Event event); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index c027d01585..e51c652ae3 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -80,6 +80,13 @@ constexpr std::size_t kL1SmallSize = 1 << 15; std::pair getCurrentSystemDesc(); +template +StorageType createStorage(ElementType *ptr, std::uint32_t numElements); + +template +StorageType createStorage(void *ptr, std::uint32_t numElements, + ::tt::target::DataType dataType); + Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, @@ -116,14 +123,16 @@ void deallocateBuffers(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, + std::unordered_map const &goldens = {}); void wait(Event event); void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs); + std::vector<::ttnn::Tensor *> const &outputs, + std::unordered_map const &goldens); } // namespace tt::runtime::ttnn diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index a070f2f0f5..b11dd7075a 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -65,7 +65,8 @@ void closeDevice(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, + std::unordered_map const &goldens = {}); void wait(Event event); diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index 330fb91965..92f70b4891 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -120,9 +120,24 @@ struct Event : public detail::RuntimeCheckedObjectImpl { struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; + int volume; Tensor(std::shared_ptr handle, std::shared_ptr data, - DeviceRuntime runtime) - : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {} + DeviceRuntime runtime, int volume) + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + volume(volume) {} + + std::vector getData(); +}; + +struct CallbackContext : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; + Tensor getDebugInfoGolden(std::string loc); +}; + +struct OpContext : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; + Tensor getOpOutputTensor(CallbackContext context); + std::string getOpDebugString(); }; } // namespace tt::runtime diff --git a/runtime/lib/common/debug.cpp b/runtime/lib/common/debug.cpp index f075177642..669cc12c2e 100644 --- a/runtime/lib/common/debug.cpp +++ b/runtime/lib/common/debug.cpp @@ -13,6 +13,18 @@ Env const &Env::get(bool loadKernelsFromDisk, bool enableAsyncTTNN) { return config; } +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 +Hooks const & +Hooks::get(std::optional, + std::optional)>> + operatorCallback) { + static Hooks config(operatorCallback); + return config; +} +#else +Hooks get() { return Hooks(); } +#endif + } // namespace tt::runtime::debug #endif diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 8b0e79daab..0b68365a7d 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -10,6 +10,7 @@ #if defined(TT_RUNTIME_ENABLE_TTNN) #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) @@ -211,12 +212,13 @@ void closeDevice(Device device) { Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, + std::unordered_map const &goldenHandles) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, programIndex, inputHandles, - outputHandles); + outputHandles, goldenHandles); } #endif @@ -224,7 +226,7 @@ Event submit(Device deviceHandle, Binary executableHandle, if (getCurrentRuntime() == DeviceRuntime::TTMetal) { return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, programIndex, inputHandles, - outputHandles); + outputHandles, goldenHandles); } #endif throw std::runtime_error("runtime is not enabled"); @@ -245,4 +247,167 @@ void wait(Event event) { throw std::runtime_error("runtime is not enabled"); } +std::vector Tensor::getData() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + return std::vector(static_cast(this->data.get()), + static_cast(this->data.get()) + + this->volume); +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + LOG_WARNING("Getting data for ttmetal tensor is not enabled yet"); + return {}; +#endif +} + +Tensor OpContext::getOpOutputTensor(CallbackContext context) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + auto *contextPtr = + static_cast(context.handle.get()); + auto *opContextPtr = + static_cast<::tt::target::ttnn::Operation *>(this->handle.get()); + const ttnn::ProgramTensorPool &tensorPool = contextPtr->getTensorPool(); + std::int32_t globalId{-1}; + const ::ttnn::Tensor *outPtr = nullptr; + + switch (opContextPtr->type_type()) { + case ::tt::target::ttnn::OpType::GetDeviceOp: { + globalId = opContextPtr->type_as_GetDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToMemoryConfigOp: { + globalId = opContextPtr->type_as_ToMemoryConfigOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToLayoutOp: { + globalId = opContextPtr->type_as_ToLayoutOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::TypecastOp: { + globalId = opContextPtr->type_as_TypecastOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ToDeviceOp: { + globalId = opContextPtr->type_as_ToDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::FromDeviceOp: { + globalId = opContextPtr->type_as_FromDeviceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::EmptyOp: { + globalId = opContextPtr->type_as_EmptyOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::FullOp: { + globalId = opContextPtr->type_as_FullOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::EltwiseOp: { + globalId = opContextPtr->type_as_EltwiseOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::MatmulOp: { + globalId = opContextPtr->type_as_MatmulOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ReductionOp: { + globalId = opContextPtr->type_as_ReductionOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::EmbeddingOp: { + globalId = opContextPtr->type_as_EmbeddingOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::SoftmaxOp: { + globalId = opContextPtr->type_as_SoftmaxOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::TransposeOp: { + globalId = opContextPtr->type_as_TransposeOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ConcatOp: { + globalId = opContextPtr->type_as_ConcatOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::ReshapeOp: { + globalId = opContextPtr->type_as_ReshapeOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::SliceOp: { + globalId = opContextPtr->type_as_SliceOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::Conv2dOp: { + globalId = opContextPtr->type_as_Conv2dOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::MaxPool2dOp: { + globalId = opContextPtr->type_as_MaxPool2dOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::AllGatherOp: { + globalId = opContextPtr->type_as_AllGatherOp()->out()->global_id(); + break; + } + case ::tt::target::ttnn::OpType::DeallocateOp: { + LOG_WARNING("getting output tensor for DeallocateOp is not supported"); + return Tensor(nullptr, nullptr, DeviceRuntime::TTNN, 0); + } + default: { + throw std::runtime_error("Unsupported operation type"); + } + } + + if (tensorPool.contains(globalId)) { + outPtr = &tensorPool.at(globalId); + } else { + LOG_WARNING("Output tensor not found in tensor pool"); + return Tensor(nullptr, nullptr, DeviceRuntime::TTNN, 0); + } + + ::ttnn::Tensor hostTensor = ::ttnn::from_device(*outPtr); + ::ttnn::Tensor outCopy = + ::ttnn::to_layout(hostTensor, ::ttnn::ROW_MAJOR_LAYOUT, std::nullopt, + std::nullopt, static_cast<::ttnn::Device *>(nullptr)); + + void *src = ::tt::tt_metal::get_raw_host_data_ptr(outCopy); + std::uint32_t outCopySize = outCopy.volume() * outCopy.element_size(); + std::shared_ptr data = ::tt::runtime::utils::malloc_shared(outCopySize); + std::memcpy(data.get(), src, outCopySize); + + auto tensor = std::make_shared<::ttnn::Tensor>( + ttnn::createStorage(data.get(), outCopy.volume(), + ::tt::target::DataType::Float32), + outCopy.shape().value, ::ttnn::DataType::FLOAT32, + ::ttnn::Layout::ROW_MAJOR); + + return Tensor(std::static_pointer_cast(tensor), data, + DeviceRuntime::TTNN, outCopy.volume()); +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + LOG_WARNING("Getting device tensor for ttmetal runtime is not enabled yet!"); + return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal, 0); +#endif +} + +std::string OpContext::getOpDebugString() { + auto *opContextPtr = + static_cast<::tt::target::ttnn::Operation *>(this->handle.get()); + return std::string(opContextPtr->debug_info()->c_str()); +} + +Tensor CallbackContext::getDebugInfoGolden(std::string loc) { + auto *contextPtr = + static_cast(this->handle.get()); + if (contextPtr->getGoldenMap().contains(loc)) { + return contextPtr->getGoldenMap().at(loc); + } + + LOG_WARNING("Golden tensor not found!"); + return Tensor(nullptr, nullptr, DeviceRuntime::TTNN, 0); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 931349429b..45004162c2 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -39,8 +39,10 @@ Tensor createTensor(std::shared_ptr data, desc.itemsize = itemsize; desc.dataType = dataType; std::shared_ptr tensor = std::make_shared(desc); - return Tensor(static_pointer_cast(tensor), data, - DeviceRuntime::TTMetal); + std::uint32_t volume = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + return Tensor(static_pointer_cast(tensor), data, DeviceRuntime::TTMetal, + volume); } tt::target::DataType getTensorDataType(Tensor tensor) { @@ -169,7 +171,8 @@ Events maybeCopyHostOutputs(::tt::tt_metal::Device *device, Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, + std::unordered_map const &goldenHandles) { ::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle); ::tt::target::metal::Program const *program = fbb.programs()->Get(programIndex); diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h index e59bb66d60..86bc6f6bfb 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h @@ -46,6 +46,11 @@ class ProgramTensorPool { return *liveTensors.at(globalId); } + const ::ttnn::Tensor &at(std::uint32_t globalId) const { + assert(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); + } + size_t erase(std::uint32_t globalId) { assert(liveTensors.contains(globalId) && intermedTensors.contains(globalId)); @@ -97,10 +102,11 @@ class ProgramContext { ProgramContext(const TensorMap &liveTensors, const std::unordered_set &programInputs, const std::unordered_set &programOutputs, + const std::unordered_map &goldenTensors, ::ttnn::MeshDevice *parentMesh) : tensorPool( ProgramTensorPool(liveTensors, programInputs, programOutputs)), - parentMesh(parentMesh) { + parentMesh(parentMesh), goldenTensors(goldenTensors) { assert(parentMesh && "Parent mesh cannot be null"); } ProgramContext(const ProgramContext &) = delete; @@ -161,6 +167,14 @@ class ProgramContext { // Tensor Pool Operations // ProgramTensorPool &getTensorPool() { return tensorPool; } + const ProgramTensorPool &getTensorPool() const { return tensorPool; } + + // + // Golden Tensor Operations + // + const std::unordered_map &getGoldenMap() { + return this->goldenTensors; + } private: ProgramTensorPool tensorPool; @@ -172,6 +186,9 @@ class ProgramContext { // Contains subMeshes of the parentMesh that are used by the program // Will be populated by GetDevice ops std::unordered_map> subMeshes; + + // Golden map of all golden tensors and their locations + std::unordered_map goldenTensors; }; } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index c72b20a456..e36b1eadc3 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -26,8 +26,10 @@ #include "operations/normalization/softmax.h" #include "operations/pool/maxpool2d.h" #include "operations/reduction/reduction.h" +#include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" +#include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" namespace tt::runtime::ttnn { @@ -38,15 +40,26 @@ class ProgramExecutor { ProgramExecutor(const TensorMap &liveTensors, const std::unordered_set &programInputs, const std::unordered_set &programOutputs, + const std::unordered_map &programGoldens, ::ttnn::MeshDevice *meshDevice) : context(ProgramContext(liveTensors, programInputs, programOutputs, - meshDevice)) {} + programGoldens, meshDevice)) {} void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { LOG_DEBUG(LogType::LogRuntimeTTNN, "Executing operation: ", op->debug_info()->c_str()); runOperation(op); + if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { + std::shared_ptr contextPtr = + ::tt::runtime::utils::unsafe_borrow_shared(&context); + std::shared_ptr opPtr = + ::tt::runtime::utils::unsafe_borrow_shared( + const_cast<::tt::target::ttnn::Operation *>(op)); + + (*callback)(CallbackContext(contextPtr, DeviceRuntime::TTNN), + OpContext(opPtr, DeviceRuntime::TTNN)); + } } } @@ -117,8 +130,7 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { return operations::creation::run(op->type_as_FullOp(), context); } case ::tt::target::ttnn::OpType::EltwiseOp: { - const ::tt::target::ttnn::EltwiseOp *eltwiseOp = op->type_as_EltwiseOp(); - return runEltwiseOperation(eltwiseOp); + return runEltwiseOperation(op->type_as_EltwiseOp()); } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: { @@ -186,7 +198,8 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program, void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs) { + std::vector<::ttnn::Tensor *> const &outputs, + std::unordered_map const &goldens) { if (handleNopProgram(program, inputs, outputs)) { return; } @@ -212,7 +225,7 @@ void runProgram(::ttnn::MeshDevice &meshDevice, LOG_ASSERT(inserted, "Duplicate output tensor"); programOutputs.emplace(output->global_id()); } - ProgramExecutor executor(liveTensors, programInputs, programOutputs, + ProgramExecutor executor(liveTensors, programInputs, programOutputs, goldens, &meshDevice); executor.execute(program); } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 9eba9986e6..927fdc3fe0 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -5,6 +5,7 @@ #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" #include "tt/runtime/ttnn/utils.h" #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/Target.h" @@ -24,7 +25,7 @@ using ::tt::tt_metal::raise_unsupported_storage; using ::tt::tt_metal::ShardTensor; template -static StorageType createStorage(ElementType *ptr, std::uint32_t numElements) { +StorageType createStorage(ElementType *ptr, std::uint32_t numElements) { if constexpr (std::is_same_v) { return BorrowedStorage( ::tt::tt_metal::borrowed_buffer::Buffer(ptr, numElements), @@ -39,8 +40,8 @@ static StorageType createStorage(ElementType *ptr, std::uint32_t numElements) { } template -static StorageType createStorage(void *ptr, std::uint32_t numElements, - ::tt::target::DataType dataType) { +StorageType createStorage(void *ptr, std::uint32_t numElements, + ::tt::target::DataType dataType) { switch (dataType) { case ::tt::target::DataType::Float32: return createStorage(static_cast(ptr), numElements); @@ -88,7 +89,7 @@ Tensor createTensor(std::shared_ptr data, ::ttnn::Shape(small_vector_shape), utils::toTTNNDataType(dataType), ::ttnn::Layout::ROW_MAJOR); return Tensor(std::static_pointer_cast(tensor), data, - DeviceRuntime::TTNN); + DeviceRuntime::TTNN, tensor.get()->volume()); } Tensor @@ -114,7 +115,7 @@ createTensor(std::vector> &data, std::make_shared>>(data); return Tensor(std::static_pointer_cast(tensor), std::static_pointer_cast(borrowedData), - DeviceRuntime::TTNN); + DeviceRuntime::TTNN, tensor.get()->volume()); } tt::target::DataType getTensorDataType(Tensor tensor) { @@ -176,24 +177,28 @@ static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, + std::unordered_map const &goldenHandles) { ::ttnn::MeshDevice &meshDevice = deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); + std::vector<::ttnn::Tensor *> inputs; inputs.reserve(inputHandles.size()); for (auto &input : inputHandles) { LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); } + std::vector<::ttnn::Tensor *> outputs; outputs.reserve(outputHandles.size()); for (auto &output : outputHandles) { LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); } + tt::runtime::ttnn::runProgram(meshDevice, fbb.programs()->Get(programIndex), - inputs, outputs); + inputs, outputs, goldenHandles); return Event(nullptr, DeviceRuntime::TTNN); } diff --git a/runtime/tools/python/ttrt/common/golden.py b/runtime/tools/python/ttrt/common/golden.py new file mode 100644 index 0000000000..9f20b5069f --- /dev/null +++ b/runtime/tools/python/ttrt/common/golden.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import json +import importlib.machinery +import sys +import signal +import os +import io +import subprocess +import time +import socket +from pkg_resources import get_distribution +import shutil +import atexit +import re + +from ttrt.common.util import * + + +def get_atol_rtol_pcc(golden, calculated): + import numpy as np + import torch + + # Calculate atol and rtol + cal_atol = torch.max(torch.abs(golden - calculated)).item() + cal_rtol = torch.max(torch.abs(golden - calculated) / torch.abs(calculated)).item() + + # Calculate PCC + def get_pcc(golden, calculated): + # Both tensors are nan + if torch.all(torch.isnan(golden)) and torch.all(torch.isnan(calculated)): + print("Both tensors are 'nan'") + return 1.0 + # One tensor is all nan, the other is not + elif torch.all(torch.isnan(golden)) or torch.all(torch.isnan(calculated)): + print("One tensor is all nan, the other is not.") + return 0.0 + else: + # For now, mask all infs and nans so that we check the rest... TODO + golden = golden.clone() + golden[ + torch.logical_or( + torch.isnan(golden), + torch.logical_or(torch.isinf(golden), torch.isneginf(golden)), + ) + ] = 0 + calculated = calculated.clone() + calculated[ + torch.logical_or( + torch.isnan(calculated), + torch.logical_or( + torch.isinf(calculated), torch.isneginf(calculated) + ), + ) + ] = 0 + + if torch.equal(golden, calculated): + return 1.0 + + if golden.dtype == torch.bfloat16: + golden = golden.type(torch.float32) + calculated = calculated.type(torch.float32) + + # Single element case + if golden.numel() == 1: + return float(torch.equal(golden, calculated)) + + # If both tensors are contant + if torch.max(golden) == torch.min(golden) and torch.max( + calculated + ) == torch.min(calculated): + return torch.isclose(torch.max(golden), torch.max(calculated)).item() + + cal_pcc = np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(calculated).detach().numpy() + ).flatten(), + ) + # Remove correlation coefficient with self (typically always 1.0) + mask = np.ones(cal_pcc.shape, dtype=bool) + np.fill_diagonal(mask, 0) + cal_pcc = np.min(cal_pcc[mask]) + + if isinstance(cal_pcc, np.ma.core.MaskedConstant): + return 1.0 + + return cal_pcc + + cal_pcc = get_pcc(golden, calculated) + + return ( + cal_atol, + cal_rtol, + cal_pcc, + f"Max ATOL Delta: {cal_atol}, Max RTOL Delta: {cal_rtol}, PCC: {cal_pcc}", + ) + + +def golden(context=None, opContext=None): + import torch + import ttrt.runtime + + print("-----------executing golden comparision-----------") + + try: + op_debug_str = opContext.get_op_debug_str() + + # find matching golden tensor based on loc in op debug string + match = re.search(r"loc\(([^)]+)\)", op_debug_str) + + if not match: + print(f"debug_str={op_debug_str}") + print("No location found in debug string - skipping golden comparison") + return + + loc = match.group(1).replace('"', "") + print(f"found location={loc}") + + golden_tensor = context.get_debug_info_golden(loc) + output_tensor = opContext.get_op_output_tensor(context) + + golden_tensor_list = golden_tensor.get_data() + output_tensor_list = output_tensor.get_data() + + if len(golden_tensor_list) == 0: + print("Golden tensor is empty - skipping golden comparison") + return + + if len(output_tensor_list) == 0: + print("Output tensor is empty - skipping golden comparison") + return + + if len(golden_tensor_list) != len(output_tensor_list): + print( + "Golden and output tensor sizes do not match - skipping golden comparison" + ) + return + + golden_tensor_torch = torch.tensor( + golden_tensor_list, dtype=torch.float32 + ).flatten() + output_tensor_torch = torch.tensor( + output_tensor_list, dtype=torch.float32 + ).flatten() + + _, _, cal_pcc, output_str = get_atol_rtol_pcc( + golden_tensor_torch, output_tensor_torch + ) + + print(f"PCC={cal_pcc}") + print(output_str) + finally: + print("-----------finished executing golden comparision-----------") diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 0fe97f093c..e8e1d13901 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -18,6 +18,7 @@ from ttrt.common.util import * from ttrt.common.query import Query +from ttrt.common.golden import golden class Run: @@ -172,6 +173,13 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Run.register_arg( + name="--golden", + type=bool, + default=False, + choices=[True, False], + help="run golden comparison for intermediate and output tensors", + ) Run.register_arg( name="binary", type=str, @@ -361,6 +369,9 @@ def _execute(binaries): self.logging.warning(f"no binaries found to run - returning early") return + if self["--golden"]: + callback_env = ttrt.runtime.DebugHooks.get(golden) + debug_env = ttrt.runtime.DebugEnv.get( self["--load-kernels-from-disk"], self["--enable-async-ttnn"] ) @@ -403,9 +414,11 @@ def _execute(binaries): program.populate_outputs( Run.TorchInitializer.get_initilizer("zeros") ) + program.populate_goldens() total_inputs = [] total_outputs = [] + total_goldens = [] for loop in range(self["--loops"]): self.logging.debug( f"generating inputs/outputs for loop={loop+1}/{self['--loops']} for binary={bin.file_path}" @@ -438,6 +451,24 @@ def _execute(binaries): total_inputs.append(inputs) total_outputs.append(outputs) + self.logging.debug( + f"generating golden map for loop={loop+1}/{self['--loops']} for binary={bin.file_path}" + ) + goldens = {} + + for key, golden_obj in program.golden_map.map.items(): + goldens[key] = ttrt.runtime.create_tensor( + golden_obj.torch_tensor.data_ptr(), + list(golden_obj.tensor_shape), + list(golden_obj.tensor_stride), + golden_obj.torch_tensor.element_size(), # 4 bytes - float32 + Binary.Program.to_data_type( + golden_obj.torch_tensor.dtype + ), + ) + + total_goldens.append(goldens) + event = None for loop in range(self["--loops"]): self.logging.debug( @@ -450,6 +481,7 @@ def _execute(binaries): program_index, total_inputs[loop], total_outputs[loop], + total_goldens[loop], ) self.logging.debug( diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index ebbf1d6d72..0ef82208f6 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -522,20 +522,40 @@ def get_ttsys_file_extension(): return Flatbuffer.ttsys_file_extension -class Golden: - def __init__(self, tensor_id, tensor_shape, tensor_stride, tensor_data): - self.tensor_id = tensor_id - self.tensor_shape = tensor_shape - self.tensor_stride = tensor_stride - self.tensor_data = tensor_data - - def get_golden_tensor(self): - tensor_byte_data = bytes(self.tensor_data) - float_data = np.frombuffer(tensor_byte_data, dtype=np.float32) - golden_tensor = torch.tensor(float_data, dtype=torch.float32).reshape( - self.tensor_shape - ) - return golden_tensor +class GoldenMap: + def __init__(self): + self.map = {} + + def add_golden(self, element): + self.map[element.tensor_id] = element + + def get_golden(self, tensor_id): + return self.map[tensor_id] + + def get_inputs(self): + inputs = [] + + for i, tensor in self.map.items(): + if i.startswith("input"): + inputs.append(tensor) + + return inputs + + class Golden: + def __init__(self, tensor_id, tensor_shape, tensor_stride, tensor_data): + import torch + import numpy as np + + self.tensor_id = tensor_id + self.tensor_shape = tensor_shape + self.tensor_stride = tensor_stride + self.tensor_data = tensor_data + + tensor_byte_data = bytes(self.tensor_data) + float_data = np.frombuffer(tensor_byte_data, dtype=np.float32) + self.torch_tensor = torch.tensor(float_data, dtype=torch.float32).reshape( + self.tensor_shape + ) class Binary(Flatbuffer): @@ -557,20 +577,6 @@ def __init__(self, logger, file_manager, file_path, capsule=None): program = Binary.Program(i, self.fbb_dict["programs"][i]) self.programs.append(program) - # populate golden tensors if they exist - if "debug_info" in self.fbb_dict["programs"][i]: - golden_info_list = self.fbb_dict["programs"][i]["debug_info"][ - "golden_info" - ]["golden_map"] - - for golden_tensor_dict in golden_info_list: - Golden( - golden_tensor_dict["key"], - golden_tensor_dict["value"]["shape"], - golden_tensor_dict["value"]["stride"], - golden_tensor_dict["value"]["data"], - ) - def check_system_desc(self, query): import ttrt.binary @@ -615,16 +621,23 @@ def __init__(self, index, program): self.program = program self.input_tensors = [] self.output_tensors = [] + self.golden_map = GoldenMap() def populate_inputs(self, init_fn): - for i in self.program["inputs"]: - torch_tensor = init_fn( - i["desc"]["shape"], - dtype=Binary.Program.from_data_type( - i["desc"]["layout"]["memory_desc"]["data_type"] - ), - ) - self.input_tensors.append(torch_tensor) + inputs = self.golden_map.get_inputs() + + if len(inputs) != 0: + for input in inputs: + self.input_tensors.append(input.torch_tensor) + else: + for i in self.program["inputs"]: + torch_tensor = init_fn( + i["desc"]["shape"], + dtype=Binary.Program.from_data_type( + i["desc"]["layout"]["memory_desc"]["data_type"] + ), + ) + self.input_tensors.append(torch_tensor) def populate_outputs(self, init_fn): for i in self.program["outputs"]: @@ -636,6 +649,21 @@ def populate_outputs(self, init_fn): ) self.output_tensors.append(torch_tensor) + def populate_goldens(self): + # populate golden tensors if they exist + if "debug_info" in self.program: + golden_info_list = self.program["debug_info"]["golden_info"][ + "golden_map" + ] + for golden_tensor_dict in golden_info_list: + golden_tensor = GoldenMap.Golden( + golden_tensor_dict["key"], + golden_tensor_dict["value"]["shape"], + golden_tensor_dict["value"]["stride"], + golden_tensor_dict["value"]["data"], + ) + self.golden_map.add_golden(golden_tensor) + @staticmethod def to_data_type(dtype): import torch diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index 1a616db248..728ba01c57 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -10,6 +10,7 @@ DataType, DeviceRuntime, DebugEnv, + DebugHooks, get_current_runtime, set_compatible_runtime, get_current_system_desc, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 4f528c02f9..e0f14fbb2b 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -9,6 +9,7 @@ #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" +#include #include #include @@ -20,7 +21,14 @@ PYBIND11_MODULE(_C, m) { py::class_(m, "Device") .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); py::class_(m, "Event"); - py::class_(m, "Tensor"); + py::class_(m, "Tensor") + .def("get_data", &tt::runtime::Tensor::getData); + py::class_(m, "OpContext") + .def("get_op_output_tensor", &tt::runtime::OpContext::getOpOutputTensor) + .def("get_op_debug_str", &tt::runtime::OpContext::getOpDebugString); + py::class_(m, "CallbackContext") + .def("get_debug_info_golden", + &tt::runtime::CallbackContext::getDebugInfoGolden); py::enum_<::tt::target::DataType>(m, "DataType") .value("Float32", ::tt::target::DataType::Float32) .value("Float16", ::tt::target::DataType::Float16) @@ -85,7 +93,8 @@ PYBIND11_MODULE(_C, m) { m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); m.def("submit", &tt::runtime::submit, py::arg("device"), py::arg("executable"), py::arg("program_index"), py::arg("inputs"), - py::arg("outputs"), "Submit a binary for execution"); + py::arg("outputs"), py::arg("goldens"), + "Submit a binary for execution"); m.def("wait", &tt::runtime::wait, py::arg("event")); py::class_(m, "DebugEnv") @@ -96,6 +105,26 @@ PYBIND11_MODULE(_C, m) { return os.str(); }); + py::class_(m, "DebugHooks") + .def_static( + "get", + [](py::function func) { +#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1 + tt::runtime::debug::Hooks::get( + [func](std::optional context, + std::optional opContext) { + func(context, opContext); + }); +#else + tt::runtime::debug::Hooks::get(); +#endif + }) + .def("__str__", [](const tt::runtime::debug::Hooks &hooks) { + std::stringstream os; + os << hooks; + return os.str(); + }); + py::class_(m, "WorkaroundEnv") .def_static("get", &tt::runtime::workaround::Env::get) .def("__str__", [](const tt::runtime::workaround::Env &env) {