-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Replace EliminateExceptions lowering pass #1859
Conversation
That seems to help in certain cases in which the compilation was hanging, however, I am also now seeing some
error. Looking at #1282 , it seems that adding My understanding of the interactions of these different pieces is limited, but I am wondering whether removing the call to |
@gcuendet - Thank you for the follow-up. We are looking into the Do you have any examples of models you are willing to share which fail with the Uninitialized error, when the |
Thanks for the answer @gs-olive . I am trying to simplify our failing model to something minimal that I can share, but I am not sure I'll be able to 😞 . If I do I'll share, of course! |
33f0fa5
to
972a717
Compare
Hi @gcuendet - I found the issue in |
Hi @gs-olive . Thanks for the thorough investigation and the nice finding. I tried the following:
More specifically, I get the following (truncated) output: output
For reference, here is also the full lowered graph: Lowered graph
From my limited understanding, and looking at it naively, it seems that the top level node And that's probably where a proper solution for #1842 would be useful. I.e. if that top level |
Thank you for the detailed response! I agree with your statement and assessment that a proper solution for #1842 would address this issue. Another alternative is to support nodes such as |
Do you plan to open a PR in pytorch with that fix? |
I intend to open an issue with PyTorch first to assess whether the proposed fix is a viable solution to the issue and whether the intended behavior of that function is retained, then I will open a PR for it. Update: Filed pytorch/pytorch#100730 |
972a717
to
b6c3b2b
Compare
Hi @gcuendet - I've made some updates to this lowering pass and added a few test cases. I have also verified that for the case in #1823, the model no longer halts mid-compilation (and compiled successfully on my machine, using this PR). Could you please verify on the model with the new updates as well? |
if (certainlyThrows(true_block)) { | ||
n->input(0)->replaceAllUsesDominatedByNodeWith(n, false_const); | ||
} else if (certainlyThrows(false_block)) { | ||
n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not make sense.
If you check what's actually going on for prim::If
here, you will find that:
- every if node has some blocks and attributes
- for if node,
IfNode->input(0)
is actually the condition
So, in PyTorch's pass implementation here, what's happening is that if will go through the 2 blocks in prim::If
, if one of that block contains a prim::RaiseException
node, it will RESET the prim::If
input(condition) to enforce it to run another flow.
For example, if the original condition is True
it should go to the first block(true block) but it finds out that the first block contains a prim::RaiseException
node, it rewrites the condition to be False
to make sure that it won't go to the first block to avoid the exception.
Moreover, if you go to the definition of the function replaceAllUsesDominatedByNodeWith
here, you will find that this actually replacing the value dominated by the node. In graph theory, A dominate B means that A flows to B. For example, node A dominate value x means that in the graph it must flow from node A to value x. So, this line n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const);
is going to replace the prim::If
condition value which is dominated by that prim::If
node. However, prim::If
condition value dominate that prim::If
node, in other words, it's replacing nothing here because n->input(0)
is never dominated by that n
node (if node's condition flows to if node).
So this line is basically doing nothing here.
Above is my understanding, please correct me if I'm wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bowang007 - thank you for the detailed review and comment. I agree with your observation about the replacement of all uses of the condition dominated by n
as effectively doing nothing here. I've run + added some additional tests to the PR and I have now updated the call to replaceFirstUseWith
, which I believe is the correct replacement in this case.
Based on your comments, I now think the reason for which the original halting bug in #1823 was occurring is the following:
%out1.1 : Tensor = prim::If(%71)
block0():
...
-> (%73)
block1():
...
block0():
...
-> ()
block1():
-> ()
...
= prim::RaiseException(%76, %59)
-> (%63)
%out2.1 : Tensor = prim::If(%71)
block0():
...
-> (%82)
block1():
...
In the graph above, the condition %71
appears as the predicate of two distinct prim::If
nodes. When we call replaceAllUsesWith
, as torch::jit::passes::EliminateExceptions
does, this replaces downstream prim::If
users of the same condition, and seems to cause an infinite loop here. By replacing only the first use of the condition, however, this would ensure only the condition of the prim::If
being directly analyzed will be replaced, avoiding this loop. What do you think of the new changes?
I've verified the new implementation is functionally equivalent to torch::jit::passes::EliminateExceptions
on the samples in the test cases, and does not halt for #1823.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @gs-olive thanks for your detailed reply. Several items I want to discuss with you and @narendasan :
- I'm not so sure if we should have this pass. According to the explanation for this pass here
, it also mentions thatThis pass is illegal in general case as the modified graph might not throw an exception that the original graph would throw.
Generally speaking, this pass cleans up allprim::If
branches that throws outs anexception
error, so looks like this would produce some unexpected result if users pass in some input that is invalid. - This
exceptionElimination
pass was introduced to fix theuninitialized error
when we were dealing with customer bugs, see this PR, since we have a better fix for thatuninitialized type
bug Fix: fix the bug that uninitialized tensor cannot be found #1933, maybe it's fine to remove that pass. replacing only the first use of the condition
seems a little bit hacky since we never know where that condition value is firstly used. For example, for this graph:
%66 : bool = aten::ne(%48, %dim.2)
%67 : some operations consumes %66
= prim::If(%66)
block0:
...
block1:
...
it's not actually doing what we want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not too much to add here, the discussion seems accurate. Here is the official documentation on block and conditional semantics in PyTorch https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/OVERVIEW.md#block
b6c3b2b
to
3c2e02e
Compare
3c2e02e
to
e15be6b
Compare
@@ -169,3 +173,189 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { | |||
} | |||
EXPECT_EQ(1, if_count); | |||
} | |||
|
|||
TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these test cases, what's the difference between the graph before and after that pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference is that the graph will be collapsed since the exception will be removed from the graph. This is the same as the graph effect of the torch::jit::EliminateExceptions
lowering pass. For instance, the change would be:
Original Graph:
graph(%x, %y):
%dim : int = aten::dim(%x)
%48 : int = prim::Constant[value=2]()
%66 : bool = aten::eq(%48, %dim)
%45 : str = prim::Constant[value="EXCEPTION"]()
%4 : Tensor = prim::If(%66)
block0():
= prim::RaiseException(%45)
-> (%x)
block1():
%res = aten::mul(%x, %y)
-> (%res)
return (%4)
New Graph:
graph(%x : Tensor,
%y : Tensor):
%6 : Tensor = aten::mul(%x, %y)
return (%6)
if (certainlyThrows(true_block)) { | ||
n->input(0)->replaceAllUsesDominatedByNodeWith(n, false_const); | ||
} else if (certainlyThrows(false_block)) { | ||
n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @gs-olive thanks for your detailed reply. Several items I want to discuss with you and @narendasan :
- I'm not so sure if we should have this pass. According to the explanation for this pass here
, it also mentions thatThis pass is illegal in general case as the modified graph might not throw an exception that the original graph would throw.
Generally speaking, this pass cleans up allprim::If
branches that throws outs anexception
error, so looks like this would produce some unexpected result if users pass in some input that is invalid. - This
exceptionElimination
pass was introduced to fix theuninitialized error
when we were dealing with customer bugs, see this PR, since we have a better fix for thatuninitialized type
bug Fix: fix the bug that uninitialized tensor cannot be found #1933, maybe it's fine to remove that pass. replacing only the first use of the condition
seems a little bit hacky since we never know where that condition value is firstly used. For example, for this graph:
%66 : bool = aten::ne(%48, %dim.2)
%67 : some operations consumes %66
= prim::If(%66)
block0:
...
block1:
...
it's not actually doing what we want.
Hi @bowang007 - thanks for the comment, this makes sense.
|
e15be6b
to
879da55
Compare
879da55
to
047b8b4
Compare
Block* true_block = n->blocks()[0]; | ||
Block* false_block = n->blocks()[1]; | ||
bool removed_exception = false; | ||
Value* input_value_replacement; | ||
|
||
// If the block throws an exception, replace input with logical opposite | ||
if (certainlyThrows(true_block)) { | ||
removed_exception = true; | ||
input_value_replacement = false_const; | ||
} else if (certainlyThrows(false_block)) { | ||
removed_exception = true; | ||
input_value_replacement = true_const; | ||
} | ||
|
||
// Log node and perform input replacement | ||
if (removed_exception) { | ||
LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n)); | ||
n->insertInput(0, input_value_replacement); | ||
n->removeInput(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly modifies inputs to the prim::If
node only, to avoid incorrect boolean modifications elsewhere.
Logs a warning informing the user that an exception was automatically removed from the TorchScript IR.
Have we done testing on what happens if we just remove the pass? iirc, we treat exceptions as compile time errors though the evaluator, the original purpose was to have less nops in the graph and remove unnecessary conditionals |
I haven't done specific testing (aside from CI) on removing the pass, but I would expect a performance degradation on models which have conditional exception logic, since the |
- Remove Torch `EliminateExceptions` lowering pass - Add new `EliminateExceptionsSafe` lowering pass, which has functionally the same task as that of `EliminateExceptions`, but with a safer replacement scheme - Update EliminateExceptions to use direct node input replacement instead of `replaceAllUsesWith` to avoid issue with invalid IR causing program halting - Add testing for new lowering pass
047b8b4
to
5f6214c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Torch-TensorRT has evaluator support for exception elimination and this particular lowering pass causes certain models to halt compilation indefinitely
Modify this lowering pass to use a safer node-replacement system via direct node input replacement
Add new
EliminateExceptionsSafe
lowering pass, which has functionally the same task as that ofEliminateExceptions
, but with a safer replacement schemeUpdate EliminateExceptions to use direct node input replacement instead of
replaceAllUsesWith
to avoid issue with invalid IR causing program haltingAdd testing for new lowering pass
The
torch::jit::EliminateExceptions
lowering pass is listed as "risky" in the documentation, and as per #1823, can also cause compilation to halt on certain graphs. We should remove this pass while awaiting a more robust solution through the implementation of #1842 (giving the user the option to eliminate exceptions more effectively in the graph via a compile-time flag). Torch-TRT already has an evaluator forprim::RaiseException
, and can thus handle these operators:TensorRT/core/conversion/evaluators/prim.cpp
Lines 337 to 343 in 09e15b2
Note: This PR does not remove the existing
EliminateExceptionOrPassPattern
lowering pass, as that pass is intended for a more specialized exception case.Addresses #1823
Relevant to #1842
Type of change
Checklist: