-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add graph capture validation pass #1195
Conversation
@@ -450,6 +457,7 @@ def _execute(binaries): | |||
program_index, | |||
total_inputs[loop], | |||
total_outputs[loop], | |||
self["--use-graph-capture"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add this in trt.runtime.DebugEnv? It seems like this feature is a debug mode feature and the infra is already there to add specific bools like this one
debug_env = ttrt.runtime.DebugEnv.get(
self["--load-kernels-from-disk"], self["--enable-async-ttnn"]
)
const std::unordered_set<uint32_t> &programInputs, | ||
const std::unordered_set<uint32_t> &programOutputs, | ||
::ttnn::MeshDevice *meshDevice, bool useGraphCapture) { | ||
if (useGraphCapture) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be something like debug::Env::get().useGraphCapture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes inlined
// If this option is true, run the entire graph with graph capture to validate | ||
// it. | ||
// | ||
Option<bool> graphCaptureValidationEnabled{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You lack a test using this option.
|
||
void outputTTNNIRFile(const std::string &mlirFilePath) { | ||
ModuleOp module = getOperation(); | ||
std::error_code _ec; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens in case of failure?
const std::string outReportPath = tmpDirPath / "module_graph_capture.json"; | ||
|
||
outputTTNNIRFile(mlirFilePath); | ||
outputFlatBufferFile(mlirFilePath, flatBufferFilePath); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are failures in each stage handled, how do you know you are not reading artefact of some previous compile session?
const std::string &outReportFilePath) { | ||
// TODO(mbezulj): Add required env variable to be able to run graph capture | ||
// with mockup device and without kernel compilation. | ||
const std::string cmd = "ttrt run " + flatBufferFilePath + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have something like std::format to allow plugin of parameters within const string?
outputFlatBufferFile(mlirFilePath, flatBufferFilePath); | ||
runGraphCapture(flatBufferFilePath, outReportPath); | ||
|
||
if (!isValidGraphCaptureReport(outReportPath)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IsValid needs to return more than bool. It needs to be loc<->op type mapping to -> exception type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then handlers need to be written for every exception type.
|
||
try { | ||
runOperation(op); | ||
} catch (const std::exception &ex) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this leave device in a bad state? Can it leak memory?
"./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath + | ||
" -o " + flatBufferFilePath; | ||
|
||
system(cmd.c_str()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should look to replace system calls with proper API calls, like Forge does
// Generate binary from the MLIR module. auto binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());
Thanks everyone for great feedback on the implementation! I'll close PR per explanation I've wrote on the issue #1183 |
Implement a mlir pass which performs graph capture validation of output TTNN graph. By default, this pass is disabled. #1183
fyi @mbezuljTT @s-jovic