Skip to content

Commit

Permalink
Introduce CUDA OpenXLA fallback. (#7318)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jul 3, 2024
1 parent 172097b commit c782e0d
Show file tree
Hide file tree
Showing 14 changed files with 1,037 additions and 513 deletions.
16 changes: 16 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ cc_binary(
]),
)

cc_binary(
name = "_XLAC_cuda_functions.so",
copts = [
"-fopenmp",
"-fPIC",
],
linkopts = [
"-Wl,-soname,_XLAC_cuda_functions.so",
],
linkshared = 1,
visibility = ["//visibility:public"],
deps = [
"//torch_xla/csrc:aten_cuda_functions",
],
)

test_suite(
name = "cpp_tests",
# testonly = True,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def run(self):
packages=find_packages(include=['torch_xla*']),
ext_modules=[
BazelExtension('//:_XLAC.so'),
BazelExtension('//:_XLAC_cuda_functions.so'),
],
install_requires=[
'absl-py>=1.0.0',
Expand Down
6 changes: 6 additions & 0 deletions test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ptxla_cc_test(
":cpp_test_util",
":torch_xla_test",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -63,6 +64,7 @@ ptxla_cc_test(
deps = [
":torch_xla_test",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:shape_util",
],
Expand All @@ -77,6 +79,7 @@ ptxla_cc_test(
"//torch_xla/csrc/runtime:runtime",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"//torch_xla/csrc:thread_pool",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
Expand All @@ -95,6 +98,7 @@ ptxla_cc_test(
":cpp_test_util",
":torch_xla_test",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -119,6 +123,7 @@ ptxla_cc_test(
"//torch_xla/csrc/runtime:env_vars",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:xla_data_proto_cc",
"@tsl//tsl/profiler/utils:session_manager",
Expand All @@ -137,6 +142,7 @@ ptxla_cc_test(
":torch_xla_test",
"//torch_xla/csrc/runtime:metrics",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:permutation_util",
],
Expand Down
Loading

0 comments on commit c782e0d

Please sign in to comment.