Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph capture validation pass #1195

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ project(ttmlir-toolchain LANGUAGES CXX C)
set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274")
set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29")
set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739")
set(NLOHMANN_JSON_VERSION "9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03")

include(ExternalProject)

Expand Down Expand Up @@ -78,5 +79,14 @@ ExternalProject_Add(stablehlo
INSTALL_COMMAND ""
)

ExternalProject_Add(nlohmann_json
PREFIX ${TTMLIR_TOOLCHAIN_DIR}
GIT_REPOSITORY https://github.com/nlohmann/json.git
GIT_TAG ${NLOHMANN_JSON_VERSION}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)

add_custom_target(llvm-lit ALL COMMAND cp llvm-project-prefix/src/llvm-project-build/bin/llvm-lit ${TTMLIR_TOOLCHAIN_DIR}/bin/llvm-lit DEPENDS llvm-project)
add_custom_target(run-clang-tidy-install ALL COMMAND cp llvm-project-prefix/src/llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py ${TTMLIR_TOOLCHAIN_DIR}/bin/run-clang-tidy.py DEPENDS llvm-project)
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ struct TTIRToTTNNBackendPipelineOptions

ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};

// If this option is true, run the entire graph with graph capture to validate
// it.
//
Option<bool> graphCaptureValidationEnabled{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You lack a test using this option.

*this, "graph-capture-validation-enabled",
llvm::cl::desc("Enable TTNN graph validation using graph capture."),
llvm::cl::init(false)};
};

void createTTNNPipelineTTIRPasses(
Expand All @@ -127,6 +135,9 @@ void createTTNNPipelineLayoutDecompositionPass(
void createTTNNPipelineDeallocPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineValidateGraphCapturePass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options);

Expand All @@ -142,6 +153,9 @@ void createTTNNPipelineLayoutDecompositionPassFromString(OpPassManager &pm,
void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
std::string options);

void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm,
std::string options);

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ def TTNNDecomposeLayouts: Pass<"ttnn-decompose-layouts", "::mlir::ModuleOp"> {
}];
}

def TTNNValidateGraphCapture: Pass<"ttnn-validate-graph-capture", "::mlir::ModuleOp"> {
let summary = "Validate op graph with graph capture.";
let description = [{
This pass validates that the produced TTNN op graph is valid using graph capture.
}];
}

#endif
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/nlohmann_json/include)

add_subdirectory(CAPI)
add_subdirectory(Conversion)
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ void createTTNNPipelineDeallocPass(
pm.addPass(createTTNNDeallocate());
}

void createTTNNPipelineValidateGraphCapturePass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(createTTNNValidateGraphCapture());
}

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
Expand Down Expand Up @@ -111,13 +116,24 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineValidateGraphCapturePass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
createTTNNPipelineLayoutDecompositionPass(pm, options);
createTTNNPipelineDeallocPass(pm, options);

if (options.graphCaptureValidationEnabled) {
createTTNNPipelineValidateGraphCapturePass(pm, options);
}
}

//===----------------------------------------------------------------------===//
Expand Down
86 changes: 78 additions & 8 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,23 @@
#include "mlir/IR/PatternMatch.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include <algorithm>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#include "nlohmann/json.hpp"
#pragma clang diagnostic pop

#include <cstdio>
#include <filesystem>
#include <fstream>

#include <iostream>

namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNDEALLOCATE
#define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS
#define GEN_PASS_DEF_TTNNVALIDATEGRAPHCAPTURE
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
Expand Down Expand Up @@ -98,6 +111,66 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}
};

class TTNNValidateGraphCapture
: public impl::TTNNValidateGraphCaptureBase<TTNNValidateGraphCapture> {

public:
using impl::TTNNValidateGraphCaptureBase<
TTNNValidateGraphCapture>::TTNNValidateGraphCaptureBase;

void runOnOperation() final {
const std::filesystem::path tmpDirPath =
std::filesystem::temp_directory_path();

const std::string mlirFilePath = tmpDirPath / "module.mlir";
const std::string flatBufferFilePath = tmpDirPath / "module.ttnn";
const std::string outReportPath = tmpDirPath / "module_graph_capture.json";

outputTTNNIRFile(mlirFilePath);
outputFlatBufferFile(mlirFilePath, flatBufferFilePath);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are failures in each stage handled, how do you know you are not reading artefact of some previous compile session?

runGraphCapture(flatBufferFilePath, outReportPath);

if (!isValidGraphCaptureReport(outReportPath)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsValid needs to return more than bool. It needs to be loc<->op type mapping to -> exception type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then handlers need to be written for every exception type.

// TODO (nobradovic/odjuricic): Handle recompile.
}
}

void outputTTNNIRFile(const std::string &mlirFilePath) {
ModuleOp module = getOperation();
std::error_code _ec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in case of failure?

auto fs = llvm::raw_fd_stream(mlirFilePath, _ec);
module.print(fs);
}

void outputFlatBufferFile(const std::string &mlirFilePath,
const std::string &flatBufferFilePath) {
const std::string cmd =
"./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath +
" -o " + flatBufferFilePath;

system(cmd.c_str());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should look to replace system calls with proper API calls, like Forge does
// Generate binary from the MLIR module. auto binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());

}

void runGraphCapture(const std::string &flatBufferFilePath,
const std::string &outReportFilePath) {
// TODO(mbezulj): Add required env variable to be able to run graph capture
// with mockup device and without kernel compilation.
const std::string cmd = "ttrt run " + flatBufferFilePath +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have something like std::format to allow plugin of parameters within const string?

" --use-graph-capture --result-file " +
outReportFilePath;
system(cmd.c_str());
}

bool isValidGraphCaptureReport(const std::string &outReportPath) {
std::ifstream reportFile(outReportPath);
nlohmann::json jsonData = nlohmann::json::parse(reportFile);

return std::all_of(jsonData.begin(), jsonData.end(), [](auto &jsonElement) {
return jsonElement["result"] == "pass";
});
}
};

class TTNNDecomposeLayouts
: public impl::TTNNDecomposeLayoutsBase<TTNNDecomposeLayouts> {

Expand Down Expand Up @@ -163,14 +236,11 @@ class TTNNDecomposeLayouts

void print() const {
llvm::errs() << "OpsToCreate{ \n"
<< "\t"
<< "CreateToDeviceOp: " << createToDeviceOp << "\n"
<< "\t"
<< "CreateFromDeviceOp: " << createFromDeviceOp << "\n"
<< "\t"
<< "CreateToLayoutOp: " << createToLayoutOp << "\n"
<< "\t"
<< "CreateTypecastOp: " << createTypecastOp << "\n"
<< "\t" << "CreateToDeviceOp: " << createToDeviceOp << "\n"
<< "\t" << "CreateFromDeviceOp: " << createFromDeviceOp
<< "\n"
<< "\t" << "CreateToLayoutOp: " << createToLayoutOp << "\n"
<< "\t" << "CreateTypecastOp: " << createTypecastOp << "\n"
<< "\t"
<< "CreateToMemoryConfigOp: " << createToMemoryConfigOp
<< "\n"
Expand Down
5 changes: 3 additions & 2 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs, bool useGraphCapture);

void wait(Event event);

void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);
std::vector<::ttnn::Tensor *> const &outputs,
bool useGraphCapture);

} // namespace tt::runtime::ttnn

Expand Down
2 changes: 1 addition & 1 deletion runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void closeDevice(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs, bool useGraphCapture = false);

void wait(Event event);

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,12 @@ void closeDevice(Device device) {
Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
std::vector<Tensor> const &outputHandles, bool useGraphCapture) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
outputHandles, useGraphCapture);
}
#endif

Expand Down
70 changes: 64 additions & 6 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
#include "ttnn/graph/graph_processor.hpp"
#include <exception>
#include <memory>
#include <optional>

namespace tt::runtime::ttnn {
using LogType = ::tt::runtime::logger::LogType;
Expand All @@ -41,7 +45,9 @@ class ProgramExecutor {
: context(ProgramContext(liveTensors, programInputs, programOutputs,
meshDevice)) {}

void execute(const ::tt::target::ttnn::Program *program) {
virtual ~ProgramExecutor() = default;

virtual void execute(const ::tt::target::ttnn::Program *program) {
for (const ::tt::target::ttnn::Operation *op : *program->operations()) {
LOG_DEBUG(LogType::LogRuntimeTTNN,
"Executing operation: ", op->debug_info()->c_str());
Expand All @@ -51,12 +57,49 @@ class ProgramExecutor {

ProgramContext &getContext() { return context; }

private:
protected:
ProgramContext context;
void runOperation(const ::tt::target::ttnn::Operation *op);
void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op);
};

class GraphCaptureProgramExecutor : public ProgramExecutor {
public:
using ProgramExecutor::ProgramExecutor;

void execute(const ::tt::target::ttnn::Program *program) override {
const auto execute_impl = [&]() {
unsigned int opIndex = 0;
for (const ::tt::target::ttnn::Operation *op : *program->operations()) {
LOG_DEBUG(LogType::LogRuntimeTTNN,
"Executing operation: ", op->debug_info()->c_str());

try {
runOperation(op);
} catch (const std::exception &ex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this leave device in a bad state? Can it leak memory?

// TODO(mbezulj): Replace opIndex with loc attribute of the operation
// which failed (loc attribute needs to be propagated to the flat
// buffer).
std::stringstream ss;
ss << "Failed on op " << std::to_string(opIndex) << "( "
<< op->debug_info()->c_str() << " ) "
<< " because of: " << ex.what();
throw std::runtime_error(ss.str());
}

++opIndex;
}

return std::nullopt;
};

::ttnn::graph::GraphProcessor::begin_graph_capture(
tt::tt_metal::IGraphProcessor::RunMode::NO_DISPATCH);
execute_impl();
::ttnn::graph::GraphProcessor::GraphProcessor::end_graph_capture();
}
};

void ProgramExecutor::runEltwiseOperation(
const ::tt::target::ttnn::EltwiseOp *op) {
auto runUnaryOp = [&]() {
Expand Down Expand Up @@ -177,10 +220,25 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program,
return isNop;
}

std::unique_ptr<ProgramExecutor>
makeProgramExecutor(const TensorMap &liveTensors,
const std::unordered_set<uint32_t> &programInputs,
const std::unordered_set<uint32_t> &programOutputs,
::ttnn::MeshDevice *meshDevice, bool useGraphCapture) {
if (useGraphCapture) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be something like debug::Env::get().useGraphCapture

return std::make_unique<GraphCaptureProgramExecutor>(
liveTensors, programInputs, programOutputs, meshDevice);
}

return std::make_unique<ProgramExecutor>(liveTensors, programInputs,
programOutputs, meshDevice);
}

void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs) {
std::vector<::ttnn::Tensor *> const &outputs,
bool useGraphCapture) {
if (handleNopProgram(program, inputs, outputs)) {
return;
}
Expand All @@ -206,9 +264,9 @@ void runProgram(::ttnn::MeshDevice &meshDevice,
LOG_ASSERT(inserted, "Duplicate output tensor");
programOutputs.emplace(output->global_id());
}
ProgramExecutor executor(liveTensors, programInputs, programOutputs,
&meshDevice);
executor.execute(program);
std::unique_ptr<ProgramExecutor> executor = makeProgramExecutor(
liveTensors, programInputs, programOutputs, &meshDevice, useGraphCapture);
executor->execute(program);
}

} // namespace tt::runtime::ttnn
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) {
Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
std::vector<Tensor> const &outputHandles, bool useGraphCapture) {
::ttnn::MeshDevice &meshDevice =
deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN);
::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle);
Expand All @@ -180,7 +180,7 @@ Event submit(Device deviceHandle, Binary executableHandle,
outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get()));
}
tt::runtime::ttnn::runProgram(meshDevice, fbb.programs()->Get(programIndex),
inputs, outputs);
inputs, outputs, useGraphCapture);
return Event(nullptr, DeviceRuntime::TTNN);
}

Expand Down
Loading
Loading