diff --git a/test/test_profiler.py b/test/test_profiler.py index fb04c59bc05..2f6dce2ebcc 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -14,6 +14,31 @@ import torch_xla.utils.utils as xu +# This function must remain a top-level function. Using spawn +# as the fork method requires this function to be pickle-able. +def train_worker(port, training_started): + flags = args_parse.parse_common_options( + datadir='/tmp/mnist-data', + batch_size=16, + momentum=0.5, + lr=0.01, + num_epochs=10) + flags.fake_data = True + flags.profiler_port = port + + # Disable programmatic profiling + flags.profile_step = -1 + flags.profile_epoch = -1 + flags.profile_logdir = None + flags.profile_duration_ms = -1 + + test_profile_mp_mnist.train_mnist( + flags, + training_started=training_started, + dynamic_graph=True, + fetch_often=True) + + class ProfilerTest(unittest.TestCase): def setUp(self): @@ -51,33 +76,15 @@ def _check_trace_namespace_exists(self, path): f'Expected "build_graph" trace in: {path}') def test_trace_and_metrics(self): + # Create a new context for forking processes with the spawn method. + # This is necessary so as to avoid CUDA initialization issues when + # both PyTorch and PyTorch/XLA were compiled with CUDA support. + context = multiprocessing.get_context("spawn") port = xu.get_free_tcp_ports()[0] - training_started = multiprocessing.Event() - - def train_worker(): - flags = args_parse.parse_common_options( - datadir='/tmp/mnist-data', - batch_size=16, - momentum=0.5, - lr=0.01, - num_epochs=10) - flags.fake_data = True - flags.profiler_port = port - - # Disable programmatic profiling - flags.profile_step = -1 - flags.profile_epoch = -1 - flags.profile_logdir = None - flags.profile_duration_ms = -1 - - test_profile_mp_mnist.train_mnist( - flags, - training_started=training_started, - dynamic_graph=True, - fetch_often=True) - - p = multiprocessing.Process(target=train_worker, daemon=True) + training_started = context.Event() + p = context.Process( + target=train_worker, args=(port, training_started), daemon=True) p.start() training_started.wait(60)