Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ddp testing #2856

Closed
wants to merge 18 commits into from
Closed
54 changes: 54 additions & 0 deletions tests/models/data/ddp/train_test_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Runs several combinations of `.fit()` and `.test()` on a single node across multiple gpus.
"""
from argparse import ArgumentParser

from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate


def variation_fit_test(trainer, model):
trainer.fit(model)
trainer.test(model)


def variation_test_fit(trainer, model):
trainer.test(model)
trainer.fit(model)


def variation_test_test(trainer, model):
trainer.test(model)
trainer.test(model)


awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def variation_test_fit_test(trainer, model):
trainer.test(model)
trainer.fit(model)
trainer.test(model)


def get_variations():
variations = [v for v in globals() if v.startswith("variation")]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return variations


def main():
seed_everything(1234)
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--variation', default=variation_fit_test.__name__)
parser.set_defaults(gpus=2)
parser.set_defaults(distributed_backend="ddp")
args = parser.parse_args()

model = EvalModelTemplate()
trainer = Trainer.from_argparse_args(args)

# run the chosen variation
run_variation = globals()[args.variation]
run_variation(trainer, model)


if __name__ == '__main__':
main()
34 changes: 34 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import subprocess
import sys
from collections import namedtuple
from pathlib import Path
from unittest import mock

import pytest
import torch
Expand All @@ -11,6 +15,7 @@
from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.models.data.ddp import train_test_variations

PRETEND_N_OF_GPUS = 16

Expand Down Expand Up @@ -93,6 +98,35 @@ def test_multi_gpu_model_dp(tmpdir):
memory.get_memory_profile('min_max')


@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.parametrize('variation', train_test_variations.get_variations())
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp(tmpdir, cli_args, variation):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
file = Path(train_test_variations.__file__).absolute()
cli_args = cli_args.split(' ') if cli_args else []
cli_args += ['--default_root_dir', str(tmpdir)]
command = [sys.executable, file, '--variation', variation] + cli_args
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
p.communicate()
# assert p.returncode == 0
std, err = p.communicate(timeout=60)
std = std.decode('utf-8').strip()
err = err.decode('utf-8').strip()
# assert std and not err
if p.returncode:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about if p.returncode is falsey? It should probably raise something, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think return code 0 means success, and > 0 fail. It's when you do sys.exit(0) it's success

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh makes sense. Maybe p.returncode != 0? Covers the case that it's None etc..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will incorporate all your feedback comments in the follow up PR #2997

print(std)
print(err)
print(command)
pytest.fail(err)

# cli_args += ['--variation', variation]
# from tests.models.data.ddp.train_test_variations import main
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
# main()


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_spawn(tmpdir):
tutils.set_random_master_port()
Expand Down