Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove or simplify hardcoded lists of device types #6235

Merged
merged 18 commits into from
Jan 10, 2024

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Dec 26, 2023

Allows new device plugins without registering the device type in our code! See #6242 for broader context of this cleanup.

  • Add placeholder PLUGIN device type for unknown devices
    • Store the actual device name (according to PjRtClient::platform_name) in the DeviceType, since we still need to pass the device name string across Python/C++ boundary
    • Replace explicit lists of devices with patterns
  • Deprecate devkind argument, since we only support one PJRT backend at a time
  • Remove explicit references to device types that do not have any special cases (XPU, ROCM)
  • Remove deprecated GPU device type
  • Simplify several test skip conditions that had lists of device types

Tested with CPU plugin in #6253 where I change the platform name to TEST to simulate an unknown device type. test_operations.py passes with that plugin.

@will-cromar will-cromar changed the title [WIP] Remove or simplify hardcoded lists of device types Remove or simplify hardcoded lists of device types Jan 9, 2024
@will-cromar will-cromar marked this pull request as ready for review January 9, 2024 00:21
@@ -71,6 +71,11 @@ class IfrtComputationClient : public ComputationClient {

std::string GetDefaultDevice() const override;

torch_xla::DeviceType GetDeviceType() const override {
return torch_xla::DeviceType(
absl::AsciiStrToUpper(client_->platform_name()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The platform_name will be something like CUDA or TPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

Copy link
Collaborator

@vanbasten23 vanbasten23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the simplification!

@@ -72,7 +72,7 @@ def is_xla_tensor(tensor):


def parse_xla_device(device):
m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device)
m = re.match(r'([A-Z]+):(\d+)$', device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to think.. where would it fail if user has a typo on PJRT_DEVICE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device name here comes from the PjRtClient::platform_name, which is not guaranteed to match the PJRT_DEVICE name.

If there's a typo in PJRT_DEVICE, the runtime will fail to initialize and throw an error.

} // namespace

std::string DeviceType::XlaDeviceTypeToString(XlaDeviceType hw_type) {
XLA_CHECK(hw_type != XlaDeviceType::PLUGIN) << "PLUGIN type name unknown";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will this function be called? Are we not expceting this to be called on plug in?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is used when constructing a device from the XlaDeviceType enum type rather than the device name in a string. E.g. here:

default_device_type_ =
std::make_shared<DeviceType>(static_cast<XlaDeviceType>(type));

If all we get is the placeholder XlaDeviceType::PLUGIN type, then we don't know the actual platform_name of the device type for toString. I think this is a relatively rare case in our code, so I can try to remove it. I have a more ambitious refactoring effort going in #6261.

type_name_(type_name) {}

// TODO(wcromar): do we even want this default constructor?
DeviceType() : DeviceType(XlaDeviceType::CPU){};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like we should eventually let PyTorch/XLA to auto detect the device based on some rules(libtpu, cuda version etc). Can't really think of a case we want to default construct a CPU device...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm 90% sure this constructor is never called. Logically, the "default" device type is managed by xla_backend_impl.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we definitely don't need this. Upstream provides a better default constructor: https://github.com/pytorch/pytorch/blob/16d69290c6d037a25e32220b9517597d04dbd0bf/torch/csrc/lazy/backend/backend_device.cpp#L14-L15

I'm just going to delete this now instead of leaving the TODO.

@will-cromar will-cromar merged commit 050a240 into master Jan 10, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants