Skip to content

Commit

Permalink
Add type stubs for Cython files (#1594)
Browse files Browse the repository at this point in the history
* Check for missing imports in River

* Add type stubs for expected_mutual_info.pyx

* Add type stubs for efficient_rollingrocauc.pyx

* Add type stub for adwin_c

* Add basic type stubs for vectordict

* Remove modulo from Vectordict.__pow__

This brings the signature in line with how Python defines the operator.
The modulo parameter was not used.

* Check the stubs files in pre-commit
  • Loading branch information
e10e3 authored Aug 22, 2024
1 parent 71a127f commit bd0e31b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ repos:
- id: ruff
name: ruff
language: python
types: [python]
types: [python, pyi]
entry: ruff
args:
- --fix

- id: mypy
name: mypy
language: python
types: [python]
types: [python, pyi]
entry: mypy --implicit-optional
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ files = "river"

[[tool.mypy.overrides]]
module = [
"river.*",
"mmh3.*",
"numpy.*",
"sklearn.*",
Expand Down
27 changes: 27 additions & 0 deletions river/drift/adwin_c.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class AdaptiveWindowing:
def __init__(
self,
delta: float = 0.002,
clock: int = 32,
max_buckets: int = 5,
min_window_length: int = 5,
grace_period: int = 10,
) -> None: ...
def get_n_detections(self) -> int: ...
def get_width(self) -> float: ...
def get_total(self) -> float: ...
def get_variance(self) -> float: ...
@property
def variance_in_window(self) -> float: ...
def update(self, value: float) -> bool: ...

class Bucket:
def __init__(self, max_size: int) -> None: ...
def clear_at(self, index: int) -> None: ...
def insert_data(self, value: float, variance: float) -> None: ...
def remove(self) -> None: ...
def compress(self, n_elements: int) -> None: ...
def get_total_at(self, index: int) -> float: ...
def get_variance_at(self, index: int) -> float: ...
def set_total_at(self, value: float, index: int) -> None: ...
def set_variance_at(self, value: float, index: int) -> None: ...
12 changes: 12 additions & 0 deletions river/metrics/efficient_rollingrocauc/efficient_rollingrocauc.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from collections.abc import Sequence
from typing import Any

class EfficientRollingROCAUC:
def __cinit__(self, positiveLabel: int, windowSize: int) -> None: ...
def __dealloc__(self) -> None: ...
def update(self, label: bool, score: bool | float | dict[bool, float]) -> None: ...
def revert(self, label: bool, score: bool | float | dict[bool, float]) -> None: ...
def get(self) -> float: ...
def __getnewargs_ex__(self) -> tuple[tuple[int, int], dict[str, Any]]: ...
def __getstate__(self) -> tuple[Sequence[int], Sequence[float]]: ...
def __setstate__(self, state: tuple[Sequence[int], Sequence[float]]) -> None: ...
3 changes: 3 additions & 0 deletions river/metrics/expected_mutual_info.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from river import metrics

def expected_mutual_info(confusion_matrix: metrics.ConfusionMatrix) -> float: ...
56 changes: 56 additions & 0 deletions river/utils/vectordict.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections.abc import Callable

import numpy

def get_union_keys(left: VectorDict, right: VectorDict): ...
def get_intersection_keys(left: VectorDict, right: VectorDict): ...

class VectorDict:
def __init__(
self,
data: VectorDict | dict | None = None,
default_factory: Callable | None = None,
mask: VectorDict | set | None = None,
copy: bool = False,
) -> None: ...
def with_mask(self, mask, copy=False): ...
def to_dict(self): ...
def to_numpy(self, fields) -> numpy.ndarray: ...
def __contains__(self, key): ...
def __delitem__(self, key): ...
def __format__(self, format_spec): ...
def __getitem__(self, key): ...
def __iter__(self): ...
def __len__(self): ...
def __repr__(self): ...
def __setitem__(self, key, value): ...
def __str__(self): ...
def clear(self): ...
def get(self, key, *args, **kwargs): ...
def items(self): ...
def keys(self): ...
def pop(self, *args, **kwargs): ...
def popitem(self): ...
def setdefault(self, key, *args, **kwargs): ...
def update(self, *args, **kwargs): ...
def values(self): ...
def __eq__(left, right): ...
def __add__(left, right): ...
def __iadd__(self, other): ...
def __sub__(left, right): ...
def __isub__(self, other): ...
def __mul__(left, right): ...
def __imul__(self, other): ...
def __truediv__(left, right): ...
def __itruediv__(self, other): ...
def __pow__(left, right): ...
def __ipow__(self, other): ...
def __matmul__(left, right): ...
def __neg__(self): ...
def __pos__(self): ...
def __abs__(self): ...
def abs(self): ...
def min(self): ...
def max(self): ...
def minimum(self, other): ...
def maximum(self, other): ...
4 changes: 2 additions & 2 deletions river/utils/vectordict.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ cdef class VectorDict:
return NotImplemented
return self

def __pow__(left, right, modulo):
if not isinstance(left, VectorDict) or modulo is not None:
def __pow__(left, right):
if not isinstance(left, VectorDict):
return NotImplemented
left_ = <VectorDict> left
res = left_._to_dict(force_copy=True)
Expand Down

0 comments on commit bd0e31b

Please sign in to comment.