Skip to content

Commit

Permalink
Fixes around Strategy.set_world_ranks (#16966)
Browse files Browse the repository at this point in the history
* don't call set_world_ranks in xla strategy

* update

* fabric and other strategies

* CHANGELOG

* Typos

* Reuse test

---------

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Apr 13, 2023
1 parent 17548d5 commit 50662eb
Show file tree
Hide file tree
Showing 13 changed files with 50 additions and 49 deletions.
6 changes: 6 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))


- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))

### Deprecated

-
Expand All @@ -39,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue where running on TPUs would select the wrong device index ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Fixed issue where Fabric would not initialize the global rank, world size, and rank-zero-only rank after initialization and before launch ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))


## [2.0.1.post0] - 2023-04-11

No changes
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ def _lazy_init_strategy(self) -> None:
self.strategy.parallel_devices = self._parallel_devices
if hasattr(self.strategy, "num_nodes"):
self.strategy._num_nodes = self._num_nodes_flag
if hasattr(self.strategy, "set_world_ranks"):
self.strategy.set_world_ranks()
if hasattr(self.strategy, "_set_world_ranks"):
self.strategy._set_world_ranks()
self.strategy._configure_launcher()

if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/plugins/environments/lsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def world_size(self) -> int:
if world_size is None:
raise ValueError(
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
"Make sure you run your executable with `jsrun`."
" Make sure you run your executable with `jsrun`."
)
return int(world_size)

Expand All @@ -101,7 +101,7 @@ def global_rank(self) -> int:
if global_rank is None:
raise ValueError(
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
"Make sure you run your executable with `jsrun`."
" Make sure you run your executable with `jsrun`."
)
return int(global_rank)

Expand All @@ -114,7 +114,7 @@ def local_rank(self) -> int:
if local_rank is None:
raise ValueError(
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
"Make sure you run your executable with `jsrun`."
" Make sure you run your executable with `jsrun`."
)
return int(local_rank)

Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:

def _setup_distributed(self) -> None:
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand All @@ -186,11 +185,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank

def _determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

Expand Down Expand Up @@ -580,7 +580,6 @@ def _setup_distributed(self) -> None:
)
reset_seed()
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._init_deepspeed_distributed()
if not self._config_initialized:
self._format_config()
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
def _setup_distributed(self) -> None:
reset_seed()
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand All @@ -329,11 +328,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank


def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
Expand Down
6 changes: 0 additions & 6 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _configure_launcher(self) -> None:

def setup_environment(self) -> None:
self._launched = True
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
super().setup_environment()

Expand Down Expand Up @@ -203,8 +202,3 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("xla", cls, description=cls.__class__.__name__)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
rank_zero_only.rank = self.cluster_environment.global_rank()
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))


- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))

### Deprecated

-
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def setup_distributed(self) -> None:
log.debug(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
self.set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand All @@ -192,11 +191,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank

def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.types import LRSchedulerConfig, STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -326,7 +326,6 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
def setup_distributed(self) -> None:
reset_seed()
self.set_world_ranks()
rank_zero_only.rank = self.global_rank
self._init_deepspeed_distributed()
if not self._config_initialized:
self._format_config()
Expand Down
14 changes: 6 additions & 8 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def setup_environment(self) -> None:
# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend)
Expand All @@ -190,11 +187,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank

def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def reduce(

def setup_distributed(self) -> None:
self._launched = True
self.set_world_ranks()
rank_zero_only.rank = self.global_rank

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
rank_zero_only.rank = self.cluster_environment.global_rank()
# accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
# processes (by the accelerator connector), we cannot run the code that would normally be here.
# instead it's done in `setup_distributed`
pass

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def creates_processes_externally(self) -> bool:
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, CustomCluster)
# this checks that `strategy._set_world_ranks` was called by the connector
assert connector.strategy.world_size == 2


@RunIf(mps=False)
Expand Down Expand Up @@ -230,10 +232,10 @@ class Strat(DDPStrategy):
@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment):
with mock.patch.dict(os.environ, env_vars, clear=True):
trainer = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, expected_environment)
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, expected_environment)


@RunIf(mps=False)
Expand Down

0 comments on commit 50662eb

Please sign in to comment.