Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use spawn in _compat_test.py to avoid fork problems #6374

Merged
merged 2 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 45 additions & 38 deletions cirq-core/cirq/_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_wrap_module():


def test_deprecate_attributes_assert_attributes_in_sys_modules():
subprocess_context(_test_deprecate_attributes_assert_attributes_in_sys_modules)()
run_in_subprocess(_test_deprecate_attributes_assert_attributes_in_sys_modules)


def _test_deprecate_attributes_assert_attributes_in_sys_modules():
Expand Down Expand Up @@ -635,42 +635,49 @@ def _type_repr_in_deprecated_module():
] + _deprecation_origin


def _trace_unhandled_exceptions(*args, queue: 'multiprocessing.Queue', func: Callable, **kwargs):
def _trace_unhandled_exceptions(*args, queue: 'multiprocessing.Queue', func: Callable):
try:
func(*args, **kwargs)
func(*args)
queue.put(None)
except BaseException as ex:
msg = str(ex)
queue.put((type(ex).__name__, msg, traceback.format_exc()))


def subprocess_context(test_func):
"""Ensures that sys.modules changes in subprocesses won't impact the parent process."""
def run_in_subprocess(test_func, *args):
"""Run a function in a subprocess.

This ensures that sys.modules changes in subprocesses won't impact the parent process.

Args:
test_func: The function to be run in a subprocess.
*args: Positional args to pass to the function.
"""

assert callable(test_func), (
"subprocess_context expects a function. Did you call the function instead of passing "
"run_in_subprocess expects a function. Did you call the function instead of passing "
"it to this method?"
)

ctx = multiprocessing.get_context('spawn' if os.name == 'nt' else 'fork')

exception = ctx.Queue()
# Use spawn to ensure subprocesses are isolated.
# See https://github.com/quantumlib/Cirq/issues/6373
ctx = multiprocessing.get_context('spawn')

def isolated_func(*args, **kwargs):
kwargs['queue'] = exception
kwargs['func'] = test_func
p = ctx.Process(target=_trace_unhandled_exceptions, args=args, kwargs=kwargs)
p.start()
p.join()
result = exception.get()
if result: # pragma: no cover
ex_type, msg, ex_trace = result
if ex_type == "Skipped":
warnings.warn(f"Skipping: {ex_type}: {msg}\n{ex_trace}")
pytest.skip(f'{ex_type}: {msg}\n{ex_trace}')
else:
pytest.fail(f'{ex_type}: {msg}\n{ex_trace}')
queue = ctx.Queue()

return isolated_func
p = ctx.Process(
target=_trace_unhandled_exceptions, args=args, kwargs={'queue': queue, 'func': test_func}
)
p.start()
p.join()
result = queue.get()
if result: # pragma: no cover
ex_type, msg, ex_trace = result
if ex_type == "Skipped":
warnings.warn(f"Skipping: {ex_type}: {msg}\n{ex_trace}")
pytest.skip(f'{ex_type}: {msg}\n{ex_trace}')
else:
pytest.fail(f'{ex_type}: {msg}\n{ex_trace}')


@mock.patch.dict(os.environ, {"CIRQ_FORCE_DEDUPE_MODULE_DEPRECATION": "1"})
Expand Down Expand Up @@ -698,7 +705,7 @@ def isolated_func(*args, **kwargs):
],
)
def test_deprecated_module(outdated_method, deprecation_messages):
subprocess_context(_test_deprecated_module_inner)(outdated_method, deprecation_messages)
run_in_subprocess(_test_deprecated_module_inner, outdated_method, deprecation_messages)


def _test_deprecated_module_inner(outdated_method, deprecation_messages):
Expand Down Expand Up @@ -736,7 +743,7 @@ def test_same_name_submodule_earlier_in_subtree():
cirq.ops.engine.calibration packages. The wrong resolution resulted in false circular
imports!
"""
subprocess_context(_test_same_name_submodule_earlier_in_subtree_inner)()
run_in_subprocess(_test_same_name_submodule_earlier_in_subtree_inner)


def _test_same_name_submodule_earlier_in_subtree_inner():
Expand All @@ -748,7 +755,7 @@ def _test_same_name_submodule_earlier_in_subtree_inner():
def test_metadata_search_path():
# to cater for metadata path finders
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
subprocess_context(_test_metadata_search_path_inner)()
run_in_subprocess(_test_metadata_search_path_inner)


def _test_metadata_search_path_inner(): # pragma: no cover
Expand All @@ -760,7 +767,7 @@ def _test_metadata_search_path_inner(): # pragma: no cover


def test_metadata_distributions_after_deprecated_submodule():
subprocess_context(_test_metadata_distributions_after_deprecated_submodule)()
run_in_subprocess(_test_metadata_distributions_after_deprecated_submodule)


def _test_metadata_distributions_after_deprecated_submodule():
Expand All @@ -779,7 +786,7 @@ def _test_metadata_distributions_after_deprecated_submodule():


def test_parent_spec_after_deprecated_submodule():
subprocess_context(_test_parent_spec_after_deprecated_submodule)()
run_in_subprocess(_test_parent_spec_after_deprecated_submodule)


def _test_parent_spec_after_deprecated_submodule():
Expand All @@ -791,7 +798,7 @@ def _test_parent_spec_after_deprecated_submodule():
def test_type_repr_in_new_module():
# to cater for metadata path finders
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
subprocess_context(_test_type_repr_in_new_module_inner)()
run_in_subprocess(_test_type_repr_in_new_module_inner)


def _test_type_repr_in_new_module_inner():
Expand Down Expand Up @@ -849,19 +856,19 @@ def _test_broken_module_3_inner():


def test_deprecated_module_error_handling_1():
subprocess_context(_test_broken_module_1_inner)()
run_in_subprocess(_test_broken_module_1_inner)


def test_deprecated_module_error_handling_2():
subprocess_context(_test_broken_module_2_inner)()
run_in_subprocess(_test_broken_module_2_inner)


def test_deprecated_module_error_handling_3():
subprocess_context(_test_broken_module_3_inner)()
run_in_subprocess(_test_broken_module_3_inner)


def test_new_module_is_top_level():
subprocess_context(_test_new_module_is_top_level_inner)()
run_in_subprocess(_test_new_module_is_top_level_inner)


def _test_new_module_is_top_level_inner():
Expand All @@ -877,7 +884,7 @@ def _test_new_module_is_top_level_inner():


def test_import_deprecated_with_no_attribute():
subprocess_context(_test_import_deprecated_with_no_attribute_inner)()
run_in_subprocess(_test_import_deprecated_with_no_attribute_inner)


def _test_import_deprecated_with_no_attribute_inner():
Expand Down Expand Up @@ -970,23 +977,23 @@ def module_repr(self, module: ModuleType) -> str:

def test_subprocess_test_failure():
with pytest.raises(Failed, match='ValueError.*this fails'):
subprocess_context(_test_subprocess_test_failure_inner)()
run_in_subprocess(_test_subprocess_test_failure_inner)


def _test_subprocess_test_failure_inner():
raise ValueError('this fails')


def test_dir_is_still_valid():
subprocess_context(_dir_is_still_valid_inner)()
run_in_subprocess(_dir_is_still_valid_inner)


def _dir_is_still_valid_inner():
"""to ensure that create_attribute=True keeps the dir(module) intact"""

import cirq.testing._compat_test_data as mod

for m in ['fake_a', 'info', 'module_a', 'sys']:
for m in ['fake_a', 'logging', 'module_a']:
assert m in dir(mod)


Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/testing/_compat_test_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
See cirq/_compat_test.py for the tests.
This module contains example deprecations for modules.
"""
import sys
from logging import info
import logging

from cirq import _compat

info("init:compat_test_data")
logging.info("init:compat_test_data")

# simulates a rename of a child module
# fake_a -> module_a
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
"""module_a for module deprecation tests"""

from logging import info
import logging

from cirq.testing._compat_test_data.module_a import module_b

Expand All @@ -11,4 +11,4 @@

MODULE_A_ATTRIBUTE = "module_a"

info("init:module_a")
logging.info("init:module_a")
Loading