Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CUDA OpenXLA fallback. #7318

Merged
merged 27 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
091fb6b
Initial implementation.
ysiraichi Apr 4, 2024
1c8b479
Replace usage of `xla_cpu_fallback`.
ysiraichi Jun 6, 2024
a908ed4
Remove 'fallback' suffix.
ysiraichi Jun 11, 2024
2d4bb59
Scratch
ysiraichi Jun 17, 2024
aae7197
Deal with undefined tensors + bring back step 3.
ysiraichi Jun 18, 2024
bbb86a8
Refactor test.
ysiraichi Jun 18, 2024
540275a
Refactor fallback.
ysiraichi Jun 19, 2024
868333a
Refactor tests.
ysiraichi Jun 20, 2024
e9fb89d
Fix lint issues.
ysiraichi Jun 20, 2024
e1e1da1
Forward declare CUDA functions.
ysiraichi Jun 20, 2024
5413a74
Lazy initialize CUDA.
ysiraichi Jun 20, 2024
1304dc7
Fix lint.
ysiraichi Jun 20, 2024
e4642f9
Fix compilation.
ysiraichi Jun 20, 2024
2f32ba8
Address reviews.
ysiraichi Jun 21, 2024
8c0dddd
Create fallback implementation for tests.
ysiraichi Jun 24, 2024
9c2e28b
Fix lint issue.
ysiraichi Jun 24, 2024
ea33a74
Fix test builds and add `libc10_cuda.so` dependency.
ysiraichi Jun 24, 2024
0ef4e1b
Fix C++ test dependencies.
ysiraichi Jun 24, 2024
d035935
Build a second library for conditionally loading CUDA functions.
ysiraichi Jun 25, 2024
4129e86
Add Python module initialization.
ysiraichi Jun 25, 2024
9a22be7
Fix lib and `dlopen` options.
ysiraichi Jun 27, 2024
7d5063a
Fix lint issues.
ysiraichi Jun 27, 2024
75ffb01
Skip CUDA fallback tests on CPU envs.
ysiraichi Jun 27, 2024
0dd29da
Fix lint issues.
ysiraichi Jun 27, 2024
2c2b165
Clean up.
ysiraichi Jun 27, 2024
4686238
Make sure only one XLA device is used.
ysiraichi Jul 2, 2024
8d5fe8f
Fix compilation issues.
ysiraichi Jul 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ cc_binary(
]),
)

cc_binary(
name = "_XLAC_cuda_functions.so",
copts = [
"-fopenmp",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how the copts and linkopts are determined

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I just copied them from _XLAC. I guess I could get rid of them, though. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if it works now, feel free to keep it :p

"-fPIC",
],
linkopts = [
"-Wl,-soname,_XLAC_cuda_functions.so",
],
linkshared = 1,
visibility = ["//visibility:public"],
deps = [
"//torch_xla/csrc:aten_cuda_functions",
],
)

test_suite(
name = "cpp_tests",
# testonly = True,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def run(self):
packages=find_packages(include=['torch_xla*']),
ext_modules=[
BazelExtension('//:_XLAC.so'),
BazelExtension('//:_XLAC_cuda_functions.so'),
],
install_requires=[
'absl-py>=1.0.0',
Expand Down
6 changes: 6 additions & 0 deletions test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ptxla_cc_test(
":cpp_test_util",
":torch_xla_test",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -63,6 +64,7 @@ ptxla_cc_test(
deps = [
":torch_xla_test",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:shape_util",
],
Expand All @@ -77,6 +79,7 @@ ptxla_cc_test(
"//torch_xla/csrc/runtime:runtime",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"//torch_xla/csrc:thread_pool",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
Expand All @@ -95,6 +98,7 @@ ptxla_cc_test(
":cpp_test_util",
":torch_xla_test",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -119,6 +123,7 @@ ptxla_cc_test(
"//torch_xla/csrc/runtime:env_vars",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:xla_data_proto_cc",
"@tsl//tsl/profiler/utils:session_manager",
Expand All @@ -137,6 +142,7 @@ ptxla_cc_test(
":torch_xla_test",
"//torch_xla/csrc/runtime:metrics",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using the fallback implementation for solving the undefined references in C++ tests. I think this should be reasonable, since we don't test fallback on C++ tests.

"@com_google_googletest//:gtest_main",
"@xla//xla:permutation_util",
],
Expand Down
Loading
Loading