Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 24, 2022
1 parent 2f23530 commit b059f73
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 48 deletions.
37 changes: 21 additions & 16 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,20 @@ class DDPSpawnStrategy(ParallelStrategy):
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
# parallel_devices: Optional[List[torch.device]] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
process_group_backend: Optional[str] = None,
start_method: str = "spawn",
**kwargs: Any,
):
super().__init__(
accelerator=accelerator,
# parallel_devices=parallel_devices,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
Expand All @@ -84,6 +85,7 @@ def __init__(
self._ddp_comm_wrapper = ddp_comm_wrapper
self._local_rank = 0
self._process_group_backend: Optional[str] = process_group_backend
self._start_method = start_method

@property
def num_nodes(self) -> int:
Expand All @@ -100,11 +102,11 @@ def local_rank(self) -> int:

@property
def root_device(self):
return torch.device("cuda", self.local_rank)
return self.parallel_devices[self.local_rank]

@property
def num_processes(self):
return 2
return len(self.parallel_devices)

@property
def distributed_sampler_kwargs(self):
Expand All @@ -120,7 +122,7 @@ def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

def _configure_launcher(self):
self._launcher = _SpawnLauncher(self)
self._launcher = _SpawnLauncher(self, start_method=self._start_method)

def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
Expand Down Expand Up @@ -270,17 +272,20 @@ def post_training_step(self):

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"ddp_spawn_find_unused_parameters_false",
cls,
description="DDPSpawn Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
for start_method in ("spawn", "fork"):
strategy_registry.register(
f"ddp_{start_method}_find_unused_parameters_false",
cls,
description="DDPSpawn Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
start_method=start_method,
)
strategy_registry.register(
f"ddp_{start_method}",
cls,
description=f"{cls.__class__.__name__}",
start_method=start_method,
)

def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down strategy")
Expand Down
15 changes: 7 additions & 8 deletions src/pytorch_lightning/strategies/launchers/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ class _SpawnLauncher(_Launcher):
strategy: A reference to the strategy that is used together with this launcher.
"""

def __init__(self, strategy: Strategy) -> None:
def __init__(self, strategy: Strategy, start_method: str = "spawn") -> None:
self._strategy = strategy
self._start_method = "fork"
self._start_method = start_method

@property
def is_interactive_compatible(self) -> bool:
# The start method 'spawn' is currently the only one that works with DDP and CUDA support
# The start method 'fork' is the only one supported in Jupyter environments but not compatible with CUDA
# For more context, see https://github.com/Lightning-AI/lightning/issues/7550
# return self._start_method == "fork" and self._strategy.root_device.type != "cuda"
return True
# The start method 'spawn' is not supporrted in interactive environments
# The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
# initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
return self._start_method == "fork"

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
Expand All @@ -81,7 +80,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
start_method="fork",
start_method=self._start_method,
)
spawn_output = return_queue.get()
if trainer is None:
Expand Down
48 changes: 24 additions & 24 deletions src/pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,23 +351,23 @@ def _check_config_and_set_final_flags(
else:
self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment")

# if hasattr(self._strategy_flag, "parallel_devices"):
# if self._strategy_flag.parallel_devices:
# if self._strategy_flag.parallel_devices[0].type == "cpu":
# if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"):
# raise MisconfigurationException(
# f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
# f" but accelerator set to {self._accelerator_flag}, please choose one device type"
# )
# self._accelerator_flag = "cpu"
# if self._strategy_flag.parallel_devices[0].type == "cuda":
# if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"):
# raise MisconfigurationException(
# f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
# f" but accelerator set to {self._accelerator_flag}, please choose one device type"
# )
# self._accelerator_flag = "gpu"
# self._parallel_devices = self._strategy_flag.parallel_devices
if hasattr(self._strategy_flag, "parallel_devices"):
if self._strategy_flag.parallel_devices:
if self._strategy_flag.parallel_devices[0].type == "cpu":
if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"):
raise MisconfigurationException(
f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "cpu"
if self._strategy_flag.parallel_devices[0].type == "cuda":
if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"):
raise MisconfigurationException(
f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "gpu"
self._parallel_devices = self._strategy_flag.parallel_devices

amp_type = amp_type if isinstance(amp_type, str) else None
self._amp_type_flag = AMPType.from_str(amp_type)
Expand Down Expand Up @@ -521,8 +521,8 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
self._set_devices_flag_if_auto_select_gpus_passed()

self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
# if not self._parallel_devices:
# self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
if not self._parallel_devices:
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)

def _set_devices_flag_if_auto_passed(self) -> None:
if self._devices_flag == "auto" or self._devices_flag is None:
Expand Down Expand Up @@ -758,11 +758,11 @@ def _lazy_init_strategy(self) -> None:
self.strategy.checkpoint_io = self.checkpoint_io
if hasattr(self.strategy, "cluster_environment"):
self.strategy.cluster_environment = self.cluster_environment
# if hasattr(self.strategy, "parallel_devices"):
# if self.strategy.parallel_devices:
# self._parallel_devices = self.strategy.parallel_devices
# else:
# self.strategy.parallel_devices = self._parallel_devices
if hasattr(self.strategy, "parallel_devices"):
if self.strategy.parallel_devices:
self._parallel_devices = self.strategy.parallel_devices
else:
self.strategy.parallel_devices = self._parallel_devices
if hasattr(self.strategy, "num_nodes"):
self.strategy._num_nodes = self._num_nodes_flag
if hasattr(self.strategy, "_layer_sync"):
Expand Down

0 comments on commit b059f73

Please sign in to comment.