From e76d4c65a05ce17a2666da52b74c284f9fca391f Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 26 Sep 2024 20:18:47 -0400 Subject: [PATCH] use member functions instead of inheritance --- sdks/python/apache_beam/coders/coder_impl.py | 21 +++++++------- sdks/python/apache_beam/transforms/window.py | 30 +++++++++++--------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 1b33f2365ead1..6c664e88fe6ff 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -74,7 +74,6 @@ if TYPE_CHECKING: import proto from apache_beam.transforms import userstate - from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import IntervalWindow try: @@ -807,7 +806,6 @@ def estimate_size(self, unused_value, nested=False): if not TYPE_CHECKING: IntervalWindow = None - GlobalWindow = None class IntervalWindowCoderImpl(StreamCoderImpl): @@ -824,7 +822,11 @@ def _from_normal_time(self, value): def encode_to_stream(self, value, out, nested): # type: (IntervalWindow, create_OutputStream, bool) -> None - typed_value = value + if not TYPE_CHECKING: + global IntervalWindow # pylint: disable=global-variable-not-assigned + if IntervalWindow is None: + from apache_beam.transforms.window import IntervalWindow + typed_value = IntervalWindow.try_from_global_window(value) span_millis = ( typed_value._end_micros // 1000 - typed_value._start_micros // 1000) out.write_bigendian_uint64( @@ -836,7 +838,6 @@ def decode_from_stream(self, in_, nested): if not TYPE_CHECKING: global IntervalWindow # pylint: disable=global-variable-not-assigned if IntervalWindow is None: - from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import IntervalWindow # instantiating with None is not part of the public interface typed_value = IntervalWindow(None, None) # type: ignore[arg-type] @@ -844,17 +845,17 @@ def decode_from_stream(self, in_, nested): 1000 * self._to_normal_time(in_.read_bigendian_uint64())) typed_value._start_micros = ( typed_value._end_micros - 1000 * in_.read_var_int64()) - gw = GlobalWindow() - if typed_value == gw: - return gw - - return typed_value + return typed_value.try_to_global_window() def estimate_size(self, value, nested=False): # type: (Any, bool) -> int # An IntervalWindow is context-insensitive, with a timestamp (8 bytes) # and a varint timespam. - typed_value = value + if not TYPE_CHECKING: + global IntervalWindow # pylint: disable=global-variable-not-assigned + if IntervalWindow is None: + from apache_beam.transforms.window import IntervalWindow + typed_value = IntervalWindow.try_from_global_window(value) span_millis = ( typed_value._end_micros // 1000 - typed_value._start_micros // 1000) return 8 + get_varint_size(span_millis) diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index 7da58c293932e..691b2b567e658 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -261,21 +261,25 @@ def __lt__(self, other): return self.end < other.end return hash(self) < hash(other) - def __eq__(self, other): - return ( - self is other or ( - type(self) is type(other) and self.end == other.end and - self.start == other.start)) - - def __hash__(self): - return hash((self.start, self.end)) - def intersects(self, other: 'IntervalWindow') -> bool: return other.start < self.end or self.start < other.end def union(self, other: 'IntervalWindow') -> 'IntervalWindow': return IntervalWindow( min(self.start, other.start), max(self.end, other.end)) + + @staticmethod + def try_from_global_window(value) -> 'IntervalWindow': + gw = GlobalWindow() + if gw == value: + return IntervalWindow(gw.start, GlobalWindow._getTimestampFromProto()) + return value + + def try_to_global_window(self) -> BoundedWindow: + gw = GlobalWindow() + if self.start == gw.start and self.end == GlobalWindow._getTimestampFromProto(): + return gw + return IntervalWindow(gw.start(), GlobalWindow._getTimestampFromProto()) V = TypeVar("V") @@ -309,7 +313,7 @@ def __lt__(self, other): return self.timestamp < other.timestamp -class GlobalWindow(IntervalWindow): +class GlobalWindow(BoundedWindow): """The default window into which all data is placed (via GlobalWindows).""" _instance: Optional['GlobalWindow'] = None @@ -319,7 +323,7 @@ def __new__(cls): return cls._instance def __init__(self) -> None: - super().__init__(MIN_TIMESTAMP, GlobalWindow._getTimestampFromProto()) + super().__init__(GlobalWindow._getTimestampFromProto()) def __repr__(self): return 'GlobalWindow' @@ -328,9 +332,7 @@ def __hash__(self): return hash(type(self)) def __eq__(self, other): - return ( - self is other or type(self) is type(other) or - (type(other) is IntervalWindow and other.__eq__(self))) + return self is other or type(self) is type(other) @property def start(self) -> Timestamp: