From e9a92371065fc26883a899b361cb0ffcdd2e3ed6 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Wed, 6 Nov 2024 16:33:30 +0000 Subject: [PATCH] Add graph capture validation pass --- env/CMakeLists.txt | 10 +++ .../Dialect/TTNN/Pipelines/TTNNPipelines.h | 14 +++ .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 7 ++ lib/CMakeLists.txt | 1 + lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 16 ++++ lib/Dialect/TTNN/Transforms/Passes.cpp | 86 +++++++++++++++++-- runtime/include/tt/runtime/detail/ttnn.h | 5 +- runtime/include/tt/runtime/runtime.h | 2 +- runtime/lib/runtime.cpp | 4 +- runtime/lib/ttnn/program.cpp | 73 ++++++++++++++-- runtime/lib/ttnn/runtime.cpp | 4 +- runtime/tools/python/ttrt/common/run.py | 8 ++ runtime/tools/python/ttrt/runtime/module.cpp | 3 +- 13 files changed, 210 insertions(+), 23 deletions(-) diff --git a/env/CMakeLists.txt b/env/CMakeLists.txt index 0f3c26736b..dcee3a1400 100644 --- a/env/CMakeLists.txt +++ b/env/CMakeLists.txt @@ -4,6 +4,7 @@ project(ttmlir-toolchain LANGUAGES CXX C) set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274") set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29") set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739") +set(NLOHMANN_JSON_VERSION "9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03") include(ExternalProject) @@ -78,5 +79,14 @@ ExternalProject_Add(stablehlo INSTALL_COMMAND "" ) +ExternalProject_Add(nlohmann_json + PREFIX ${TTMLIR_TOOLCHAIN_DIR} + GIT_REPOSITORY https://github.com/nlohmann/json.git + GIT_TAG ${NLOHMANN_JSON_VERSION} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" +) + add_custom_target(llvm-lit ALL COMMAND cp llvm-project-prefix/src/llvm-project-build/bin/llvm-lit ${TTMLIR_TOOLCHAIN_DIR}/bin/llvm-lit DEPENDS llvm-project) add_custom_target(run-clang-tidy-install ALL COMMAND cp llvm-project-prefix/src/llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py ${TTMLIR_TOOLCHAIN_DIR}/bin/run-clang-tidy.py DEPENDS llvm-project) diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 9988bbcc18..984e0a059b 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -104,6 +104,14 @@ struct TTIRToTTNNBackendPipelineOptions ListOption meshShape{ *this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")}; + + // If this option is true, run the entire graph with graph capture to validate + // it. + // + Option graphCaptureValidationEnabled{ + *this, "graph-capture-validation-enabled", + llvm::cl::desc("Enable TTNN graph validation using graph capture."), + llvm::cl::init(false)}; }; void createTTNNPipelineTTIRPasses( @@ -121,6 +129,9 @@ void createTTNNPipelineLayoutDecompositionPass( void createTTNNPipelineDeallocPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options); +void createTTNNPipelineValidateGraphCapturePass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options); + void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options); @@ -136,6 +147,9 @@ void createTTNNPipelineLayoutDecompositionPassFromString(OpPassManager &pm, void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, std::string options); +void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm, + std::string options); + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options); diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index c29fa977b4..49d86a6454 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -21,4 +21,11 @@ def TTNNDecomposeLayouts: Pass<"ttnn-decompose-layouts", "::mlir::ModuleOp"> { }]; } +def TTNNValidateGraphCapture: Pass<"ttnn-validate-graph-capture", "::mlir::ModuleOp"> { + let summary = "Validate op graph with graph capture."; + let description = [{ + This pass validates that the produced TTNN op graph is valid using graph capture. + }]; +} + #endif diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c3dc3a4b71..85fa440613 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,5 +1,6 @@ include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo) include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build) +include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/nlohmann_json/include) add_subdirectory(CAPI) add_subdirectory(Conversion) diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 7f3baaeaf7..e21c25e2e9 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -74,6 +74,11 @@ void createTTNNPipelineDeallocPass( pm.addPass(createTTNNDeallocate()); } +void createTTNNPipelineValidateGraphCapturePass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(createTTNNValidateGraphCapture()); +} + void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options) { auto optionsStruct = @@ -109,6 +114,13 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineValidateGraphCapturePass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); @@ -116,6 +128,10 @@ void createTTIRToTTNNBackendPipeline( createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); createTTNNPipelineDeallocPass(pm, options); + + if (options.graphCaptureValidationEnabled) { + createTTNNPipelineValidateGraphCapturePass(pm, options); + } } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index cb0d8c8869..42216eba1f 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -8,10 +8,23 @@ #include "mlir/IR/PatternMatch.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcovered-switch-default" +#include "nlohmann/json.hpp" +#pragma clang diagnostic pop + +#include +#include +#include + +#include namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNDEALLOCATE #define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS +#define GEN_PASS_DEF_TTNNVALIDATEGRAPHCAPTURE #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" class TTNNDeallocate : public impl::TTNNDeallocateBase { @@ -98,6 +111,66 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase { } }; +class TTNNValidateGraphCapture + : public impl::TTNNValidateGraphCaptureBase { + +public: + using impl::TTNNValidateGraphCaptureBase< + TTNNValidateGraphCapture>::TTNNValidateGraphCaptureBase; + + void runOnOperation() final { + const std::filesystem::path tmpDirPath = + std::filesystem::temp_directory_path(); + + const std::string mlirFilePath = tmpDirPath / "module.mlir"; + const std::string flatBufferFilePath = tmpDirPath / "module.ttnn"; + const std::string outReportPath = tmpDirPath / "module_graph_capture.json"; + + outputTTNNIRFile(mlirFilePath); + outputFlatBufferFile(mlirFilePath, flatBufferFilePath); + runGraphCapture(flatBufferFilePath, outReportPath); + + if (!isValidGraphCaptureReport(outReportPath)) { + // TODO (nobradovic/odjuricic): Handle recompile. + } + } + + void outputTTNNIRFile(const std::string &mlirFilePath) { + ModuleOp module = getOperation(); + std::error_code _ec; + auto fs = llvm::raw_fd_stream(mlirFilePath, _ec); + module.print(fs); + } + + void outputFlatBufferFile(const std::string &mlirFilePath, + const std::string &flatBufferFilePath) { + const std::string cmd = + "./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath + + " -o " + flatBufferFilePath; + + system(cmd.c_str()); + } + + void runGraphCapture(const std::string &flatBufferFilePath, + const std::string &outReportFilePath) { + // TODO(mbezulj): Add required env variable to be able to run graph capture + // with mockup device and without kernel compilation. + const std::string cmd = "ttrt run " + flatBufferFilePath + + " --use-graph-capture --result-file " + + outReportFilePath; + system(cmd.c_str()); + } + + bool isValidGraphCaptureReport(const std::string &outReportPath) { + std::ifstream reportFile(outReportPath); + nlohmann::json jsonData = nlohmann::json::parse(reportFile); + + return std::all_of(jsonData.begin(), jsonData.end(), [](auto &jsonElement) { + return jsonElement["result"] == "pass"; + }); + } +}; + class TTNNDecomposeLayouts : public impl::TTNNDecomposeLayoutsBase { @@ -163,14 +236,11 @@ class TTNNDecomposeLayouts void print() const { llvm::errs() << "OpsToCreate{ \n" - << "\t" - << "CreateToDeviceOp: " << createToDeviceOp << "\n" - << "\t" - << "CreateFromDeviceOp: " << createFromDeviceOp << "\n" - << "\t" - << "CreateToLayoutOp: " << createToLayoutOp << "\n" - << "\t" - << "CreateTypecastOp: " << createTypecastOp << "\n" + << "\t" << "CreateToDeviceOp: " << createToDeviceOp << "\n" + << "\t" << "CreateFromDeviceOp: " << createFromDeviceOp + << "\n" + << "\t" << "CreateToLayoutOp: " << createToLayoutOp << "\n" + << "\t" << "CreateTypecastOp: " << createTypecastOp << "\n" << "\t" << "CreateToMemoryConfigOp: " << createToMemoryConfigOp << "\n" diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 0fdfdbddff..056109df7c 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -98,14 +98,15 @@ void deallocateBuffers(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, bool useGraphCapture); void wait(Event event); void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs); + std::vector<::ttnn::Tensor *> const &outputs, + bool useGraphCapture); } // namespace tt::runtime::ttnn diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 05971f1600..ddaf58743d 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -51,7 +51,7 @@ void closeDevice(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, bool useGraphCapture = false); void wait(Event event); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index eca784a95c..9854e67481 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -187,12 +187,12 @@ void closeDevice(Device device) { Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, bool useGraphCapture) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, programIndex, inputHandles, - outputHandles); + outputHandles, useGraphCapture); } #endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index f150d35c10..184899a0b1 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -28,11 +28,16 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" #include "ttmlir/Target/TTNN/program_generated.h" +#include "ttnn/graph/graph_processor.hpp" +#include +#include +#include namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; -struct ProgramExecutor { +class ProgramExecutor { +public: ProgramExecutor(const TensorMap &liveTensors, const std::unordered_set &programInputs, const std::unordered_set &programOutputs, @@ -40,7 +45,9 @@ struct ProgramExecutor { : context(ProgramContext(liveTensors, programInputs, programOutputs, meshDevice)) {} - void execute(const ::tt::target::ttnn::Program *program) { + virtual ~ProgramExecutor() = default; + + virtual void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { LOG_DEBUG(LogType::LogRuntimeTTNN, "Executing operation: ", op->debug_info()->c_str()); @@ -50,12 +57,49 @@ struct ProgramExecutor { ProgramContext &getContext() { return context; } -private: +protected: ProgramContext context; void runOperation(const ::tt::target::ttnn::Operation *op); void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); }; +class GraphCaptureProgramExecutor : public ProgramExecutor { +public: + using ProgramExecutor::ProgramExecutor; + + void execute(const ::tt::target::ttnn::Program *program) override { + const auto execute_impl = [&]() { + unsigned int opIndex = 0; + for (const ::tt::target::ttnn::Operation *op : *program->operations()) { + LOG_DEBUG(LogType::LogRuntimeTTNN, + "Executing operation: ", op->debug_info()->c_str()); + + try { + runOperation(op); + } catch (const std::exception &ex) { + // TODO(mbezulj): Replace opIndex with loc attribute of the operation + // which failed (loc attribute needs to be propagated to the flat + // buffer). + std::stringstream ss; + ss << "Failed on op " << std::to_string(opIndex) << "( " + << op->debug_info()->c_str() << " ) " + << " because of: " << ex.what(); + throw std::runtime_error(ss.str()); + } + + ++opIndex; + } + + return std::nullopt; + }; + + ::ttnn::graph::GraphProcessor::begin_graph_capture( + tt::tt_metal::IGraphProcessor::RunMode::NO_DISPATCH); + execute_impl(); + ::ttnn::graph::GraphProcessor::GraphProcessor::end_graph_capture(); + } +}; + void ProgramExecutor::runEltwiseOperation( const ::tt::target::ttnn::EltwiseOp *op) { auto runUnaryOp = [&]() { @@ -176,10 +220,25 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program, return isNop; } +std::unique_ptr +makeProgramExecutor(const TensorMap &liveTensors, + const std::unordered_set &programInputs, + const std::unordered_set &programOutputs, + ::ttnn::MeshDevice *meshDevice, bool useGraphCapture) { + if (useGraphCapture) { + return std::make_unique( + liveTensors, programInputs, programOutputs, meshDevice); + } + + return std::make_unique(liveTensors, programInputs, + programOutputs, meshDevice); +} + void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs) { + std::vector<::ttnn::Tensor *> const &outputs, + bool useGraphCapture) { if (handleNopProgram(program, inputs, outputs)) { return; } @@ -205,9 +264,9 @@ void runProgram(::ttnn::MeshDevice &meshDevice, LOG_ASSERT(inserted, "Duplicate output tensor"); programOutputs.emplace(output->global_id()); } - ProgramExecutor executor(liveTensors, programInputs, programOutputs, - &meshDevice); - executor.execute(program); + std::unique_ptr executor = makeProgramExecutor( + liveTensors, programInputs, programOutputs, &meshDevice, useGraphCapture); + executor->execute(program); } } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 24c372b681..3518d4c93f 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -110,7 +110,7 @@ static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, bool useGraphCapture) { ::ttnn::MeshDevice &meshDevice = deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); @@ -127,7 +127,7 @@ Event submit(Device deviceHandle, Binary executableHandle, outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); } tt::runtime::ttnn::runProgram(meshDevice, fbb.programs()->Get(programIndex), - inputs, outputs); + inputs, outputs, useGraphCapture); return Event(nullptr, DeviceRuntime::TTNN); } diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 976779e5fe..f09445bfa2 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -172,6 +172,13 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Run.register_arg( + name="--use-graph-capture", + type=bool, + default=False, + choices=[True, False], + help="use graph capture to simulate workload run", + ) Run.register_arg( name="binary", type=str, @@ -450,6 +457,7 @@ def _execute(binaries): program_index, total_inputs[loop], total_outputs[loop], + self["--use-graph-capture"], ) self.logging.debug( diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 040b41306c..201b9421b2 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -66,7 +66,8 @@ PYBIND11_MODULE(_C, m) { m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); m.def("submit", &tt::runtime::submit, py::arg("device"), py::arg("executable"), py::arg("program_index"), py::arg("inputs"), - py::arg("outputs"), "Submit a binary for execution"); + py::arg("outputs"), py::arg("use_graph_capture"), + "Submit a binary for execution"); m.def("wait", &tt::runtime::wait, py::arg("event")); py::class_(m, "DebugEnv")