Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Replace EliminateExceptions lowering pass #1859

Merged
merged 1 commit into from
Jul 15, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Apr 25, 2023

Description

  • Torch-TensorRT has evaluator support for exception elimination and this particular lowering pass causes certain models to halt compilation indefinitely

  • Modify this lowering pass to use a safer node-replacement system via direct node input replacement

  • Add new EliminateExceptionsSafe lowering pass, which has functionally the same task as that of EliminateExceptions, but with a safer replacement scheme

  • Update EliminateExceptions to use direct node input replacement instead of replaceAllUsesWith to avoid issue with invalid IR causing program halting

  • Add testing for new lowering pass

The torch::jit::EliminateExceptions lowering pass is listed as "risky" in the documentation, and as per #1823, can also cause compilation to halt on certain graphs. We should remove this pass while awaiting a more robust solution through the implementation of #1842 (giving the user the option to eliminate exceptions more effectively in the graph via a compile-time flag). Torch-TRT already has an evaluator for prim::RaiseException, and can thus handle these operators:

.evaluator(
{c10::Symbol::fromQualString("prim::RaiseException"),
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto exception = args.at(n->input(0)).IValue();
TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception);
return {};
}});

Note: This PR does not remove the existing EliminateExceptionOrPassPattern lowering pass, as that pass is intended for a more specialized exception case.

Addresses #1823
Relevant to #1842

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Apr 25, 2023
@github-actions github-actions bot added component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes labels Apr 25, 2023
@gs-olive gs-olive changed the title fix: Remove eliminate exceptions lowering pass fix: Remove EliminateExceptions lowering pass Apr 25, 2023
@gs-olive gs-olive requested a review from bowang007 April 25, 2023 19:27
@github-actions github-actions bot requested a review from peri044 April 25, 2023 19:27
@gcuendet
Copy link
Contributor

gcuendet commented May 1, 2023

That seems to help in certain cases in which the compilation was hanging, however, I am also now seeing some

"Expected Tensor but got Uninitialized"

error. Looking at #1282 , it seems that adding torch::jit::EliminateExceptions was the solution to such errors...

My understanding of the interactions of these different pieces is limited, but I am wondering whether removing the call to torch::jit::EliminateExceptions is not just fixing one thing and breaking (again) another one.

@gs-olive
Copy link
Collaborator Author

gs-olive commented May 1, 2023

@gcuendet - Thank you for the follow-up. We are looking into the Expected Tensor but got Uninitialized error simultaneously with this change, as it seems the usage of torch::jit::EliminateExceptions is more of a temporary fix for that problem (see: #1785 (review)).

Do you have any examples of models you are willing to share which fail with the Uninitialized error, when the torch::jit::EliminateExceptions pass is removed?

@gcuendet
Copy link
Contributor

gcuendet commented May 3, 2023

Thanks for the answer @gs-olive . I am trying to simplify our failing model to something minimal that I can share, but I am not sure I'll be able to 😞 . If I do I'll share, of course!

@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label May 3, 2023
@gs-olive gs-olive force-pushed the eliminate_exceptions_removal branch from 33f0fa5 to 972a717 Compare May 4, 2023 05:37
@gs-olive
Copy link
Collaborator Author

gs-olive commented May 4, 2023

Hi @gcuendet - I found the issue in torch::jit::EliminateExceptions originates from the calls to replaceAllUsesWith, and I've addressed this by replacing it with the safer alternative replaceFirstUseWith. This resolves the issue on the nn.Upsample case on my machine. Please let me know if it resolves that case while still retaining the benefits of the original EliminateExceptions on your models.

@gcuendet
Copy link
Contributor

gcuendet commented May 4, 2023

Hi @gs-olive . Thanks for the thorough investigation and the nice finding. I tried the following:

terminate called after throwing an instance of 'c10::Error'
  what():  Expected Tensor but got Uninitialized

More specifically, I get the following (truncated) output:

output
DEBUG: [Torch-TensorRT] - Resolving non-tensor inputs for segmented blocks
DEBUG: [Torch-TensorRT] - Registering input/output torch::jit::Value for segmented graphs
DEBUG: [Torch-TensorRT] - Performing shape analysis for segmented blocks using static shapes for inputs

[...] # Here, I see a bunch of "Running shape analysis on block Segment Block @x" until the last one below:

GRAPH: [Torch-TensorRT] - Running shape analysis on block Segment Block @5:
    Target: Torch

    Graph: graph(%X.1 : Tensor,
      %8 : int,
      %18 : Tensor):
  %17 : str = prim::Constant[value="bilinear"]()
  %16 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]() 
  %14 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]()
  %13 : int = prim::Constant[value=5]()
  %11 : str = prim::Constant[value="builtins.NotImplementedError"]()
  %10 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]()
  %9 : int = prim::Constant[value=3]()
  %6 : float[] = prim::Constant[value=[2., 2.]]()
  %5 : bool = prim::Constant[value=0]()
  %4 : NoneType = prim::Constant()
  %1 : bool = prim::Constant[value=1]()
  %out1.1 : Tensor = prim::If(%1)
    block0():
      %2 : Tensor = aten::upsample_bilinear2d(%X.1, %4, %5, %6)
      -> (%2)
    block1():
      %7 : bool = aten::eq(%8, %9)
       = prim::If(%7)
        block0():
           = prim::RaiseException(%10, %11)
          -> ()
        block1():
          -> ()
      %12 : bool = aten::eq(%8, %13)
       = prim::If(%12)
        block0():
           = prim::RaiseException(%14, %11)
          -> ()
        block1():
          -> ()
      %15 : str = aten::format(%16, %8, %17)
       = prim::RaiseException(%15, %11)
      -> (%18)
  return (%out1.1)


terminate called after throwing an instance of 'c10::Error'
  what():  Expected Tensor but got Uninitialized
Exception raised from reportToTensorTypeError at /home/ubuntu/buildAgent/temp/buildTmp/conan_home/.conan/data/libtorch/1.13.1-0/cognex/stable/build/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/source_subfolder/aten/src/ATen/core/ivalue.cpp:942 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7f9c4ba3936b in /mnt/caches/conan/data/libtorch/1.13.1-0/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xce (0x7f9c4ba34d3e in /mnt/caches/conan/data/libtorch/1.13.1-0/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libc10.so)
frame #2: c10::IValue::reportToTensorTypeError() const + 0x64 (0x7f9c4ebcb394 in /mnt/caches/conan/data/libtorch/1.13.1-0/cognex/stable/package/b8fcba865a68ae644e8a8bfa868ef00d13dc8a17/lib/libtorch_cpu.so)
frame #3: torch_tensorrt::core::partitioning::getSegmentsOutputByRunning(torch_tensorrt::core::partitioning::SegmentedBlock&, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt::core::part
itioning::PartitioningInfo const&, torch_tensorrt::core::ir::ShapeMode const&) + 0x106f (0x7f9c561393ff in /mnt/caches/conan/data/torch-tensorrt/972a7172917f69d0b961d5bfa94deb4d4a920c7a-0/cognex/olive-fix/package/504a9adcecd4971b484b990838fccfc5a2dc72bc/lib/libtorchtrt.so)
frame #4: torch_tensorrt::core::partitioning::runShapeAnalysis(torch_tensorrt::core::partitioning::PartitioningCtx*, torch::jit::Block*, std::unordered_map<torch::jit::Value const*, c10::IValue, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, c10::IValue> > >&, torch_tensorrt:
:core::ir::ShapeMode const&) + 0x239 (0x7f9c5613b6b9 in /mnt/caches/conan/data/torch-tensorrt/972a7172917f69d0b961d5bfa94deb4d4a920c7a-0/cognex/olive-fix/package/504a9adcecd4971b484b990838fccfc5a2dc72bc/lib/libtorchtrt.so)
frame #5: torch_tensorrt::core::partitioning::partition(torch_tensorrt::core::partitioning::PartitioningCtx*, bool) + 0xfa4 (0x7f9c56130984 in /mnt/caches/conan/data/torch-tensorrt/972a7172917f69d0b961d5bfa94deb4d4a920c7a-0/cognex/olive-fix/package/504a9adcecd4971b484b990838fccfc5a2dc72bc/lib/libtorchtrt.so)
frame #6: torch_tensorrt::core::BuildHybridGraph(torch::jit::Module&, torch::jit::Block*, torch_tensorrt::core::CompileSpec, std::map<torch::jit::Value*, c10::IValue, std::less<torch::jit::Value*>, std::allocator<std::pair<torch::jit::Value* const, c10::IValue> > >, std::unordered_map<torch::jit::Value const*, std::vector<c10::optional<c10::ScalarType>, std::all
ocator<c10::optional<c10::ScalarType> > >, std::hash<torch::jit::Value const*>, std::equal_to<torch::jit::Value const*>, std::allocator<std::pair<torch::jit::Value const* const, std::vector<c10::optional<c10::ScalarType>, std::allocator<c10::optional<c10::ScalarType> > > > > >, bool) + 0x2a8 (0x7f9c56163d88 in /mnt/caches/conan/data/torch-tensorrt/972a7172917f69
d0b961d5bfa94deb4d4a920c7a-0/cognex/olive-fix/package/504a9adcecd4971b484b990838fccfc5a2dc72bc/lib/libtorchtrt.so)
frame #7: torch_tensorrt::core::CompileGraph(torch::jit::Module const&, torch_tensorrt::core::CompileSpec) + 0x1797 (0x7f9c56167fb7 in /mnt/caches/conan/data/torch-tensorrt/972a7172917f69d0b961d5bfa94deb4d4a920c7a-0/cognex/olive-fix/package/504a9adcecd4971b484b990838fccfc5a2dc72bc/lib/libtorchtrt.so)
frame #8: torch_tensorrt::torchscript::compile(torch::jit::Module const&, torch_tensorrt::torchscript::CompileSpec) + 0xff (0x7f9c55fea45f in /mnt/caches/conan/data/torch-tensorrt/972a7172917f69d0b961d5bfa94deb4d4a920c7a-0/cognex/olive-fix/package/504a9adcecd4971b484b990838fccfc5a2dc72bc/lib/libtorchtrt.so)
frame #9: <unknown function> + 0x3da3 (0x565009113da3 in ./build/bin/interpolate_tensorrt)
frame #10: __libc_start_main + 0xf3 (0x7f9c1382f083 in /lib/x86_64-linux-gnu/libc.so.6)
frame #11: <unknown function> + 0x3f1e (0x565009113f1e in ./build/bin/interpolate_tensorrt)

[1]    3241346 abort (core dumped)  ./build/bin/interpolate_tensorrt torchscripts/upsample_script_pytorch1.13.pt

For reference, here is also the full lowered graph:

Lowered graph
INFO: [Torch-TensorRT] - Lowered Graph: graph(%X.1 : Tensor):
  %2 : bool = prim::Constant[value=0]()
  %3 : float[] = prim::Constant[value=[2., 2.]]()
  %4 : str = prim::Constant[value="bilinear"]() 
  %5 : int = prim::Constant[value=5]()
  %6 : int = prim::Constant[value=3]()
  %7 : int = prim::Constant[value=4]()
  %8 : int = prim::Constant[value=2]()
  %9 : str = prim::Constant[value="Input and scale_factor must have the same number of spatial dimensions, but got input with spatial dimensions of {} and scale_factor of shape {}. Please provide input tensor in (N, C, d1, d2, ...,dK) format and scale_factor in (s1, s2, ...,sK) format."]()
  %10 : str = prim::Constant[value="Got 3D input, but bilinear mode needs 4D input"]()
  %11 : str = prim::Constant[value="Got 5D input, but bilinear mode needs 4D input"]()
  %12 : str = prim::Constant[value="Input Error: Only 3D, 4D and 5D input Tensors supported (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got {})"]()
  %13 : NoneType = prim::Constant()
  %14 : str = prim::Constant[value="builtins.ValueError"]()
  %15 : int = prim::Constant[value=1]()
  %16 : str = prim::Constant[value="builtins.NotImplementedError"]()
  %60 : Tensor = prim::Uninitialized() # :0:0
  %61 : int = aten::dim(%X.1)
  %dim.2 : int = aten::sub(%61, %8)
  %63 : bool = aten::ne(%8, %dim.2)
   = prim::If(%63)
    block0():
      %64 : int[] = aten::size(%X.1)
      %65 : int[] = aten::slice(%64, %8, %13, %15)
      %66 : int[] = aten::list(%65)
      %67 : str = aten::format(%9, %66, %3)
       = prim::RaiseException(%67, %14)
      -> ()
    block1():
      -> ()
  %68 : bool = aten::eq(%61, %7)
   = prim::If(%2)
    block0():
      %69 : int[] = aten::size(%X.1)
      %70 : int[] = aten::slice(%69, %8, %13, %15)
      %71 : int[] = aten::list(%70)
      %72 : str = aten::format(%9, %71, %3)
       = prim::RaiseException(%72, %14)
      -> ()
    block1():
      -> ()
   = prim::If(%68)
    block0():
      -> ()
    block1():
      %73 : bool = aten::eq(%61, %6)
       = prim::If(%73)
        block0():
           = prim::RaiseException(%10, %16)
          -> ()
        block1():
          -> ()
      %74 : bool = aten::eq(%61, %5)
       = prim::If(%74)
        block0():
           = prim::RaiseException(%11, %16)
          -> ()
        block1():
          -> ()
      %75 : str = aten::format(%12, %61, %4)
       = prim::RaiseException(%75, %16)
      -> ()
  %out1.1 : Tensor = prim::If(%87)
    block0():
      %27 : Tensor = aten::upsample_bilinear2d(%X.1, %13, %2, %3)
      -> (%27)
    block1():
      %83 : bool = aten::eq(%61, %6)
       = prim::If(%83)
        block0():
           = prim::RaiseException(%10, %16)
          -> ()
        block1():
          -> ()
      %84 : bool = aten::eq(%61, %5)
       = prim::If(%84)
        block0():
           = prim::RaiseException(%11, %16)
          -> ()
        block1():
          -> ()
      %85 : str = aten::format(%12, %61, %4)
       = prim::RaiseException(%85, %16)
      -> (%60)
  %87 : bool = prim::Constant[value=1]()
  return (%out1.1)

From my limited understanding, and looking at it naively, it seems that the top level node %out1.1 : Tensor = prim::If(%87) (in the full lowered graph) is the cause of the problem:
The "condition False" block1() of that prim::If node will never be executed, since %87 : bool = prim::Constant[value=1]() is a constant True value but it still needs to define its output so that it's compatible with the output of prim::If, in that case a Tensor, and when checking its dimensions, it fails since that Tensor is just a %60 : Tensor = prim::Uninitialized() # :0:0 (cf. full lowered graph)...

And that's probably where a proper solution for #1842 would be useful. I.e. if that top level prim::If node was removed and replaced by just the content of block0(), in that case the actual aten::upsample_bilinear2d computation, then there would be no %60 : Tensor = prim::Uninitialized() left to make the shape analysis fail. Is that correct?

@gs-olive
Copy link
Collaborator Author

gs-olive commented May 4, 2023

Thank you for the detailed response! I agree with your statement and assessment that a proper solution for #1842 would address this issue. Another alternative is to support nodes such as %60 : Tensor = prim::Uninitialized(), since these can also show up in other contexts and should not fail. @bowang007 is working on something similar now, and that fix, in conjunction with this PR should help to resolve the issues relating to exceptions and uninitialized tensors, while #1842 is pending.

@gcuendet
Copy link
Contributor

gcuendet commented May 5, 2023

Do you plan to open a PR in pytorch with that fix?

@gs-olive
Copy link
Collaborator Author

gs-olive commented May 5, 2023

I intend to open an issue with PyTorch first to assess whether the proposed fix is a viable solution to the issue and whether the intended behavior of that function is retained, then I will open a PR for it.

Update: Filed pytorch/pytorch#100730

@gs-olive gs-olive changed the title fix: Remove EliminateExceptions lowering pass fix: Replace EliminateExceptions lowering pass Jun 1, 2023
@gs-olive gs-olive force-pushed the eliminate_exceptions_removal branch from 972a717 to b6c3b2b Compare June 1, 2023 18:59
@github-actions github-actions bot added the component: tests Issues re: Tests label Jun 1, 2023
@gs-olive gs-olive removed component: tests Issues re: Tests WIP Work is in progress, pull request should not be merged yet labels Jun 1, 2023
@gs-olive gs-olive requested review from narendasan and removed request for peri044 June 1, 2023 19:00
@gs-olive
Copy link
Collaborator Author

gs-olive commented Jun 1, 2023

Hi @gcuendet - I've made some updates to this lowering pass and added a few test cases. I have also verified that for the case in #1823, the model no longer halts mid-compilation (and compiled successfully on my machine, using this PR). Could you please verify on the model with the new updates as well?

if (certainlyThrows(true_block)) {
n->input(0)->replaceAllUsesDominatedByNodeWith(n, false_const);
} else if (certainlyThrows(false_block)) {
n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const);
Copy link
Collaborator

@bowang007 bowang007 Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not make sense.
If you check what's actually going on for prim::If here, you will find that:

  1. every if node has some blocks and attributes
  2. for if node, IfNode->input(0) is actually the condition

So, in PyTorch's pass implementation here, what's happening is that if will go through the 2 blocks in prim::If, if one of that block contains a prim::RaiseException node, it will RESET the prim::If input(condition) to enforce it to run another flow.
For example, if the original condition is True it should go to the first block(true block) but it finds out that the first block contains a prim::RaiseException node, it rewrites the condition to be False to make sure that it won't go to the first block to avoid the exception.

Moreover, if you go to the definition of the function replaceAllUsesDominatedByNodeWith here, you will find that this actually replacing the value dominated by the node. In graph theory, A dominate B means that A flows to B. For example, node A dominate value x means that in the graph it must flow from node A to value x. So, this line n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const); is going to replace the prim::If condition value which is dominated by that prim::If node. However, prim::If condition value dominate that prim::If node, in other words, it's replacing nothing here because n->input(0) is never dominated by that n node (if node's condition flows to if node).
So this line is basically doing nothing here.
Above is my understanding, please correct me if I'm wrong.

Copy link
Collaborator Author

@gs-olive gs-olive Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bowang007 - thank you for the detailed review and comment. I agree with your observation about the replacement of all uses of the condition dominated by n as effectively doing nothing here. I've run + added some additional tests to the PR and I have now updated the call to replaceFirstUseWith, which I believe is the correct replacement in this case.

Based on your comments, I now think the reason for which the original halting bug in #1823 was occurring is the following:

%out1.1 : Tensor = prim::If(%71)
    block0():
      ...
      -> (%73)
    block1():
         ...
        block0():
          ...
          -> ()
        block1():
          -> ()
        ...
       = prim::RaiseException(%76, %59)
      -> (%63)
  %out2.1 : Tensor = prim::If(%71)
    block0():
      ...
      -> (%82)
    block1():
      ...

In the graph above, the condition %71 appears as the predicate of two distinct prim::If nodes. When we call replaceAllUsesWith, as torch::jit::passes::EliminateExceptions does, this replaces downstream prim::If users of the same condition, and seems to cause an infinite loop here. By replacing only the first use of the condition, however, this would ensure only the condition of the prim::If being directly analyzed will be replaced, avoiding this loop. What do you think of the new changes?

I've verified the new implementation is functionally equivalent to torch::jit::passes::EliminateExceptions on the samples in the test cases, and does not halt for #1823.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @gs-olive thanks for your detailed reply. Several items I want to discuss with you and @narendasan :

  1. I'm not so sure if we should have this pass. According to the explanation for this pass here
    , it also mentions that This pass is illegal in general case as the modified graph might not throw an exception that the original graph would throw. Generally speaking, this pass cleans up all prim::If branches that throws outs an exception error, so looks like this would produce some unexpected result if users pass in some input that is invalid.
  2. This exceptionElimination pass was introduced to fix the uninitialized error when we were dealing with customer bugs, see this PR, since we have a better fix for that uninitialized type bug Fix: fix the bug that uninitialized tensor cannot be found #1933, maybe it's fine to remove that pass.
  3. replacing only the first use of the condition seems a little bit hacky since we never know where that condition value is firstly used. For example, for this graph:
%66 : bool = aten::ne(%48, %dim.2)
%67 : some operations consumes %66
   = prim::If(%66) 
  block0: 
  ...
  block1: 
... 

it's not actually doing what we want.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too much to add here, the discussion seems accurate. Here is the official documentation on block and conditional semantics in PyTorch https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/OVERVIEW.md#block

@gs-olive gs-olive force-pushed the eliminate_exceptions_removal branch from b6c3b2b to 3c2e02e Compare June 7, 2023 03:24
@github-actions github-actions bot added the component: tests Issues re: Tests label Jun 7, 2023
@@ -169,3 +173,189 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
}
EXPECT_EQ(1, if_count);
}

TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these test cases, what's the difference between the graph before and after that pass?

Copy link
Collaborator Author

@gs-olive gs-olive Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that the graph will be collapsed since the exception will be removed from the graph. This is the same as the graph effect of the torch::jit::EliminateExceptions lowering pass. For instance, the change would be:

Original Graph:

    graph(%x, %y):
      %dim : int = aten::dim(%x)
      %48 : int = prim::Constant[value=2]()
      %66 : bool = aten::eq(%48, %dim)
      %45 : str = prim::Constant[value="EXCEPTION"]()
      %4 : Tensor = prim::If(%66)
        block0():
          = prim::RaiseException(%45)
          -> (%x)
        block1():
          %res = aten::mul(%x, %y)
          -> (%res)
      return (%4)

New Graph:

    graph(%x : Tensor,
          %y : Tensor):
      %6 : Tensor = aten::mul(%x, %y)
      return (%6)

if (certainlyThrows(true_block)) {
n->input(0)->replaceAllUsesDominatedByNodeWith(n, false_const);
} else if (certainlyThrows(false_block)) {
n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @gs-olive thanks for your detailed reply. Several items I want to discuss with you and @narendasan :

  1. I'm not so sure if we should have this pass. According to the explanation for this pass here
    , it also mentions that This pass is illegal in general case as the modified graph might not throw an exception that the original graph would throw. Generally speaking, this pass cleans up all prim::If branches that throws outs an exception error, so looks like this would produce some unexpected result if users pass in some input that is invalid.
  2. This exceptionElimination pass was introduced to fix the uninitialized error when we were dealing with customer bugs, see this PR, since we have a better fix for that uninitialized type bug Fix: fix the bug that uninitialized tensor cannot be found #1933, maybe it's fine to remove that pass.
  3. replacing only the first use of the condition seems a little bit hacky since we never know where that condition value is firstly used. For example, for this graph:
%66 : bool = aten::ne(%48, %dim.2)
%67 : some operations consumes %66
   = prim::If(%66) 
  block0: 
  ...
  block1: 
... 

it's not actually doing what we want.

@gs-olive
Copy link
Collaborator Author

Hi @bowang007 - thanks for the comment, this makes sense.

  1. I agree that the pass can be risky in the sense that it removes exceptions. The main rationale for removing these exceptions would be that our converters already handle the edge cases which these exceptions are checking for, so we shouldn't need the extra logic and conditionals they bring with them. For instance, the conditional shown in the comment above could cause unnecessary fallback which can affect performance. TRT and our converters already apply a large number of input validation checks to address these cases.
  2. Removing the torch::jit::EliminateExceptions pass entirely could also be a viable solution to resolve the compilation halting here, I agree.
  3. This is a good point, thanks for bringing this up. I've updated the implementation to just directly modify the prim::If node's input, to address the case you mentioned.

@gs-olive gs-olive force-pushed the eliminate_exceptions_removal branch from e15be6b to 879da55 Compare July 10, 2023 21:42
@gs-olive gs-olive force-pushed the eliminate_exceptions_removal branch from 879da55 to 047b8b4 Compare July 10, 2023 22:08
Comment on lines +142 to +160
Block* true_block = n->blocks()[0];
Block* false_block = n->blocks()[1];
bool removed_exception = false;
Value* input_value_replacement;

// If the block throws an exception, replace input with logical opposite
if (certainlyThrows(true_block)) {
removed_exception = true;
input_value_replacement = false_const;
} else if (certainlyThrows(false_block)) {
removed_exception = true;
input_value_replacement = true_const;
}

// Log node and perform input replacement
if (removed_exception) {
LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n));
n->insertInput(0, input_value_replacement);
n->removeInput(1);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly modifies inputs to the prim::If node only, to avoid incorrect boolean modifications elsewhere.

Logs a warning informing the user that an exception was automatically removed from the TorchScript IR.

@gs-olive gs-olive requested a review from bowang007 July 10, 2023 23:22
@narendasan
Copy link
Collaborator

Have we done testing on what happens if we just remove the pass? iirc, we treat exceptions as compile time errors though the evaluator, the original purpose was to have less nops in the graph and remove unnecessary conditionals

@gs-olive
Copy link
Collaborator Author

I haven't done specific testing (aside from CI) on removing the pass, but I would expect a performance degradation on models which have conditional exception logic, since the prim::If would cause fallback in the graph, increasing graph segmentation.

- Remove Torch `EliminateExceptions` lowering pass
- Add new `EliminateExceptionsSafe` lowering pass, which has
functionally the same task as that of `EliminateExceptions`, but with a
safer replacement scheme
- Update EliminateExceptions to use direct node input replacement
instead of `replaceAllUsesWith` to avoid issue with invalid IR causing
program halting
- Add testing for new lowering pass
@gs-olive gs-olive force-pushed the eliminate_exceptions_removal branch from 047b8b4 to 5f6214c Compare July 14, 2023 21:35
Copy link
Collaborator

@bowang007 bowang007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gs-olive gs-olive merged commit d4c9c06 into pytorch:main Jul 15, 2023
@gs-olive gs-olive deleted the eliminate_exceptions_removal branch July 15, 2023 01:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants