Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph capture validation pass #1195

Closed
wants to merge 2 commits into from

Conversation

derdeljanTT
Copy link
Contributor

@derdeljanTT derdeljanTT commented Nov 7, 2024

Implement a mlir pass which performs graph capture validation of output TTNN graph. By default, this pass is disabled. #1183

fyi @mbezuljTT @s-jovic

@@ -450,6 +457,7 @@ def _execute(binaries):
program_index,
total_inputs[loop],
total_outputs[loop],
self["--use-graph-capture"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this in trt.runtime.DebugEnv? It seems like this feature is a debug mode feature and the infra is already there to add specific bools like this one

debug_env = ttrt.runtime.DebugEnv.get(
self["--load-kernels-from-disk"], self["--enable-async-ttnn"]
)

const std::unordered_set<uint32_t> &programInputs,
const std::unordered_set<uint32_t> &programOutputs,
::ttnn::MeshDevice *meshDevice, bool useGraphCapture) {
if (useGraphCapture) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be something like debug::Env::get().useGraphCapture

Copy link
Contributor

@tapspatel tapspatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes inlined

@mbezuljTT mbezuljTT self-assigned this Nov 12, 2024
// If this option is true, run the entire graph with graph capture to validate
// it.
//
Option<bool> graphCaptureValidationEnabled{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You lack a test using this option.


void outputTTNNIRFile(const std::string &mlirFilePath) {
ModuleOp module = getOperation();
std::error_code _ec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in case of failure?

const std::string outReportPath = tmpDirPath / "module_graph_capture.json";

outputTTNNIRFile(mlirFilePath);
outputFlatBufferFile(mlirFilePath, flatBufferFilePath);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are failures in each stage handled, how do you know you are not reading artefact of some previous compile session?

const std::string &outReportFilePath) {
// TODO(mbezulj): Add required env variable to be able to run graph capture
// with mockup device and without kernel compilation.
const std::string cmd = "ttrt run " + flatBufferFilePath +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have something like std::format to allow plugin of parameters within const string?

outputFlatBufferFile(mlirFilePath, flatBufferFilePath);
runGraphCapture(flatBufferFilePath, outReportPath);

if (!isValidGraphCaptureReport(outReportPath)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsValid needs to return more than bool. It needs to be loc<->op type mapping to -> exception type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then handlers need to be written for every exception type.


try {
runOperation(op);
} catch (const std::exception &ex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this leave device in a bad state? Can it leak memory?

"./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath +
" -o " + flatBufferFilePath;

system(cmd.c_str());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should look to replace system calls with proper API calls, like Forge does
// Generate binary from the MLIR module. auto binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());

@mbezuljTT
Copy link
Contributor

Thanks everyone for great feedback on the implementation! I'll close PR per explanation I've wrote on the issue #1183
I think we have a good plan to move forward, but it will take us some time to get there.

@mbezuljTT mbezuljTT closed this Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants