Skip to content

Commit

Permalink
Fix binary operations involving unions wrapped in Annotated (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jun 5, 2024
1 parent 57fb3fd commit e8ed0f3
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Fix binary operations involving unions wrapped in `Annotated` (#779)
- Fix various issues with Python 3.13 and 3.14 support (#773)
- Improve `ParamSpec` support (#772, #777)
- Fix handling of stub functions with positional-only parameters with
Expand Down
9 changes: 6 additions & 3 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5078,10 +5078,13 @@ def _check_dunder_call(
args: Iterable[Composite],
allow_call: bool = False,
) -> Tuple[Value, bool]:
if isinstance(callee_composite.value, MultiValuedValue):
val = callee_composite.value
if isinstance(val, AnnotatedValue):
val = val.value
if isinstance(val, MultiValuedValue):
composites = [
Composite(val, callee_composite.varname, callee_composite.node)
for val in callee_composite.value.vals
Composite(subval, callee_composite.varname, callee_composite.node)
for subval in val.vals
]
with qcore.override(self, "in_union_decomposition", True):
values_and_exists = [
Expand Down
14 changes: 14 additions & 0 deletions pyanalyze/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from typing import Dict, Optional, Tuple

from pyanalyze.annotated_types import Gt

from . import tests, value
from .arg_spec import ArgSpecCache
from .error_code import ErrorCode, register_error_code
Expand Down Expand Up @@ -59,14 +61,25 @@ def get_constructor(cls: type) -> Optional[Signature]:
return None


def _make_union_in_annotated_impl(ctx: CallContext) -> value.Value:
return value.AnnotatedValue(
value.TypedValue(float) | value.TypedValue(int),
[value.CustomCheckExtension(Gt(0))],
)


@used # in test.toml
def get_known_signatures(
arg_spec_cache: ArgSpecCache,
) -> Dict[object, ConcreteSignature]:
failing_impl_sig = arg_spec_cache.get_argspec(tests.FailingImpl, impl=_failing_impl)
custom_sig = arg_spec_cache.get_argspec(tests.custom_code, impl=_custom_code_impl)
union_in_anno_sig = arg_spec_cache.get_argspec(
tests.make_union_in_annotated, impl=_make_union_in_annotated_impl
)
assert isinstance(custom_sig, Signature), custom_sig
assert isinstance(failing_impl_sig, Signature), failing_impl_sig
assert isinstance(union_in_anno_sig, Signature), union_in_anno_sig
return {
tests.takes_kwonly_argument: Signature.make(
[
Expand All @@ -81,6 +94,7 @@ def get_known_signatures(
),
tests.FailingImpl: failing_impl_sig,
tests.custom_code: custom_sig,
tests.make_union_in_annotated: union_in_anno_sig,
tests.overloaded: OverloadedSignature(
[
Signature.make([], value.TypedValue(int), callable=tests.overloaded),
Expand Down
39 changes: 39 additions & 0 deletions pyanalyze/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,42 @@ def capybara():
def test_text_and_bytes(self):
def capybara():
return "foo" + b"bar" # E: unsupported_operation


class TestAnnotated(TestNameCheckVisitorBase):

@assert_passes()
def test_union(self):
from typing import Union

from pyanalyze.annotated_types import Gt as OurGt
from pyanalyze.tests import make_union_in_annotated
from pyanalyze.value import (
AnnotatedValue,
CustomCheckExtension,
TypedValue,
assert_is_value,
)

def want_float_or_int(x: Union[float, int]) -> None:
pass

def capybara() -> None:
# unite_values() distributes Annotated metadata over the union,
# so we have to really careful to get a type that is Annotated[x | y, ...]
# rather than Annotated[x, ...] | Annotated[y, ...]
assert_is_value(
make_union_in_annotated(),
AnnotatedValue(
TypedValue(float) | TypedValue(int),
[CustomCheckExtension(OurGt(0))],
),
)
y = 10000 * make_union_in_annotated()
x = make_union_in_annotated()
z = x * 1000000
assert_is_value(y, TypedValue(float) | TypedValue(int))
assert_is_value(z, TypedValue(float) | TypedValue(int))

want_float_or_int(make_union_in_annotated())
want_float_or_int(x)
4 changes: 4 additions & 0 deletions pyanalyze/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,7 @@ def assert_never(arg: NoReturn) -> NoReturn:

def make_simple_sequence(typ: type, vals: Sequence[Value]) -> SequenceValue:
return SequenceValue(typ, [(False, val) for val in vals])


def make_union_in_annotated() -> object:
return 42

0 comments on commit e8ed0f3

Please sign in to comment.