Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager mode #635

Merged
merged 1 commit into from
Mar 22, 2022
Merged

Eager mode #635

merged 1 commit into from
Mar 22, 2022

Conversation

makslevental
Copy link
Collaborator

@makslevental makslevental commented Mar 2, 2022

This PR implements an eager mode backend for PyTorch through the torch-mlir framework. This is accomplished by overriding the __torch_dispatch__ class method on wrapper subclass TorchMLIRTensor(torch.Tensor).

Effectively, this mode works by compiling op by op as the NN is eagerly executed by PyTorch. Entailed in that compilation is building a representation of the op that can be torch.jit.scripted, importing using ModuleBuilder, and then executing (e.g., with RefBackendLinalgOnTensorsBackend). This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported op).

Currently, all e2e tests pass execpt for two that involve an upstream PyTorch bug (pytorch/pytorch#74400).

High priority next steps:

  1. A compile cache in order to speed up reruns of the same NN.
  2. Integration with IREE (though not in this repo).
  3. Integration with torch.distributed.

@makslevental makslevental marked this pull request as ready for review March 2, 2022 23:03
@makslevental makslevental mentioned this pull request Mar 2, 2022
@silvasean
Copy link
Contributor

silvasean commented Mar 2, 2022

Can you add a test config for the e2e test framework?

https://github.com/llvm/torch-mlir/tree/main/python/torch_mlir_e2e_test/torchscript/configs

It should be possible to pass all the tests if you implement the fallback correctly.

Copy link
Contributor

@silvasean silvasean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by comment. will need a few passes here on the review.

python/torch_mlir/eager/torch_dispatch.py Outdated Show resolved Hide resolved
@makslevental makslevental force-pushed the eager_clean branch 11 times, most recently from 433330b to 47a6a3d Compare March 14, 2022 16:42
CMakeLists.txt Outdated Show resolved Hide resolved
e2e_testing/torchscript/xfail_sets.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
examples/lazytensor_tanh.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
@makslevental makslevental force-pushed the eager_clean branch 4 times, most recently from 5aecdb6 to 8a0be06 Compare March 14, 2022 19:02
python/torch_mlir/eager_mode/torch_mlir_types.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
@makslevental makslevental force-pushed the eager_clean branch 5 times, most recently from 8bb719a to dadb46f Compare March 22, 2022 16:41
Copy link
Contributor

@silvasean silvasean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly nits.

python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/ir_building.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
python/torch_mlir/eager_mode/torch_mlir_dispatch.py Outdated Show resolved Hide resolved
e2e_testing/torchscript/xfail_sets.py Outdated Show resolved Hide resolved
@makslevental makslevental force-pushed the eager_clean branch 3 times, most recently from bbc8e27 to cf86e01 Compare March 22, 2022 19:05
@makslevental makslevental changed the title [TORCH][MLIR] Eager mode Eager mode Mar 22, 2022
@makslevental makslevental requested a review from silvasean March 22, 2022 19:07
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend


class TorchMLIRTensor(torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any upstream documentation describing the extension point you are using here? (_make_wrapper_subclass/`_torch_dispatch``, etc.) It would be good to link it in if it exists.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spoke with brian hirsch and he said the best documentation for how to use wrapper_subclass is https://github.com/albanD/subclass_zoo. I can add as a comment.

…h-mlir framework. This is accomplished by overriding the `__torch_dispatch__` class method on wrapper subclass `TorchMLIRTensor(torch.Tensor)`.

Effectively, this mode works by compiling op by op as the NN is eagerly executed by PyTorch. Entailed in that compilation is building a representation of the op that can be `torch.jit.script`ed, importing using `ModuleBuilder`, and then executing (e.g., with `RefBackendLinalgOnTensorsBackend`). This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported op).

Currently, all e2e tests pass execpt for two that involve an upstream PyTorch bug (pytorch/pytorch#74400).

High priority next steps:

1. A compile cache in order to speed up reruns of the same NN.
2. Integration with IREE (though not in this repo).
3. Integration with `torch.distributed`.
@silvasean silvasean merged commit fe8ac57 into llvm:main Mar 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants