diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index b9c75f4de93d..fb5c9f81b764 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -29,6 +29,7 @@ from apache_beam.utils.counters import Counter from apache_beam.utils.counters import CounterFactory from apache_beam.utils.counters import CounterName +from apache_beam.utils.threads import ParentAwareThread try: from apache_beam.runners.worker import statesampler_fast as statesampler_impl # type: ignore @@ -40,18 +41,39 @@ if TYPE_CHECKING: from apache_beam.metrics.execution import MetricsContainer -_STATE_SAMPLERS = threading.local() +# Global dictionary to store state samplers keyed by thread id. +_STATE_SAMPLERS = {} +_STATE_SAMPLERS_LOCK = threading.Lock() def set_current_tracker(tracker): - _STATE_SAMPLERS.tracker = tracker + """Sets state tracker for the calling thread.""" + with _STATE_SAMPLERS_LOCK: + if (tracker is None): + _STATE_SAMPLERS.pop(threading.get_ident()) + return + + _STATE_SAMPLERS[threading.get_ident()] = tracker def get_current_tracker(): - try: - return _STATE_SAMPLERS.tracker - except AttributeError: - return None + """Retrieve state tracker for the calling thread. + + If the thread is a ParentAwareThread (child thread that work was handed off + to) it attempts to retrieve the tracker associated with its parent thread. + + """ + current_thread_id = threading.get_ident() + + with _STATE_SAMPLERS_LOCK: + if current_thread_id in _STATE_SAMPLERS: + return _STATE_SAMPLERS[current_thread_id] + + current_thread = threading.current_thread() + if isinstance(current_thread, ParentAwareThread + ) and current_thread.parent_thread_id in _STATE_SAMPLERS: + return _STATE_SAMPLERS.get(current_thread.parent_thread_id) + return None _INSTRUCTION_IDS = threading.local() diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index d7415e8d8135..9c700bc97ba9 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2592,17 +2592,29 @@ def __getattribute__(self, name): return getattr(self._fn, name) def process(self, *args, **kwargs): - if self._pool is None: - self._pool = concurrent.futures.ThreadPoolExecutor(10) - # Ensure we iterate over the entire output list in the given amount of time. - try: - return self._pool.submit( - lambda: list(self._fn.process(*args, **kwargs))).result( - self._timeout) - except TimeoutError: - self._pool.shutdown(wait=False) - self._pool = None - raise + from apache_beam.utils.threads import ParentAwareThread + + results = [] + exception = None + + def run_process(): + try: + results.extend(self._fn.process(*args, **kwargs)) + except Exception as e: + nonlocal exception + exception = e + + thread = ParentAwareThread(target=run_process) + thread.start() + thread.join(self._timeout) + + if thread.is_alive(): + raise TimeoutError() + + if exception is not None: + raise exception + + return results def teardown(self): try: diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index a51d5cd83d26..a026ced6e98a 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2780,6 +2780,31 @@ def test_timeout(self): ('slow', 'TimeoutError()')]), label='CheckBad') + def test_increment_counter(self): + # Counters are not currently supported for + # ParDo#with_exception_handling(use_subprocess=True). + if (self.use_subprocess): + return + + class CounterDoFn(beam.DoFn): + def __init__(self): + self.records_counter = Metrics.counter(self.__class__, 'recordsCounter') + + def process(self, element): + self.records_counter.inc() + + with TestPipeline() as p: + _, _ = ( + (p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn()) + .with_exception_handling( + use_subprocess=self.use_subprocess, timeout=1)) + results = p.result + metric_results = results.metrics().query( + MetricsFilter().with_name("recordsCounter")) + records_counter = metric_results['counters'][0] + self.assertEqual(records_counter.key.metric.name, 'recordsCounter') + self.assertEqual(records_counter.result, 3) + def test_lifecycle(self): die = type(self).die diff --git a/sdks/python/apache_beam/utils/threads.py b/sdks/python/apache_beam/utils/threads.py new file mode 100644 index 000000000000..c269bb45db6a --- /dev/null +++ b/sdks/python/apache_beam/utils/threads.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import threading + + +class ParentAwareThread(threading.Thread): + """ + A thread subclass that is aware of its parent thread. + + This is useful in scenarios where work is executed in a child thread + (e.g. ParDo#with_exception_handling(timeout)) and the child thread requires + access to parent thread scoped state variables (e.g. state sampler). + + Attributes: + parent_thread_id (int): The identifier of the parent thread that created + this thread instance. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.parent_thread_id = threading.current_thread().ident diff --git a/sdks/python/apache_beam/utils/threads_test.py b/sdks/python/apache_beam/utils/threads_test.py new file mode 100644 index 000000000000..77bad8946586 --- /dev/null +++ b/sdks/python/apache_beam/utils/threads_test.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for thread utilities.""" + +import unittest +import threading + +from apache_beam.utils.threads import ParentAwareThread + + +class ParentAwareThreadTest(unittest.TestCase): + def test_child_tread_can_access_parent_thread_id(self): + expected_parent_thread_id = threading.get_ident() + actual_parent_thread_id = None + + def get_parent_thread_id(): + nonlocal actual_parent_thread_id + actual_parent_thread_id = threading.current_thread().parent_thread_id + + thread = ParentAwareThread(target=get_parent_thread_id) + thread.start() + thread.join() + + self.assertEqual(expected_parent_thread_id, actual_parent_thread_id)