Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: device parser #3400

Merged
merged 9 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ Indices and tables
api/pytorch_lightning.profiler
api/pytorch_lightning.trainer
api/pytorch_lightning.utilities
api/pytorch_lightning.tuner
270 changes: 0 additions & 270 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,16 @@

"""

from contextlib import ExitStack
import os
from abc import ABC, abstractmethod
import time
import random
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import move_data_to_device, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only

try:
from apex import amp
except ImportError:
amp = None

try:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
Expand Down Expand Up @@ -82,30 +62,10 @@ class TrainerDPMixin(ABC):
logger: ...
amp_backend: AMPType

@abstractmethod
def call_setup_hook(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def reinit_scheduler_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def setup(self, *args) -> None:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def is_function_implemented(self, *args) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""

def copy_trainer_model_properties(self, model):
if isinstance(model, LightningDataParallel):
ref_model = model.module
Expand Down Expand Up @@ -152,233 +112,3 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if isinstance(s, str):
if s == '-1':
return -1
else:
return [int(x.strip()) for x in s.split(',') if len(x) > 0]
else:
return s


def get_all_available_gpus() -> List[int]:
"""
Returns:
a list of all available gpus
"""
return list(range(torch.cuda.device_count()))


def _check_data_type(device_ids: Any) -> None:
"""
Checks that the device_ids argument is one of: None, Int, String or List.
Raises a MisconfigurationException otherwise.

Args:
device_ids: gpus/tpu_cores parameter as passed to the Trainer
"""
if device_ids is not None and (not isinstance(device_ids, (int, str, MutableSequence)) or isinstance(device_ids, bool)):
raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.")


def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int]]) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, MutableSequence):
return list(gpus)

# must be an int
if not gpus: # gpus==0
return None
if gpus == -1:
return get_all_available_gpus()

return list(range(gpus))


def sanitize_gpu_ids(gpus: List[int]) -> List[int]:
"""
Checks that each of the GPUs in the list is actually available.
Raises a MisconfigurationException if any of the GPUs is not available.

Args:
gpus: list of ints corresponding to GPU indices

Returns:
unmodified gpus variable
"""
all_available_gpus = get_all_available_gpus()
misconfig = False
for gpu in gpus:
if gpu not in all_available_gpus:
misconfig = True

if misconfig:
# sometimes auto ddp might have different flags
# but this is not what the user intended
# correct for the user
if len(gpus) == len(all_available_gpus):
gpus = all_available_gpus
else:
raise MisconfigurationException(f"""
You requested GPUs: {gpus}
But your machine only has: {all_available_gpus}
""")
return gpus


def _parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
"""
Parses the GPU ids given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.

Args:
gpus: An int -1 or string '-1' indicate that all available GPUs should be used.
A list of ints or a string containing list of comma separated integers
indicates specific GPUs to use.
An int 0 means that no GPUs should be used.
Any int N > 0 indicates that GPUs [0..N) should be used.

Returns:
a list of gpus to be used or ``None`` if no GPUs were requested

If no GPUs are available but the value of gpus variable indicates request for GPUs
then a MisconfigurationException is raised.
"""

# nothing was passed into the GPUs argument
if callable(gpus):
return None

# Check that gpus param is None, Int, String or List
_check_data_type(gpus)

# Handle the case when no gpus are requested
if gpus is None or isinstance(gpus, int) and gpus == 0:
return None

# We know user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.

gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")
gpus = sanitize_gpu_ids(gpus)

return gpus


def determine_root_gpu_device(gpus: List[int]) -> Optional[int]:
"""
Args:
gpus: non-empty list of ints representing which gpus to use

Returns:
designated root GPU device id
"""
if gpus is None:
return None

assert isinstance(gpus, list), "gpus should be a list"
assert len(gpus) > 0, "gpus should be a non empty list"

# set root gpu
root_gpu = gpus[0]

return root_gpu


def retry_jittered_backoff(func: Callable, num_retries: int = 5, cap_delay: float = 1.0, base_delay: float = 0.01):
"""Retry jittered backoff.

Based on:
https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

Args:
func: tested function
num_retries: number of tries
cap_delay: max sleep time
base_delay: initial sleep time is 10ms
"""
sleep_delay = base_delay # initial sleep time is 10ms

for i in range(num_retries):
try:
return func()
except RuntimeError as err:
if i == num_retries - 1:
raise err
else:
continue
time.sleep(sleep_delay)
sleep_delay = min(cap_delay, random.uniform(base_delay, sleep_delay * 3))


def _parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int], int]]:
"""
Parses the tpu_cores given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.

Args:
tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used
An int 8 or string '8' indicate that all 8 cores with multi-processing should be used
A list of int or a string containing list of comma separated integer
indicates specific TPU core to use.

Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
"""

if callable(tpu_cores):
return None

_check_data_type(tpu_cores)

if isinstance(tpu_cores, str):
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())

if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")

return tpu_cores


def _tpu_cores_valid(tpu_cores):
return tpu_cores in (1, 8, None) or (
isinstance(tpu_cores, (list, tuple, set)) and
len(tpu_cores) == 1 and
tpu_cores[0] in range(1, 9)
)


def _parse_tpu_cores_str(tpu_cores):
if tpu_cores in ('1', '8'):
tpu_cores = int(tpu_cores)
else:
tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0]
return tpu_cores


def pick_single_gpu(exclude_gpus: list):
for i in range(torch.cuda.device_count()):
if i in exclude_gpus:
continue
# Try to allocate on device:
device = torch.device(f"cuda:{i}")
try:
torch.ones(1).to(device)
except RuntimeError:
continue
return i
raise RuntimeError("No GPUs available.")


def pick_multiple_gpus(nb):
picked = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))

return picked
44 changes: 22 additions & 22 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin, _parse_gpu_ids, _parse_tpu_cores,
determine_root_gpu_device, pick_multiple_gpus)
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin)
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
Expand All @@ -56,9 +56,9 @@
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning import _logger as log
from pytorch_lightning.trainer.tuning import Tuner
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.model_utils import is_overridden

# warnings to ignore in trainer
Expand Down Expand Up @@ -373,6 +373,20 @@ def __init__(
if 'LOCAL_RANK' in os.environ:
rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.lr_scheduler_connector = LRSchedulerConnector(self)
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.tuner = Tuner(self)
self.accelerator_backend = None

# loops
self.evaluation_loop = EvaluationLoop(self)
self.train_loop = TrainLoop(self)

# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningAccum(window_length=20)
Expand Down Expand Up @@ -449,7 +463,7 @@ def __init__(
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)

self.tpu_cores = _parse_tpu_cores(tpu_cores)
self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
self.on_tpu = self.tpu_cores is not None

self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None
Expand Down Expand Up @@ -507,12 +521,12 @@ def __init__(

# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
self.gpus = self.tuner.pick_multiple_gpus(gpus)
else:
self.gpus = gpus

self.data_parallel_device_ids = _parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
self.root_gpu = device_parser.determine_root_gpu_device(self.data_parallel_device_ids)
self.root_device = torch.device("cpu")

self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False
Expand Down Expand Up @@ -605,20 +619,6 @@ def __init__(

self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.lr_scheduler_connector = LRSchedulerConnector(self)
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.tuner = Tuner(self)
self.accelerator_backend = None

# loops
self.evaluation_loop = EvaluationLoop(self)
self.train_loop = TrainLoop(self)

# Callback system
self.on_init_end()

Expand Down
Empty file.
Loading