Skip to content

Commit

Permalink
Create and use ParentAwareThread in Pardo#with_exception_handling(tim…
Browse files Browse the repository at this point in the history
…eout)
  • Loading branch information
Claude committed Sep 26, 2024
1 parent 1eddbdc commit 60d7182
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 17 deletions.
34 changes: 28 additions & 6 deletions sdks/python/apache_beam/runners/worker/statesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from apache_beam.utils.counters import Counter
from apache_beam.utils.counters import CounterFactory
from apache_beam.utils.counters import CounterName
from apache_beam.utils.threads import ParentAwareThread

try:
from apache_beam.runners.worker import statesampler_fast as statesampler_impl # type: ignore
Expand All @@ -40,18 +41,39 @@
if TYPE_CHECKING:
from apache_beam.metrics.execution import MetricsContainer

_STATE_SAMPLERS = threading.local()
# Global dictionary to store state samplers keyed by thread id.
_STATE_SAMPLERS = {}
_STATE_SAMPLERS_LOCK = threading.Lock()


def set_current_tracker(tracker):
_STATE_SAMPLERS.tracker = tracker
"""Sets state tracker for the calling thread."""
with _STATE_SAMPLERS_LOCK:
if (tracker is None):
_STATE_SAMPLERS.pop(threading.get_ident())
return

_STATE_SAMPLERS[threading.get_ident()] = tracker


def get_current_tracker():
try:
return _STATE_SAMPLERS.tracker
except AttributeError:
return None
"""Retrieve state tracker for the calling thread.
If the thread is a ParentAwareThread (child thread that work was handed off
to) it attempts to retrieve the tracker associated with its parent thread.
"""
current_thread_id = threading.get_ident()

with _STATE_SAMPLERS_LOCK:
if current_thread_id in _STATE_SAMPLERS:
return _STATE_SAMPLERS[current_thread_id]

current_thread = threading.current_thread()
if isinstance(current_thread, ParentAwareThread
) and current_thread.parent_thread_id in _STATE_SAMPLERS:
return _STATE_SAMPLERS.get(current_thread.parent_thread_id)
return None


_INSTRUCTION_IDS = threading.local()
Expand Down
34 changes: 23 additions & 11 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,17 +2592,29 @@ def __getattribute__(self, name):
return getattr(self._fn, name)

def process(self, *args, **kwargs):
if self._pool is None:
self._pool = concurrent.futures.ThreadPoolExecutor(10)
# Ensure we iterate over the entire output list in the given amount of time.
try:
return self._pool.submit(
lambda: list(self._fn.process(*args, **kwargs))).result(
self._timeout)
except TimeoutError:
self._pool.shutdown(wait=False)
self._pool = None
raise
from apache_beam.utils.threads import ParentAwareThread

results = []
exception = None

def run_process():
try:
results.extend(self._fn.process(*args, **kwargs))
except Exception as e:
nonlocal exception
exception = e

thread = ParentAwareThread(target=run_process)
thread.start()
thread.join(self._timeout)

if thread.is_alive():
raise TimeoutError()

if exception is not None:
raise exception

return results

def teardown(self):
try:
Expand Down
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,6 +2780,31 @@ def test_timeout(self):
('slow', 'TimeoutError()')]),
label='CheckBad')

def test_increment_counter(self):
# Counters are not currently supported for
# ParDo#with_exception_handling(use_subprocess=True).
if (self.use_subprocess):
return

class CounterDoFn(beam.DoFn):
def __init__(self):
self.records_counter = Metrics.counter(self.__class__, 'recordsCounter')

def process(self, element):
self.records_counter.inc()

with TestPipeline() as p:
_, _ = (
(p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess, timeout=1))
results = p.result
metric_results = results.metrics().query(
MetricsFilter().with_name("recordsCounter"))
records_counter = metric_results['counters'][0]
self.assertEqual(records_counter.key.metric.name, 'recordsCounter')
self.assertEqual(records_counter.result, 3)

def test_lifecycle(self):
die = type(self).die

Expand Down
35 changes: 35 additions & 0 deletions sdks/python/apache_beam/utils/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import threading


class ParentAwareThread(threading.Thread):
"""
A thread subclass that is aware of its parent thread.
This is useful in scenarios where work is executed in a child thread
(e.g. ParDo#with_exception_handling(timeout)) and the child thread requires
access to parent thread scoped state variables (e.g. state sampler).
Attributes:
parent_thread_id (int): The identifier of the parent thread that created
this thread instance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.parent_thread_id = threading.current_thread().ident
39 changes: 39 additions & 0 deletions sdks/python/apache_beam/utils/threads_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for thread utilities."""

import unittest
import threading

from apache_beam.utils.threads import ParentAwareThread


class ParentAwareThreadTest(unittest.TestCase):
def test_child_tread_can_access_parent_thread_id(self):
expected_parent_thread_id = threading.get_ident()
actual_parent_thread_id = None

def get_parent_thread_id():
nonlocal actual_parent_thread_id
actual_parent_thread_id = threading.current_thread().parent_thread_id

thread = ParentAwareThread(target=get_parent_thread_id)
thread.start()
thread.join()

self.assertEqual(expected_parent_thread_id, actual_parent_thread_id)

0 comments on commit 60d7182

Please sign in to comment.