Skip to content

Commit

Permalink
add set local device
Browse files Browse the repository at this point in the history
  • Loading branch information
rgao committed Dec 4, 2024
1 parent e11e78e commit b8c3905
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def setup(config) -> None:
assign_device_for_local_rank(config["cpu"], config["local_rank"])
else:
# in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank
# this is dangerous and should be deprecated
# this is dangerous and should be deprecated, however, FSDP still requires backwards compatibility with
# initializing this way for now so we need to keep it
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
Expand Down Expand Up @@ -123,6 +124,11 @@ def setup(config) -> None:
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
if config.get("use_cuda_visibile_devices"):
assign_device_for_local_rank(config["cpu"], config["local_rank"])
else:
# in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank
# this is dangerous and should be deprecated, however, FSDP still requires backwards compatibility with
# initializing this way for now so we need to keep it
torch.cuda.set_device(config["local_rank"])
dist.init_process_group(
backend=config["distributed_backend"],
rank=int(os.environ.get("RANK")),
Expand Down

0 comments on commit b8c3905

Please sign in to comment.