Skip to content

Commit

Permalink
signal+tests: Add all possible overloads for gausspulse.
Browse files Browse the repository at this point in the history
`t` can be either an array or a scalar or "cutoff". Add corresponding
tests
  • Loading branch information
pavyamsiri committed Nov 24, 2024
1 parent 2fd3a6e commit 4edcf25
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 65 deletions.
186 changes: 136 additions & 50 deletions scipy-stubs/signal/_waveforms.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ _Truthy: TypeAlias = Literal[1, True]
_Falsy: TypeAlias = Literal[0, False]
_ArrayLikeFloat: TypeAlias = onp.ToFloat | onp.ToFloatND
_Array_f8: TypeAlias = onp.ArrayND[np.float64]
_GaussPulseTime: TypeAlias = _ArrayLikeFloat | Literal["cutoff"]

# Type vars to annotate `chirp`
_NBT1 = TypeVar("_NBT1", bound=npt.NBitBase)
Expand All @@ -32,9 +31,127 @@ def sawtooth(t: _ArrayLikeFloat, width: _ArrayLikeFloat = 1) -> _Array_f8: ...
def square(t: _ArrayLikeFloat, duty: _ArrayLikeFloat = 0.5) -> _Array_f8: ...

#
@overload # Static type checking for float values
def chirp(
t: _ChirpTime[_NBT1],
f0: _ChirpScalar[_NBT2],
t1: _ChirpScalar[_NBT3],
f1: _ChirpScalar[_NBT4],
method: _ChirpMethod = "linear",
phi: _ChirpScalar[_NBT5] = 0,
vertex_zero: op.CanBool = True,
) -> onp.ArrayND[np.floating[_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5]]: ...
@overload # Other dtypes default to np.float64
def chirp(
t: onp.ToFloatND | _NestedSequence[float],
f0: onp.ToFloat,
t1: onp.ToFloat,
f1: onp.ToFloat,
method: _ChirpMethod = "linear",
phi: onp.ToFloat = 0,
vertex_zero: op.CanBool = True,
) -> _Array_f8: ...

#
def sweep_poly(
t: _ArrayLikeFloat,
poly: onp.ToFloatND | np.poly1d,
phi: onp.ToFloat = 0,
) -> _Array_f8: ...

#
@overload # dtype is not given
def unit_impulse(
shape: AnyShape,
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None = None,
dtype: type[float] = ...,
) -> _Array_f8: ...
@overload # dtype is given
def unit_impulse(
shape: AnyShape,
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None,
dtype: _DTypeLike[_SCT],
) -> npt.NDArray[_SCT]: ...

# Overloads for gausspulse when `t` is scalar
@overload # retquad: False = ..., retenv: False = ...
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
tpr: onp.ToFloat = -60,
retquad: _Falsy = False,
retenv: _Falsy = False,
) -> np.float64: ...
@overload # retquad: False = ..., retenv: True (keyword)
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
tpr: onp.ToFloat = -60,
retquad: _Falsy = False,
*,
retenv: _Truthy,
) -> tuple[np.float64, np.float64]: ...
@overload # retquad: False (positional), retenv: False (positional)
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat,
bw: onp.ToFloat,
bwr: onp.ToFloat,
tpr: onp.ToFloat,
retquad: _Falsy,
retenv: _Truthy,
) -> tuple[np.float64, np.float64]: ...
@overload # retquad: True (positional), retenv: False = ...
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat,
bw: onp.ToFloat,
bwr: onp.ToFloat,
tpr: onp.ToFloat,
retquad: _Truthy,
retenv: _Falsy = False,
) -> tuple[np.float64, np.float64]: ...
@overload # retquad: True (keyword), retenv: False = ...
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
tpr: onp.ToFloat = -60,
*,
retquad: _Truthy,
retenv: _Falsy = False,
) -> tuple[np.float64, np.float64]: ...
@overload # retquad: True (positional), retenv: True (positional/keyword)
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat,
bw: onp.ToFloat,
bwr: onp.ToFloat,
tpr: onp.ToFloat,
retquad: _Truthy,
retenv: _Truthy,
) -> tuple[np.float64, np.float64, np.float64]: ...
@overload # retquad: True (keyword), retenv: True
def gausspulse(
t: onp.ToFloat,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
tpr: onp.ToFloat = -60,
*,
retquad: _Truthy,
retenv: _Truthy,
) -> tuple[np.float64, np.float64, np.float64]: ...

# Overloads for `gausspulse` when `t` is a non-scalar array like
@overload # retquad: False = ..., retenv: False = ...
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
Expand All @@ -44,7 +161,7 @@ def gausspulse(
) -> _Array_f8: ...
@overload # retquad: False = ..., retenv: True (keyword)
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
Expand All @@ -55,7 +172,7 @@ def gausspulse(
) -> tuple[_Array_f8, _Array_f8]: ...
@overload # retquad: False (positional), retenv: False (positional)
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat,
bw: onp.ToFloat,
bwr: onp.ToFloat,
Expand All @@ -65,7 +182,7 @@ def gausspulse(
) -> tuple[_Array_f8, _Array_f8]: ...
@overload # retquad: True (positional), retenv: False = ...
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat,
bw: onp.ToFloat,
bwr: onp.ToFloat,
Expand All @@ -75,7 +192,7 @@ def gausspulse(
) -> tuple[_Array_f8, _Array_f8]: ...
@overload # retquad: True (keyword), retenv: False = ...
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
Expand All @@ -86,7 +203,7 @@ def gausspulse(
) -> tuple[_Array_f8, _Array_f8]: ...
@overload # retquad: True (positional), retenv: True (positional/keyword)
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat,
bw: onp.ToFloat,
bwr: onp.ToFloat,
Expand All @@ -96,7 +213,7 @@ def gausspulse(
) -> tuple[_Array_f8, _Array_f8, _Array_f8]: ...
@overload # retquad: True (keyword), retenv: True
def gausspulse(
t: _GaussPulseTime,
t: onp.ToFloatND,
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
Expand All @@ -106,45 +223,14 @@ def gausspulse(
retenv: _Truthy,
) -> tuple[_Array_f8, _Array_f8, _Array_f8]: ...

#
@overload # Static type checking for float values
def chirp(
t: _ChirpTime[_NBT1],
f0: _ChirpScalar[_NBT2],
t1: _ChirpScalar[_NBT3],
f1: _ChirpScalar[_NBT4],
method: _ChirpMethod = "linear",
phi: _ChirpScalar[_NBT5] = 0,
vertex_zero: op.CanBool = True,
) -> onp.ArrayND[np.floating[_NBT1 | _NBT2 | _NBT3 | _NBT4 | _NBT5]]: ...
@overload # Other dtypes default to np.float64
def chirp(
t: onp.ToFloatND | _NestedSequence[float],
f0: onp.ToFloat,
t1: onp.ToFloat,
f1: onp.ToFloat,
method: _ChirpMethod = "linear",
phi: onp.ToFloat = 0,
vertex_zero: op.CanBool = True,
) -> _Array_f8: ...

#
def sweep_poly(
t: _ArrayLikeFloat,
poly: onp.ToFloatND | np.poly1d,
phi: onp.ToFloat = 0,
) -> _Array_f8: ...

#
@overload # dtype is not given
def unit_impulse(
shape: AnyShape,
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None = None,
dtype: type[float] = ...,
) -> _Array_f8: ...
@overload # dtype is given
def unit_impulse(
shape: AnyShape,
idx: op.CanIndex | Iterable[op.CanIndex] | Literal["mid"] | None,
dtype: _DTypeLike[_SCT],
) -> npt.NDArray[_SCT]: ...
# Overloads for gausspulse when `t` is `"cutoff"`
@overload # retquad: False = ..., retenv: False = ...
def gausspulse(
t: Literal["cutoff"],
fc: onp.ToFloat = 1000,
bw: onp.ToFloat = 0.5,
bwr: onp.ToFloat = -6,
tpr: onp.ToFloat = -60,
retquad: op.CanBool = False,
retenv: op.CanBool = False,
) -> np.float64: ...
Loading

0 comments on commit 4edcf25

Please sign in to comment.