Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Inference for comparison of union types #13781

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Comparison: Unions

## Union on one side of the comparison

Comparisons on union types need to consider all possible cases:

```py
one_or_two = 1 if flag else 2

reveal_type(one_or_two <= 2) # revealed: Literal[True]
reveal_type(one_or_two <= 1) # revealed: bool
reveal_type(one_or_two <= 0) # revealed: Literal[False]

reveal_type(2 >= one_or_two) # revealed: Literal[True]
reveal_type(1 >= one_or_two) # revealed: bool
reveal_type(0 >= one_or_two) # revealed: Literal[False]

reveal_type(one_or_two < 1) # revealed: Literal[False]
reveal_type(one_or_two < 2) # revealed: bool
reveal_type(one_or_two < 3) # revealed: Literal[True]

reveal_type(one_or_two > 0) # revealed: Literal[True]
reveal_type(one_or_two > 1) # revealed: bool
reveal_type(one_or_two > 2) # revealed: Literal[False]

reveal_type(one_or_two == 3) # revealed: Literal[False]
reveal_type(one_or_two == 1) # revealed: bool

reveal_type(one_or_two != 3) # revealed: Literal[True]
reveal_type(one_or_two != 1) # revealed: bool

a_or_ab = "a" if flag else "ab"

reveal_type(a_or_ab in "ab") # revealed: Literal[True]
reveal_type("a" in a_or_ab) # revealed: Literal[True]

reveal_type("c" not in a_or_ab) # revealed: Literal[True]
reveal_type("a" not in a_or_ab) # revealed: Literal[False]

reveal_type("b" in a_or_ab) # revealed: bool
reveal_type("b" not in a_or_ab) # revealed: bool

one_or_none = 1 if flag else None

reveal_type(one_or_none is None) # revealed: bool
reveal_type(one_or_none is not None) # revealed: bool
```

## Union on both sides of the comparison

With unions on both sides, we need to consider the full cross product of
options when building the resulting (union) type:

```py
small = 1 if flag_s else 2
large = 2 if flag_l else 3

reveal_type(small <= large) # revealed: Literal[True]
reveal_type(small >= large) # revealed: bool

reveal_type(small < large) # revealed: bool
reveal_type(small > large) # revealed: Literal[False]
```

## Unsupported operations

Make sure we emit a diagnostic if *any* of the possible comparisons is
unsupported. For now, we fall back to `bool` for the result type instead of
trying to infer something more precise from the other (supported) variants:

```py
x = [1, 2] if flag else 1

result = 1 in x # error: "Operator `in` is not supported"
reveal_type(result) # revealed: bool
```
5 changes: 5 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ impl<'db> Type<'db> {
(_, Type::Unknown | Type::Any | Type::Todo) => false,
(Type::Never, _) => true,
(_, Type::Never) => false,
(Type::BooleanLiteral(_), Type::Instance(class))
Copy link
Contributor

@carljm carljm Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, I really thought we had this already 🤦

There are unit tests at the bottom of this module for is_subtype_of, ideally we'd add one for this case. (Though the union-builder tests are great, too, and do transitively test this.)

if class.is_known(db, KnownClass::Bool) =>
{
true
}
(Type::IntLiteral(_), Type::Instance(class)) if class.is_known(db, KnownClass::Int) => {
true
}
Expand Down
6 changes: 6 additions & 0 deletions crates/red_knot_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ mod tests {

let union = UnionType::from_elements(&db, [t0, t1, t2, t3]).expect_union();
assert_eq!(union.elements(&db), &[bool_instance_ty, t3]);

let result_ty = UnionType::from_elements(&db, [bool_instance_ty, t0]);
assert_eq!(result_ty, bool_instance_ty);

let result_ty = UnionType::from_elements(&db, [t0, bool_instance_ty]);
assert_eq!(result_ty, bool_instance_ty);
}

#[test]
Expand Down
23 changes: 21 additions & 2 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ use crate::types::{
};
use crate::Db;

use super::KnownClass;
use super::{KnownClass, UnionBuilder};

/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
Expand Down Expand Up @@ -2717,6 +2717,21 @@ impl<'db> TypeInferenceBuilder<'db> {
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
// - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal
match (left, right) {
(Type::Union(union), other) => {
let mut builder = UnionBuilder::new(self.db);
for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?);
}
Some(builder.build())
}
(other, Type::Union(union)) => {
let mut builder = UnionBuilder::new(self.db);
for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?);
}
Some(builder.build())
Comment on lines +2728 to +2732
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to avoid the duplication here, but I couldn't think of a good way to do this without introducing much more code.

I also looked for ways to replace the for loop with an iterator expression, but that would require something like Iterator::try_collect, which is nightly-only. And would require an additional allocation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind the repetition (or use of a for loop, I'm a huge fan of for loops) but maybe Itertools::fold_options is what you want

https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.fold_options

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will require refactoring this method (so shouldn't be done in this PR), but in the long term something like #13787 seems like a good idea. I also think we'll want to emit a separate diagnostic on each member of the union that would be invalid for this operation (or one diagnostic that mentions all possibly erroring Union members). So I'd also just leave this as it is for now

Copy link
Contributor Author

@sharkdp sharkdp Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe Itertools::fold_options is what you want

https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.fold_options

Hm, not quite. fold_options short-circuits on Nones in the input iterator. But we want to short-circuit if the output of the fold-function returns None. So we would need something like Itertools::fold_while. That works, but it's completely ridiculous:

union
    .elements(self.db)
    .iter()
    .fold_while(Some(UnionBuilder::new(self.db)), |builder, element| {
        if let Some(ty) = self.infer_binary_type_comparison(other, op, *element) {
            FoldWhile::Continue(builder.map(|b| b.add(ty)))
        } else {
            FoldWhile::Done(None)
        }
    })
    .into_inner()
    .map(UnionBuilder::build)

I'll go with the for loop 😄

}

(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
Expand Down Expand Up @@ -2908,6 +2923,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
}

// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
Expand All @@ -2917,7 +2933,10 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => Some(Type::Todo),
},
// TODO: handle more types
_ => Some(Type::Todo),
_ => match op {
ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)),
sharkdp marked this conversation as resolved.
Show resolved Hide resolved
_ => Some(Type::Todo),
},
}
}

Expand Down
Loading