Skip to content

Commit

Permalink
Use spawn as the fork method for the profiler test. (#6302)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored and bhavya01 committed Apr 22, 2024
1 parent b5c8a90 commit 184522e
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
import torch_xla.utils.utils as xu


# This function must remain a top-level function. Using spawn
# as the fork method requires this function to be pickle-able.
def train_worker(port, training_started):
flags = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=16,
momentum=0.5,
lr=0.01,
num_epochs=10)
flags.fake_data = True
flags.profiler_port = port

# Disable programmatic profiling
flags.profile_step = -1
flags.profile_epoch = -1
flags.profile_logdir = None
flags.profile_duration_ms = -1

test_profile_mp_mnist.train_mnist(
flags,
training_started=training_started,
dynamic_graph=True,
fetch_often=True)


class ProfilerTest(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -51,33 +76,15 @@ def _check_trace_namespace_exists(self, path):
f'Expected "build_graph" trace in: {path}')

def test_trace_and_metrics(self):
# Create a new context for forking processes with the spawn method.
# This is necessary so as to avoid CUDA initialization issues when
# both PyTorch and PyTorch/XLA were compiled with CUDA support.
context = multiprocessing.get_context("spawn")

port = xu.get_free_tcp_ports()[0]
training_started = multiprocessing.Event()

def train_worker():
flags = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=16,
momentum=0.5,
lr=0.01,
num_epochs=10)
flags.fake_data = True
flags.profiler_port = port

# Disable programmatic profiling
flags.profile_step = -1
flags.profile_epoch = -1
flags.profile_logdir = None
flags.profile_duration_ms = -1

test_profile_mp_mnist.train_mnist(
flags,
training_started=training_started,
dynamic_graph=True,
fetch_often=True)

p = multiprocessing.Process(target=train_worker, daemon=True)
training_started = context.Event()
p = context.Process(
target=train_worker, args=(port, training_started), daemon=True)
p.start()
training_started.wait(60)

Expand Down

0 comments on commit 184522e

Please sign in to comment.