diff --git a/docs/changelog.md b/docs/changelog.md index a81531e9..8be30875 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Fix binary operations involving unions wrapped in `Annotated` (#779) - Fix various issues with Python 3.13 and 3.14 support (#773) - Improve `ParamSpec` support (#772, #777) - Fix handling of stub functions with positional-only parameters with diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 1f39d109..855b14ae 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -5078,10 +5078,13 @@ def _check_dunder_call( args: Iterable[Composite], allow_call: bool = False, ) -> Tuple[Value, bool]: - if isinstance(callee_composite.value, MultiValuedValue): + val = callee_composite.value + if isinstance(val, AnnotatedValue): + val = val.value + if isinstance(val, MultiValuedValue): composites = [ - Composite(val, callee_composite.varname, callee_composite.node) - for val in callee_composite.value.vals + Composite(subval, callee_composite.varname, callee_composite.node) + for subval in val.vals ] with qcore.override(self, "in_union_decomposition", True): values_and_exists = [ diff --git a/pyanalyze/test_config.py b/pyanalyze/test_config.py index ad2a3fc6..4d918d86 100644 --- a/pyanalyze/test_config.py +++ b/pyanalyze/test_config.py @@ -7,6 +7,8 @@ from pathlib import Path from typing import Dict, Optional, Tuple +from pyanalyze.annotated_types import Gt + from . import tests, value from .arg_spec import ArgSpecCache from .error_code import ErrorCode, register_error_code @@ -59,14 +61,25 @@ def get_constructor(cls: type) -> Optional[Signature]: return None +def _make_union_in_annotated_impl(ctx: CallContext) -> value.Value: + return value.AnnotatedValue( + value.TypedValue(float) | value.TypedValue(int), + [value.CustomCheckExtension(Gt(0))], + ) + + @used # in test.toml def get_known_signatures( arg_spec_cache: ArgSpecCache, ) -> Dict[object, ConcreteSignature]: failing_impl_sig = arg_spec_cache.get_argspec(tests.FailingImpl, impl=_failing_impl) custom_sig = arg_spec_cache.get_argspec(tests.custom_code, impl=_custom_code_impl) + union_in_anno_sig = arg_spec_cache.get_argspec( + tests.make_union_in_annotated, impl=_make_union_in_annotated_impl + ) assert isinstance(custom_sig, Signature), custom_sig assert isinstance(failing_impl_sig, Signature), failing_impl_sig + assert isinstance(union_in_anno_sig, Signature), union_in_anno_sig return { tests.takes_kwonly_argument: Signature.make( [ @@ -81,6 +94,7 @@ def get_known_signatures( ), tests.FailingImpl: failing_impl_sig, tests.custom_code: custom_sig, + tests.make_union_in_annotated: union_in_anno_sig, tests.overloaded: OverloadedSignature( [ Signature.make([], value.TypedValue(int), callable=tests.overloaded), diff --git a/pyanalyze/test_operations.py b/pyanalyze/test_operations.py index 0df5a0cc..b7cffb69 100644 --- a/pyanalyze/test_operations.py +++ b/pyanalyze/test_operations.py @@ -275,3 +275,42 @@ def capybara(): def test_text_and_bytes(self): def capybara(): return "foo" + b"bar" # E: unsupported_operation + + +class TestAnnotated(TestNameCheckVisitorBase): + + @assert_passes() + def test_union(self): + from typing import Union + + from pyanalyze.annotated_types import Gt as OurGt + from pyanalyze.tests import make_union_in_annotated + from pyanalyze.value import ( + AnnotatedValue, + CustomCheckExtension, + TypedValue, + assert_is_value, + ) + + def want_float_or_int(x: Union[float, int]) -> None: + pass + + def capybara() -> None: + # unite_values() distributes Annotated metadata over the union, + # so we have to really careful to get a type that is Annotated[x | y, ...] + # rather than Annotated[x, ...] | Annotated[y, ...] + assert_is_value( + make_union_in_annotated(), + AnnotatedValue( + TypedValue(float) | TypedValue(int), + [CustomCheckExtension(OurGt(0))], + ), + ) + y = 10000 * make_union_in_annotated() + x = make_union_in_annotated() + z = x * 1000000 + assert_is_value(y, TypedValue(float) | TypedValue(int)) + assert_is_value(z, TypedValue(float) | TypedValue(int)) + + want_float_or_int(make_union_in_annotated()) + want_float_or_int(x) diff --git a/pyanalyze/tests.py b/pyanalyze/tests.py index e2f81b4b..aea71abe 100644 --- a/pyanalyze/tests.py +++ b/pyanalyze/tests.py @@ -238,3 +238,7 @@ def assert_never(arg: NoReturn) -> NoReturn: def make_simple_sequence(typ: type, vals: Sequence[Value]) -> SequenceValue: return SequenceValue(typ, [(False, val) for val in vals]) + + +def make_union_in_annotated() -> object: + return 42