-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add graph capture validation pass #1195
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,23 @@ | |
#include "mlir/IR/PatternMatch.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" | ||
#include "ttmlir/Dialect/TTNN/Utils/Utils.h" | ||
#include <algorithm> | ||
|
||
#pragma clang diagnostic push | ||
#pragma clang diagnostic ignored "-Wcovered-switch-default" | ||
#include "nlohmann/json.hpp" | ||
#pragma clang diagnostic pop | ||
|
||
#include <cstdio> | ||
#include <filesystem> | ||
#include <fstream> | ||
|
||
#include <iostream> | ||
|
||
namespace mlir::tt::ttnn { | ||
#define GEN_PASS_DEF_TTNNDEALLOCATE | ||
#define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS | ||
#define GEN_PASS_DEF_TTNNVALIDATEGRAPHCAPTURE | ||
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" | ||
|
||
class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> { | ||
|
@@ -98,6 +111,66 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> { | |
} | ||
}; | ||
|
||
class TTNNValidateGraphCapture | ||
: public impl::TTNNValidateGraphCaptureBase<TTNNValidateGraphCapture> { | ||
|
||
public: | ||
using impl::TTNNValidateGraphCaptureBase< | ||
TTNNValidateGraphCapture>::TTNNValidateGraphCaptureBase; | ||
|
||
void runOnOperation() final { | ||
const std::filesystem::path tmpDirPath = | ||
std::filesystem::temp_directory_path(); | ||
|
||
const std::string mlirFilePath = tmpDirPath / "module.mlir"; | ||
const std::string flatBufferFilePath = tmpDirPath / "module.ttnn"; | ||
const std::string outReportPath = tmpDirPath / "module_graph_capture.json"; | ||
|
||
outputTTNNIRFile(mlirFilePath); | ||
outputFlatBufferFile(mlirFilePath, flatBufferFilePath); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are failures in each stage handled, how do you know you are not reading artefact of some previous compile session? |
||
runGraphCapture(flatBufferFilePath, outReportPath); | ||
|
||
if (!isValidGraphCaptureReport(outReportPath)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IsValid needs to return more than bool. It needs to be loc<->op type mapping to -> exception type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then handlers need to be written for every exception type. |
||
// TODO (nobradovic/odjuricic): Handle recompile. | ||
} | ||
} | ||
|
||
void outputTTNNIRFile(const std::string &mlirFilePath) { | ||
ModuleOp module = getOperation(); | ||
std::error_code _ec; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens in case of failure? |
||
auto fs = llvm::raw_fd_stream(mlirFilePath, _ec); | ||
module.print(fs); | ||
} | ||
|
||
void outputFlatBufferFile(const std::string &mlirFilePath, | ||
const std::string &flatBufferFilePath) { | ||
const std::string cmd = | ||
"./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath + | ||
" -o " + flatBufferFilePath; | ||
|
||
system(cmd.c_str()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should look to replace system calls with proper API calls, like Forge does |
||
} | ||
|
||
void runGraphCapture(const std::string &flatBufferFilePath, | ||
const std::string &outReportFilePath) { | ||
// TODO(mbezulj): Add required env variable to be able to run graph capture | ||
// with mockup device and without kernel compilation. | ||
const std::string cmd = "ttrt run " + flatBufferFilePath + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have something like std::format to allow plugin of parameters within const string? |
||
" --use-graph-capture --result-file " + | ||
outReportFilePath; | ||
system(cmd.c_str()); | ||
} | ||
|
||
bool isValidGraphCaptureReport(const std::string &outReportPath) { | ||
std::ifstream reportFile(outReportPath); | ||
nlohmann::json jsonData = nlohmann::json::parse(reportFile); | ||
|
||
return std::all_of(jsonData.begin(), jsonData.end(), [](auto &jsonElement) { | ||
return jsonElement["result"] == "pass"; | ||
}); | ||
} | ||
}; | ||
|
||
class TTNNDecomposeLayouts | ||
: public impl::TTNNDecomposeLayoutsBase<TTNNDecomposeLayouts> { | ||
|
||
|
@@ -163,14 +236,11 @@ class TTNNDecomposeLayouts | |
|
||
void print() const { | ||
llvm::errs() << "OpsToCreate{ \n" | ||
<< "\t" | ||
<< "CreateToDeviceOp: " << createToDeviceOp << "\n" | ||
<< "\t" | ||
<< "CreateFromDeviceOp: " << createFromDeviceOp << "\n" | ||
<< "\t" | ||
<< "CreateToLayoutOp: " << createToLayoutOp << "\n" | ||
<< "\t" | ||
<< "CreateTypecastOp: " << createTypecastOp << "\n" | ||
<< "\t" << "CreateToDeviceOp: " << createToDeviceOp << "\n" | ||
<< "\t" << "CreateFromDeviceOp: " << createFromDeviceOp | ||
<< "\n" | ||
<< "\t" << "CreateToLayoutOp: " << createToLayoutOp << "\n" | ||
<< "\t" << "CreateTypecastOp: " << createTypecastOp << "\n" | ||
<< "\t" | ||
<< "CreateToMemoryConfigOp: " << createToMemoryConfigOp | ||
<< "\n" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,10 @@ | |
#include "tt/runtime/detail/logger.h" | ||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
#include "ttnn/graph/graph_processor.hpp" | ||
#include <exception> | ||
#include <memory> | ||
#include <optional> | ||
|
||
namespace tt::runtime::ttnn { | ||
using LogType = ::tt::runtime::logger::LogType; | ||
|
@@ -41,7 +45,9 @@ class ProgramExecutor { | |
: context(ProgramContext(liveTensors, programInputs, programOutputs, | ||
meshDevice)) {} | ||
|
||
void execute(const ::tt::target::ttnn::Program *program) { | ||
virtual ~ProgramExecutor() = default; | ||
|
||
virtual void execute(const ::tt::target::ttnn::Program *program) { | ||
for (const ::tt::target::ttnn::Operation *op : *program->operations()) { | ||
LOG_DEBUG(LogType::LogRuntimeTTNN, | ||
"Executing operation: ", op->debug_info()->c_str()); | ||
|
@@ -51,12 +57,49 @@ class ProgramExecutor { | |
|
||
ProgramContext &getContext() { return context; } | ||
|
||
private: | ||
protected: | ||
ProgramContext context; | ||
void runOperation(const ::tt::target::ttnn::Operation *op); | ||
void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); | ||
}; | ||
|
||
class GraphCaptureProgramExecutor : public ProgramExecutor { | ||
public: | ||
using ProgramExecutor::ProgramExecutor; | ||
|
||
void execute(const ::tt::target::ttnn::Program *program) override { | ||
const auto execute_impl = [&]() { | ||
unsigned int opIndex = 0; | ||
for (const ::tt::target::ttnn::Operation *op : *program->operations()) { | ||
LOG_DEBUG(LogType::LogRuntimeTTNN, | ||
"Executing operation: ", op->debug_info()->c_str()); | ||
|
||
try { | ||
runOperation(op); | ||
} catch (const std::exception &ex) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this leave device in a bad state? Can it leak memory? |
||
// TODO(mbezulj): Replace opIndex with loc attribute of the operation | ||
// which failed (loc attribute needs to be propagated to the flat | ||
// buffer). | ||
std::stringstream ss; | ||
ss << "Failed on op " << std::to_string(opIndex) << "( " | ||
<< op->debug_info()->c_str() << " ) " | ||
<< " because of: " << ex.what(); | ||
throw std::runtime_error(ss.str()); | ||
} | ||
|
||
++opIndex; | ||
} | ||
|
||
return std::nullopt; | ||
}; | ||
|
||
::ttnn::graph::GraphProcessor::begin_graph_capture( | ||
tt::tt_metal::IGraphProcessor::RunMode::NO_DISPATCH); | ||
execute_impl(); | ||
::ttnn::graph::GraphProcessor::GraphProcessor::end_graph_capture(); | ||
} | ||
}; | ||
|
||
void ProgramExecutor::runEltwiseOperation( | ||
const ::tt::target::ttnn::EltwiseOp *op) { | ||
auto runUnaryOp = [&]() { | ||
|
@@ -177,10 +220,25 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program, | |
return isNop; | ||
} | ||
|
||
std::unique_ptr<ProgramExecutor> | ||
makeProgramExecutor(const TensorMap &liveTensors, | ||
const std::unordered_set<uint32_t> &programInputs, | ||
const std::unordered_set<uint32_t> &programOutputs, | ||
::ttnn::MeshDevice *meshDevice, bool useGraphCapture) { | ||
if (useGraphCapture) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this would be something like debug::Env::get().useGraphCapture |
||
return std::make_unique<GraphCaptureProgramExecutor>( | ||
liveTensors, programInputs, programOutputs, meshDevice); | ||
} | ||
|
||
return std::make_unique<ProgramExecutor>(liveTensors, programInputs, | ||
programOutputs, meshDevice); | ||
} | ||
|
||
void runProgram(::ttnn::MeshDevice &meshDevice, | ||
::tt::target::ttnn::Program const *program, | ||
std::vector<::ttnn::Tensor *> const &inputs, | ||
std::vector<::ttnn::Tensor *> const &outputs) { | ||
std::vector<::ttnn::Tensor *> const &outputs, | ||
bool useGraphCapture) { | ||
if (handleNopProgram(program, inputs, outputs)) { | ||
return; | ||
} | ||
|
@@ -206,9 +264,9 @@ void runProgram(::ttnn::MeshDevice &meshDevice, | |
LOG_ASSERT(inserted, "Duplicate output tensor"); | ||
programOutputs.emplace(output->global_id()); | ||
} | ||
ProgramExecutor executor(liveTensors, programInputs, programOutputs, | ||
&meshDevice); | ||
executor.execute(program); | ||
std::unique_ptr<ProgramExecutor> executor = makeProgramExecutor( | ||
liveTensors, programInputs, programOutputs, &meshDevice, useGraphCapture); | ||
executor->execute(program); | ||
} | ||
|
||
} // namespace tt::runtime::ttnn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You lack a test using this option.