Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying PjRt device at runtime doesn't work. #5942

Closed
ysiraichi opened this issue Nov 29, 2023 · 6 comments
Closed

Modifying PjRt device at runtime doesn't work. #5942

ysiraichi opened this issue Nov 29, 2023 · 6 comments
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

Not sure this is a bug or intended behavior. But, once the PjRt client is initialized with a device, apparently we can't change it. If we try to do it, PyTorch/XLA won't complain, and execute everything in the initialized device.

for dev in ("CPU", "CUDA", "CPU"):
    os.environ["PJRT_DEVICE"] = dev
    print(f"Supported devices for {dev}:", xm.get_xla_supported_devices(devkind=dev))
    device = xm.xla_device()
    a = torch.rand(5, 5, device=device)
    r = a @ a
    xm.mark_step()
Supported devices for CPU: ['xla:0']  # executed on CPU:0
Supported devices for CUDA: None      # executed on CPU:0 (no error or warning)
Supported devices for CPU: ['xla:0']  # executed on CPU:0

Expected behavior

Issue a warning (or, even better, an error). Or be able to change devices at runtime.

Environment

Additional context

This came up when trying to use the recently upstreamed benchmark. The function is_xla_device_available (inside benchmarks/util.py) is called for each enabled accelerator.

@ysiraichi
Copy link
Collaborator Author

ysiraichi commented Nov 29, 2023

cc @JackCaoG @miladm

@JackCaoG
Copy link
Collaborator

I think this is intended, we don't really expect user to switch PJRT device after program is inited. In fact, we don't expect user to change most(if not all) env var after program started, since most of them are set as static variable in C++.

@ysiraichi
Copy link
Collaborator Author

Then, maybe it would be better to error out, wouldn't it?

@JackCaoG
Copy link
Collaborator

hmm it means we need to check PJRT_DEVICE every time when we call some variant of the xla_device.. From this perspective using env var is annoying.. @will-cromar any thoughts?

@ysiraichi
Copy link
Collaborator Author

Ouch. What about having an API for initializing the PjRt client with a given PjRt device? It would make "doing the wrong thing" (i.e. changing the device) very hard.

@will-cromar
Copy link
Collaborator

Ouch. What about having an API for initializing the PjRt client with a given PjRt device? It would make "doing the wrong thing" (i.e. changing the device) very hard.

I like that idea. We can add a warning/error to runtime.set_device_type if the runtime is already initialized and discourage using the environment variables directly within code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants