Skip to content

Commit

Permalink
aot_autograd: factor out runtime epilogue from aot_dispatch_base (#10…
Browse files Browse the repository at this point in the history
…0586)

Pull Request resolved: #100586
Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 15, 2023
1 parent a4830bd commit bba12a4
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def call_func_with_args(f, args, steal_args=False, disable_amp=False):
del guard
return out

def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
def aot_dispatch_base_graph(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
# The cases that aot_dispatch_base doesn't need to handle include:
# - outputs that are aliases of graph intermediates
Expand Down Expand Up @@ -1366,6 +1366,11 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
if aot_config.enable_log:
aot_graphs_log.info("%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id))

return fw_module

def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
fw_module = aot_dispatch_base_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)

disable_amp = torch._C._is_any_autocast_enabled()
context = disable_autocast_manager if disable_amp else nullcontext

Expand Down

0 comments on commit bba12a4

Please sign in to comment.