-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make CUDA OpenXLA fallback the default. #7630
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
please add
XLA_FALLBACK_CPU
to https://github.com/pytorch/xla/blob/master/configuration.yaml and explain its current/latest function -
have you tested this PR on your full torchbench runs to verify functionality?
approving to unblock - let's confirm the above items please.
@@ -11,6 +11,8 @@ static void fail(const char* name) { | |||
|
|||
namespace c10::cuda { | |||
|
|||
DeviceIndex device_count() noexcept { return 0; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, we unit test this function in python layer - do we need a unit test for the cpp layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we? Can you point me to the source location? These functions are supposed to be implemented by PyTorch when it is compiled with CUDA support. Otherwise, this is an implementation for cases when PyTorch is compiled without CUDA support. This needs to be supplied since we would get an undefined reference, otherwise.
@@ -55,6 +55,10 @@ std::string DeviceType::toString() const { | |||
return absl::StrCat(type_name_, ":"); | |||
} | |||
|
|||
XlaDeviceType DeviceType::getType() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a unit test for this function? It only casts an integer into an enum.
cc @zpcore |
@@ -51,8 +52,33 @@ std::vector<std::string> GetFallbackOperations() { | |||
// Before each modified function below, we shall specify what has changed, | |||
// if there was any. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, we can fix in a follow up. We should rename this file to aten_fallback.cpp
Thanks, looks like the fallback is enabled by default. We should see the performance update tomorrow. |
As @miladm pointed out, I will still run benchmarks with this PR. So, this might get landed only tomorrow. |
Actually, I think I will wait for #7647. It looks like a relevant issue. |
torch_xla/csrc/aten_cpu_fallback.cpp
Outdated
bool UseCUDAFallback() { | ||
return runtime::sys_util::GetEnvBool("XLA_FALLBACK_CUDA", false); | ||
// Decide whether to run OpenXLA fallback operations on CUDA. | ||
bool UseCUDAFallback(const c10::OperatorHandle& op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have another type of cuda fallback, how about the name "OpenXLAFallbackOnCUDA"
@@ -11,6 +11,8 @@ static void fail(const char* name) { | |||
|
|||
namespace c10::cuda { | |||
|
|||
DeviceIndex device_count() noexcept { return 0; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It returns 0 because it's a phony implementation?
Also, could you add a comment in this file describing it's phony implementation and what is the purpose of this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor comments
462a877
to
34a3ac1
Compare
// List of operations that should be fallbacked to CPU instead of GPU. | ||
static std::unordered_set<std::string> _force_fallback_on_cpu{ | ||
// This operation is a simple memory access that transforms the given | ||
// 1-element tensor into a Scalar. | ||
// | ||
// Although it makes sense to run this operation on CPU (since the | ||
// output will get copied back to CPU anyway), this also fixes a | ||
// particular issue with moco benchmark. | ||
// More details: https://github.com/pytorch/xla/issues/7647 | ||
"aten::_local_scalar_dense", | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be completely transparent: aten::_local_scalar_dense
is here for 2 reasons:
- It just makes sense to run it on CPU, since the output (a
Scalar
) also lives on CPU - As a temporary fix to [torchbench]
moco
fails to run with CUDA OpenXLA fallback. #7647
@miladm @JackCaoG @vanbasten23 @zpcore In case you want to take another look at this PR, I added the following changes:
I will merge this one tomorrow. |
Partially fix: #7342
This PR changes the default device for running OpenXLA fallback operation from CPU to CUDA. So, instead of specifying
XLA_FALLBACK_CUDA=1
, in order to run fallback operations on CUDA the user must make sure that the following is true:XLA_FALLBACK_CPU
(newly introduced) is not setDeviceType
of the currentComputationClient
is CUDAI have also changed
test_fallback
function so that it won't (un)setXLA_FALLBACK_CPU
flag. Instead, it just runs the fallback operation. But, from this PR onwards, PyTorch/XLA should be able to detect when it can't do CUDA OpenXLA fallback automatically.That said, one can force CPU fallback execution with the newly introduced
XLA_FALLBACK_CPU
environment variable.cc @miladm @vanbasten23 @JackCaoG