We have moved TorchDynamo to pytorch/pytorch
import torchdynamo
is nowimport torch._dynamo
import torchinductor
is nowimport torch._inductor
For instructions to port PRs over, or more details on the move see issue 1588.
This repository still contains:
- An alias to the new location
- Issues: we will continue using this project for issue tracking
- Documentation that needs to be ported over/updated
TorchDynamo makes it easy to experiment with different compiler backends to make PyTorch code faster with a single line decorator
torch._dynamo.optimize()
TorchDynamo supports arbitrary PyTorch code, control flow, mutation and dynamic shapes.
You can follow our nightly benchmarks here
TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with a customizable backend. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends to get the best of both worlds: usability and performance.
For more on TorchDynamo you can read our posts on PyTorch dev-discuss or watch a deep-dive video.
This repository also hosts TorchInductor, which is TorchDynamo backend able to translate an FX Graph into Triton for GPUs or C++/OpenMP for CPUs. We have a training performance dashboard comparing the performance of different training backends. You can read more in the TorchInductor post on PyTorch dev-discuss.
TorchDynamo is experimental and under active development. You are welcome to try it out and contribute, but should expect to find bugs and rough edges.
Python 3.8 is recommended. Python 3.7 through 3.10 are supported and tested. Make sure to have a development version of python installed locally as well.
TorchDynamo is included in the nightly binaries of PyTorch, for reference, https://pytorch.org/get-started/locally/
To use GPU back ends (and in particular Triton), please make sure that the cuda that you have installed locally matches the PyTorch version you are running.
pip3 install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"
For reference, the nightly version of GPU PyTorch, which includes GPU TorchDynamo, can be installed with (the command below requires CUDA 11.7)
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu117
There are no additional requirements for CPU TorchDynamo. CPU TorchDynamo is included in the nightly versions of PyTorch, which, for reference, can be installed with
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
Build PyTorch from source: https://github.com/pytorch/pytorch#from-source, which has TorchDynamo included.
You can run the following commands (from the PyTorch repo root directory) that run minimal examples to check that TorchDynamo is installed correctly:
cd tools/dynamo
python verify_dynamo.py
Here is a basic example of how to use TorchDynamo. You can decorate a function
or a method using torch._dynamo.optimize()
and pass in the name of a compiler e.g: inductor and your code will run faster.
@dynamo.optimize("inductor")
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
It's also easy to define your own compiler backends in pure python custom backend
TorchDynamo has a growing list of backends, which can be found in backends.py
or torchdynamo.list_backends()
each of which with its optional dependencies.
Some of the most commonly used backends are
Debugging backends:
dynamo.optimize("eager")
- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo issues.dynamo.optimize("aot_eager")
- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
Training & inference backends:
dynamo.optimize("inductor")
- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels Read moredynamo.optimize("nvfuser")
- nvFuser with TorchScript. Read moredynamo.optimize("aot_nvfuser")
- nvFuser with AotAutograd. Read moredynamo.optimize("aot_cudagraphs")
- cudagraphs with AotAutograd. Read more
Inference-only backends:
dynamo.optimize("ofi")
- Uses Torchscript optimize_for_inference. Read moredynamo.optimize("fx2trt")
- Uses Nvidia TensorRT for inference optimizations. Read moredynamo.optimize("onnxrt")
- Uses ONNXRT for inference on CPU/GPU. Read moredynamo.optimize("ipex")
- Uses IPEX for inference on CPU. Read more
torch.jit.trace()
is silently wrong if it cannot trace e.g: during control flowtorch.jit.script()
requires modifications to user or library code by adding type annotations and removing non PyTorch codetorch.fx.symbolic_trace()
either traces correctly or gives a hard error but it's limited to traceable code so still can't handle control flowtorch._dynamo
works out of the box and produces partial graphs. It still has the option of producing a single graph withnopython=True
which are needed for some situations but allows a smoother transition where partial graphs can be optimized without code modification
TorchDynamo has a BSD-style license, as found in the LICENSE file.