From 643ef2dfd576013e18f5344157b3c2cde5a6ecfc Mon Sep 17 00:00:00 2001 From: laraabastoss <92671491+laraabastoss@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:01:35 +0100 Subject: [PATCH] Ran MyPy --- river/sketch/hierarchical_heavy_hitters.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/river/sketch/hierarchical_heavy_hitters.py b/river/sketch/hierarchical_heavy_hitters.py index 2c88166678..5e725ee6c2 100644 --- a/river/sketch/hierarchical_heavy_hitters.py +++ b/river/sketch/hierarchical_heavy_hitters.py @@ -143,7 +143,7 @@ def __init__(self): self.m_fe = 0 self.children: typing.dict[typing.Hashable, HierarchicalHeavyHitters.Node] = {} - def __init__(self, k: int, epsilon: float, parent_func: typing.Callable[[typing.Hashable, int], typing.Hashable] = None, root_value: typing.Hashable = None): + def __init__(self, k: int, epsilon: float, parent_func: typing.Optional[typing.Callable[[typing.Hashable, int], typing.Hashable]] = None, root_value: typing.Optional[typing.Hashable] = None): self.k = k self.epsilon = epsilon self.bucket_size = math.floor(1 / epsilon) @@ -220,7 +220,7 @@ def _compress_node(self, node: HierarchicalHeavyHitters.Node): node.max_e = max (node.max_e, child_node.ge + child_node.delta_e) del node.children[child_key] - def output(self, phi: float) -> list[typing.tuple[typing.Hashable, int]]: + def output(self, phi: float) -> list[tuple[typing.Hashable, int]]: """Generate a list of heavy hitters with frequency estimates above the given threshold.""" result: list[tuple[typing.Hashable, int]] = [] if self.root: @@ -269,7 +269,7 @@ def __getitem__(self, key: typing.Hashable) -> int: current = current.children[sub_key] - if sub_key == key: + if sub_key == key and current is not None: return current.ge else: return 0