Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes ddp bugs #1819

Merged
merged 37 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c5f63df
debug
williamFalcon May 13, 2020
54f510c
debug
williamFalcon May 13, 2020
47d0161
debug
williamFalcon May 13, 2020
b91d072
debug
williamFalcon May 13, 2020
83d709d
debug
williamFalcon May 13, 2020
ba8d8b5
debug
williamFalcon May 13, 2020
c1a5c4a
debug
williamFalcon May 13, 2020
d734ba1
debug
williamFalcon May 13, 2020
b17efc9
debug
williamFalcon May 13, 2020
f5ccb64
debug
williamFalcon May 13, 2020
c7da23b
debug
williamFalcon May 13, 2020
93bd9da
debug
williamFalcon May 13, 2020
f181c44
debug
williamFalcon May 13, 2020
a0ef963
debug
williamFalcon May 13, 2020
0493245
debug
williamFalcon May 13, 2020
725231b
debug
williamFalcon May 13, 2020
7b3bce1
debug
williamFalcon May 13, 2020
ad10a55
debug
williamFalcon May 13, 2020
458e724
debug
williamFalcon May 13, 2020
f49412b
debug
williamFalcon May 13, 2020
b593873
debug
williamFalcon May 13, 2020
807ed4b
debug
williamFalcon May 13, 2020
92b6ca0
debug
williamFalcon May 13, 2020
348982a
debug
williamFalcon May 13, 2020
2843073
debug
williamFalcon May 13, 2020
4b40b52
debug
williamFalcon May 13, 2020
8ae37a8
debug
williamFalcon May 13, 2020
24931c9
debug
williamFalcon May 13, 2020
f1a969c
debug
williamFalcon May 13, 2020
8506a16
debug
williamFalcon May 13, 2020
75d0131
debug
williamFalcon May 13, 2020
95aead2
debug
williamFalcon May 13, 2020
f54d2a1
debug
williamFalcon May 13, 2020
dfb6050
debug
williamFalcon May 13, 2020
dbb5a59
debug
williamFalcon May 13, 2020
82aaf41
debug
williamFalcon May 13, 2020
5c56747
debug
williamFalcon May 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def check_monitor_top_k(self, current):
return True

if not isinstance(current, torch.Tensor):
rank_zero_warn(
f'{current} is supposed to be a torch.Tensor. Saving checkpoint may not work correctly. '
f'HINT: check the value of {self.monitor} in your validation loop', RuntimeWarning
)
current = torch.tensor(current)

monitor_op = {
Expand Down Expand Up @@ -223,6 +227,12 @@ def on_validation_end(self, trainer, pl_module):
if self.save_top_k != -1:
current = metrics.get(self.monitor)

if not isinstance(current, torch.Tensor):
rank_zero_warn(
f'The metric you returned {current} must be a Torch.Tensor instance, checkpoint not saved '
f'HINT: what is the value of {self.monitor} in validation_end()?', RuntimeWarning
)

if current is None:
rank_zero_warn(
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:

if self.replace_sampler_ddp and need_dist_sampler:
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']

dl_args = {
Expand All @@ -137,7 +137,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
}
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=world_size.get(self.distributed_backend, 0),
num_replicas=world_size[self.distributed_backend],
rank=self.proc_rank,
)

Expand Down
12 changes: 9 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TrainerDDPMixin(ABC):
default_root_dir: str
use_native_amp: bool
progress_bar_callback: ...
num_processes: int

@property
@abstractmethod
Expand Down Expand Up @@ -204,14 +205,17 @@ def set_distributed_mode(self, distributed_backend):
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=ddp for you.')
self.use_ddp = True
elif distributed_backend == "dp":
self.distributed_backend = 'ddp'
distributed_backend = 'ddp'

if distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.num_gpus == 1:
self.single_gpu = True
self.use_dp = True
elif self.num_gpus > 1:
self.use_dp = True

elif distributed_backend == "ddp":
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
Expand All @@ -222,6 +226,7 @@ def set_distributed_mode(self, distributed_backend):
elif self.num_gpus > 1:
self.use_ddp = True
self.num_processes = self.num_gpus

elif distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.num_gpus >= 1:
Expand Down Expand Up @@ -314,7 +319,8 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str

log.debug(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
# don't make this debug... this is good UX
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')

def ddp_train(self, process_idx, model):
"""
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,9 @@ def tpu_train(self, tpu_core_idx, model):
# continue training routine
self.run_pretrain_routine(model)

self.save_spawn_weights(model)
# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
self.save_spawn_weights(model)

def dp_train(self, model):

Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: Optional[str] = 'full',
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 5,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[Union[BaseProfiler, bool]] = None,
Expand Down Expand Up @@ -526,6 +526,8 @@ def __init__(
self.amp_level = amp_level
self.init_amp(use_amp)

self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

# Callback system
self.on_init_end()

Expand Down Expand Up @@ -811,7 +813,7 @@ def fit(
# train
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
# load weights if not interrupted
if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE'):
if self.on_colab_kaggle:
self.load_spawn_weights(model)
self.model = model

Expand All @@ -830,7 +832,7 @@ def fit(
log.info(f'training on {self.num_tpu_cores} TPU cores')

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') else 'spawn'
start_method = 'fork' if self.on_colab_kaggle else 'spawn'

# track for predict
self.model = model
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def test_gpu_choice(tmpdir):
),
pytest.param(
dict(distributed_backend=None, gpus=2),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=2),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
pytest.param(
Expand Down