diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index e24dcc08d662..759fa22843cc 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -924,8 +924,8 @@ struct ProviderHost { #endif #if defined(USE_CUDA) || defined(USE_ROCM) - virtual PhiloxGenerator& PhiloxGenerator__Default() = 0; +#endif #ifdef ENABLE_TRAINING_TORCH_INTEROP virtual void contrib__PythonOpBase__Init(contrib::PythonOpBase* p, const OpKernelInfo& info) = 0; @@ -940,7 +940,6 @@ struct ProviderHost { virtual language_interop_ops::torch::RefCountTracker& GetRefCountTrackerInstance() = 0; virtual void RefCountTracker__DumpDetails(const language_interop_ops::torch::RefCountTracker* p, const std::string& phase_name) = 0; #endif -#endif #if defined(USE_CANN) virtual RandomGenerator& RandomGenerator__Default() = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index c79de8105c03..6b2fad8441d8 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1068,8 +1068,8 @@ struct ProviderHostImpl : ProviderHost { #endif #if defined(USE_CUDA) || defined(USE_ROCM) - PhiloxGenerator& PhiloxGenerator__Default() override { return PhiloxGenerator::Default(); } +#endif #ifdef ENABLE_TRAINING_TORCH_INTEROP void contrib__PythonOpBase__Init(contrib::PythonOpBase* p, const OpKernelInfo& info) override { p->PythonOpBase::Init(info); } @@ -1092,7 +1092,6 @@ struct ProviderHostImpl : ProviderHost { return p->language_interop_ops::torch::RefCountTracker::DumpDetails(phase_name); } #endif -#endif #if defined(USE_CANN) RandomGenerator& RandomGenerator__Default() override { return RandomGenerator::Default(); } diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py index 4f2ec745199b..718ee84cf72a 100644 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py @@ -18,9 +18,7 @@ import torch.onnx import torch.onnx._onnx_supported_ops from torch._decomp import decomposition_table -from torch._dynamo.utils import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensor -from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport @@ -182,8 +180,8 @@ def _get_support_dictionaries_and_decomposition_tables() -> ( ( _SUPPORT_DICT, _EXTRA_SUPPORT_DICT, - _ATEN2ATEN_DECOMP, - _ATEN2PRIM_DECOMP, + ATEN2ATEN_DECOMP, + ATEN2PRIM_DECOMP, ) = _get_support_dictionaries_and_decomposition_tables() @@ -628,15 +626,8 @@ def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphMod if graph_module in self._partitioner_cache: partitioned_prim_graph_module = self._partitioner_cache[graph_module] else: - prim_graph_module = make_fx( - graph_module, tracing_mode="fake", _allow_non_fake_inputs=True, decomposition_table=_ATEN2ATEN_DECOMP - )(*args) + prim_graph_module = graph_module # TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to. - # We need input and output tensors' devices to decide if aten::_to_copy is just a Cast. - fake_mode = detect_fake_mode(args) - if not fake_mode: - fake_mode = torch._subclasses.FakeTensorMode() - FakeTensorProp(prim_graph_module, mode=fake_mode).propagate(*args) _replace_to_copy_with_to(prim_graph_module) partitioner = CapabilityBasedPartitioner( prim_graph_module, self._supported_ops, allows_single_node_partition=False diff --git a/orttraining/orttraining/python/training/torchdynamo/register_backend.py b/orttraining/orttraining/python/training/torchdynamo/register_backend.py index 6f6c0f6575b0..ae9a1522a354 100644 --- a/orttraining/orttraining/python/training/torchdynamo/register_backend.py +++ b/orttraining/orttraining/python/training/torchdynamo/register_backend.py @@ -6,7 +6,7 @@ from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd -from .ort_backend import OrtBackend +from .ort_backend import ATEN2ATEN_DECOMP, OrtBackend # This should be the underlying compiler for ALL graphs if # the user uses ORT to accelerate PyTorch via Dynamo. @@ -28,7 +28,9 @@ # compiled_model = torch._dynamo.optimize(aot_ort)(model) # result = compiled_model(torch.rand(2, 2, dtype=torch.float) # result.sum().backward() -aot_ort = aot_autograd(fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition) +aot_ort = aot_autograd( + fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition, decompositions=ATEN2ATEN_DECOMP +) # Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called). #