You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The _setup_model method of DDPStrategy triggers this exception, as torch.cuda.stream is hardcoded if device_ids are passed. I've reproduced the snippet below, but here is a permalink.
@overridedef_setup_model(self, model: Module) ->DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""device_ids=self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5ctx=torch.cuda.stream(torch.cuda.Stream()) ifdevice_idsisnotNoneelsenullcontext()
withctx:
returnDistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
A potential solution could be checking the target device, or even just checking torch.cuda.is_available() for the condition. Removing the torch.cuda.Stream() call and just using the nullcontext() functions perfectly fine otherwise.
The snippet provided below relies on an XPUAccelerator registered here, but I would assume this might trigger for other accelerators as well.
Bug description
When configuring a
DDPStrategy
with multiple devices that do not use thetorch.cuda
API, we trigger the following exception:The
_setup_model
method ofDDPStrategy
triggers this exception, astorch.cuda.stream
is hardcoded ifdevice_ids
are passed. I've reproduced the snippet below, but here is a permalink.A potential solution could be checking the target device, or even just checking
torch.cuda.is_available()
for the condition. Removing thetorch.cuda.Stream()
call and just using thenullcontext()
functions perfectly fine otherwise.The snippet provided below relies on an
XPUAccelerator
registered here, but I would assume this might trigger for other accelerators as well.What version are you seeing the problem on?
v2.1, v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
cc @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: