diff --git a/env/CMakeLists.txt b/env/CMakeLists.txt index 0f3c26736b..dcee3a1400 100644 --- a/env/CMakeLists.txt +++ b/env/CMakeLists.txt @@ -4,6 +4,7 @@ project(ttmlir-toolchain LANGUAGES CXX C) set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274") set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29") set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739") +set(NLOHMANN_JSON_VERSION "9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03") include(ExternalProject) @@ -78,5 +79,14 @@ ExternalProject_Add(stablehlo INSTALL_COMMAND "" ) +ExternalProject_Add(nlohmann_json + PREFIX ${TTMLIR_TOOLCHAIN_DIR} + GIT_REPOSITORY https://github.com/nlohmann/json.git + GIT_TAG ${NLOHMANN_JSON_VERSION} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" +) + add_custom_target(llvm-lit ALL COMMAND cp llvm-project-prefix/src/llvm-project-build/bin/llvm-lit ${TTMLIR_TOOLCHAIN_DIR}/bin/llvm-lit DEPENDS llvm-project) add_custom_target(run-clang-tidy-install ALL COMMAND cp llvm-project-prefix/src/llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py ${TTMLIR_TOOLCHAIN_DIR}/bin/run-clang-tidy.py DEPENDS llvm-project) diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 7e58298730..b84f3f1183 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -110,6 +110,14 @@ struct TTIRToTTNNBackendPipelineOptions ListOption meshShape{ *this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")}; + + // If this option is true, run the entire graph with graph capture to validate + // it. + // + Option graphCaptureValidationEnabled{ + *this, "graph-capture-validation-enabled", + llvm::cl::desc("Enable TTNN graph validation using graph capture."), + llvm::cl::init(false)}; }; void createTTNNPipelineTTIRPasses( @@ -127,6 +135,9 @@ void createTTNNPipelineLayoutDecompositionPass( void createTTNNPipelineDeallocPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options); +void createTTNNPipelineValidateGraphCapturePass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options); + void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options); @@ -142,6 +153,9 @@ void createTTNNPipelineLayoutDecompositionPassFromString(OpPassManager &pm, void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, std::string options); +void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm, + std::string options); + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options); diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index c29fa977b4..49d86a6454 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -21,4 +21,11 @@ def TTNNDecomposeLayouts: Pass<"ttnn-decompose-layouts", "::mlir::ModuleOp"> { }]; } +def TTNNValidateGraphCapture: Pass<"ttnn-validate-graph-capture", "::mlir::ModuleOp"> { + let summary = "Validate op graph with graph capture."; + let description = [{ + This pass validates that the produced TTNN op graph is valid using graph capture. + }]; +} + #endif diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c3dc3a4b71..85fa440613 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,5 +1,6 @@ include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo) include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build) +include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/nlohmann_json/include) add_subdirectory(CAPI) add_subdirectory(Conversion) diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 772b51b04a..48577665b4 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -76,6 +76,11 @@ void createTTNNPipelineDeallocPass( pm.addPass(createTTNNDeallocate()); } +void createTTNNPipelineValidateGraphCapturePass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(createTTNNValidateGraphCapture()); +} + void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options) { auto optionsStruct = @@ -111,6 +116,13 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineValidateGraphCapturePass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); @@ -118,6 +130,10 @@ void createTTIRToTTNNBackendPipeline( createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); createTTNNPipelineDeallocPass(pm, options); + + if (options.graphCaptureValidationEnabled) { + createTTNNPipelineValidateGraphCapturePass(pm, options); + } } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index cb0d8c8869..42216eba1f 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -8,10 +8,23 @@ #include "mlir/IR/PatternMatch.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcovered-switch-default" +#include "nlohmann/json.hpp" +#pragma clang diagnostic pop + +#include +#include +#include + +#include namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNDEALLOCATE #define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS +#define GEN_PASS_DEF_TTNNVALIDATEGRAPHCAPTURE #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" class TTNNDeallocate : public impl::TTNNDeallocateBase { @@ -98,6 +111,66 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase { } }; +class TTNNValidateGraphCapture + : public impl::TTNNValidateGraphCaptureBase { + +public: + using impl::TTNNValidateGraphCaptureBase< + TTNNValidateGraphCapture>::TTNNValidateGraphCaptureBase; + + void runOnOperation() final { + const std::filesystem::path tmpDirPath = + std::filesystem::temp_directory_path(); + + const std::string mlirFilePath = tmpDirPath / "module.mlir"; + const std::string flatBufferFilePath = tmpDirPath / "module.ttnn"; + const std::string outReportPath = tmpDirPath / "module_graph_capture.json"; + + outputTTNNIRFile(mlirFilePath); + outputFlatBufferFile(mlirFilePath, flatBufferFilePath); + runGraphCapture(flatBufferFilePath, outReportPath); + + if (!isValidGraphCaptureReport(outReportPath)) { + // TODO (nobradovic/odjuricic): Handle recompile. + } + } + + void outputTTNNIRFile(const std::string &mlirFilePath) { + ModuleOp module = getOperation(); + std::error_code _ec; + auto fs = llvm::raw_fd_stream(mlirFilePath, _ec); + module.print(fs); + } + + void outputFlatBufferFile(const std::string &mlirFilePath, + const std::string &flatBufferFilePath) { + const std::string cmd = + "./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath + + " -o " + flatBufferFilePath; + + system(cmd.c_str()); + } + + void runGraphCapture(const std::string &flatBufferFilePath, + const std::string &outReportFilePath) { + // TODO(mbezulj): Add required env variable to be able to run graph capture + // with mockup device and without kernel compilation. + const std::string cmd = "ttrt run " + flatBufferFilePath + + " --use-graph-capture --result-file " + + outReportFilePath; + system(cmd.c_str()); + } + + bool isValidGraphCaptureReport(const std::string &outReportPath) { + std::ifstream reportFile(outReportPath); + nlohmann::json jsonData = nlohmann::json::parse(reportFile); + + return std::all_of(jsonData.begin(), jsonData.end(), [](auto &jsonElement) { + return jsonElement["result"] == "pass"; + }); + } +}; + class TTNNDecomposeLayouts : public impl::TTNNDecomposeLayoutsBase { @@ -163,14 +236,11 @@ class TTNNDecomposeLayouts void print() const { llvm::errs() << "OpsToCreate{ \n" - << "\t" - << "CreateToDeviceOp: " << createToDeviceOp << "\n" - << "\t" - << "CreateFromDeviceOp: " << createFromDeviceOp << "\n" - << "\t" - << "CreateToLayoutOp: " << createToLayoutOp << "\n" - << "\t" - << "CreateTypecastOp: " << createTypecastOp << "\n" + << "\t" << "CreateToDeviceOp: " << createToDeviceOp << "\n" + << "\t" << "CreateFromDeviceOp: " << createFromDeviceOp + << "\n" + << "\t" << "CreateToLayoutOp: " << createToLayoutOp << "\n" + << "\t" << "CreateTypecastOp: " << createTypecastOp << "\n" << "\t" << "CreateToMemoryConfigOp: " << createToMemoryConfigOp << "\n" diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 4580b290be..0cc1d0238d 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -113,14 +113,15 @@ void deallocateBuffers(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, bool useGraphCapture); void wait(Event event); void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs); + std::vector<::ttnn::Tensor *> const &outputs, + bool useGraphCapture); } // namespace tt::runtime::ttnn diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index a070f2f0f5..13ef3af5e6 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -65,7 +65,7 @@ void closeDevice(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); + std::vector const &outputs, bool useGraphCapture = false); void wait(Event event); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 8b0e79daab..77cbe3c178 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -211,12 +211,12 @@ void closeDevice(Device device) { Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, bool useGraphCapture) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, programIndex, inputHandles, - outputHandles); + outputHandles, useGraphCapture); } #endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index af1b28d990..184899a0b1 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -28,6 +28,10 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" #include "ttmlir/Target/TTNN/program_generated.h" +#include "ttnn/graph/graph_processor.hpp" +#include +#include +#include namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; @@ -41,7 +45,9 @@ class ProgramExecutor { : context(ProgramContext(liveTensors, programInputs, programOutputs, meshDevice)) {} - void execute(const ::tt::target::ttnn::Program *program) { + virtual ~ProgramExecutor() = default; + + virtual void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { LOG_DEBUG(LogType::LogRuntimeTTNN, "Executing operation: ", op->debug_info()->c_str()); @@ -51,12 +57,49 @@ class ProgramExecutor { ProgramContext &getContext() { return context; } -private: +protected: ProgramContext context; void runOperation(const ::tt::target::ttnn::Operation *op); void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); }; +class GraphCaptureProgramExecutor : public ProgramExecutor { +public: + using ProgramExecutor::ProgramExecutor; + + void execute(const ::tt::target::ttnn::Program *program) override { + const auto execute_impl = [&]() { + unsigned int opIndex = 0; + for (const ::tt::target::ttnn::Operation *op : *program->operations()) { + LOG_DEBUG(LogType::LogRuntimeTTNN, + "Executing operation: ", op->debug_info()->c_str()); + + try { + runOperation(op); + } catch (const std::exception &ex) { + // TODO(mbezulj): Replace opIndex with loc attribute of the operation + // which failed (loc attribute needs to be propagated to the flat + // buffer). + std::stringstream ss; + ss << "Failed on op " << std::to_string(opIndex) << "( " + << op->debug_info()->c_str() << " ) " + << " because of: " << ex.what(); + throw std::runtime_error(ss.str()); + } + + ++opIndex; + } + + return std::nullopt; + }; + + ::ttnn::graph::GraphProcessor::begin_graph_capture( + tt::tt_metal::IGraphProcessor::RunMode::NO_DISPATCH); + execute_impl(); + ::ttnn::graph::GraphProcessor::GraphProcessor::end_graph_capture(); + } +}; + void ProgramExecutor::runEltwiseOperation( const ::tt::target::ttnn::EltwiseOp *op) { auto runUnaryOp = [&]() { @@ -177,10 +220,25 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program, return isNop; } +std::unique_ptr +makeProgramExecutor(const TensorMap &liveTensors, + const std::unordered_set &programInputs, + const std::unordered_set &programOutputs, + ::ttnn::MeshDevice *meshDevice, bool useGraphCapture) { + if (useGraphCapture) { + return std::make_unique( + liveTensors, programInputs, programOutputs, meshDevice); + } + + return std::make_unique(liveTensors, programInputs, + programOutputs, meshDevice); +} + void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs) { + std::vector<::ttnn::Tensor *> const &outputs, + bool useGraphCapture) { if (handleNopProgram(program, inputs, outputs)) { return; } @@ -206,9 +264,9 @@ void runProgram(::ttnn::MeshDevice &meshDevice, LOG_ASSERT(inserted, "Duplicate output tensor"); programOutputs.emplace(output->global_id()); } - ProgramExecutor executor(liveTensors, programInputs, programOutputs, - &meshDevice); - executor.execute(program); + std::unique_ptr executor = makeProgramExecutor( + liveTensors, programInputs, programOutputs, &meshDevice, useGraphCapture); + executor->execute(program); } } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index b06ae893aa..4bd21b0e39 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -163,7 +163,7 @@ static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, - std::vector const &outputHandles) { + std::vector const &outputHandles, bool useGraphCapture) { ::ttnn::MeshDevice &meshDevice = deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); @@ -180,7 +180,7 @@ Event submit(Device deviceHandle, Binary executableHandle, outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); } tt::runtime::ttnn::runProgram(meshDevice, fbb.programs()->Get(programIndex), - inputs, outputs); + inputs, outputs, useGraphCapture); return Event(nullptr, DeviceRuntime::TTNN); } diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index 976779e5fe..f09445bfa2 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -172,6 +172,13 @@ def initialize_api(): choices=None, help="test file to save results to", ) + Run.register_arg( + name="--use-graph-capture", + type=bool, + default=False, + choices=[True, False], + help="use graph capture to simulate workload run", + ) Run.register_arg( name="binary", type=str, @@ -450,6 +457,7 @@ def _execute(binaries): program_index, total_inputs[loop], total_outputs[loop], + self["--use-graph-capture"], ) self.logging.debug( diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 4f528c02f9..f86da10aee 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -85,7 +85,8 @@ PYBIND11_MODULE(_C, m) { m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); m.def("submit", &tt::runtime::submit, py::arg("device"), py::arg("executable"), py::arg("program_index"), py::arg("inputs"), - py::arg("outputs"), "Submit a binary for execution"); + py::arg("outputs"), py::arg("use_graph_capture"), + "Submit a binary for execution"); m.def("wait", &tt::runtime::wait, py::arg("event")); py::class_(m, "DebugEnv")