Skip to content

Commit

Permalink
[red-knot] condense int literals (#11784)
Browse files Browse the repository at this point in the history
Display `(Literal[1] | Literal[2])` as `Literal[1, 2]`, and `(Literal[1]
| Literal[2] | OtherType)` as `(Literal[1, 2] | OtherType)`.

Fixes #11782

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
carljm and AlexWaygood authored Jun 6, 2024
1 parent b2fc0df commit cd101c8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
42 changes: 32 additions & 10 deletions crates/red_knot/src/semantic/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,42 @@ pub(crate) struct UnionType {

impl UnionType {
fn display(&self, f: &mut std::fmt::Formatter<'_>, store: &TypeStore) -> std::fmt::Result {
f.write_str("(")?;
let (int_literals, other_types): (Vec<Type>, Vec<Type>) = self
.elements
.iter()
.copied()
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
let mut first = true;
for ty in &self.elements {
if !int_literals.is_empty() {
f.write_str("Literal[")?;
let mut nums: Vec<i64> = int_literals
.into_iter()
.filter_map(|ty| {
if let Type::IntLiteral(n) = ty {
Some(n)
} else {
None
}
})
.collect();
nums.sort_unstable();
for num in nums {
if !first {
f.write_str(", ")?;
}
write!(f, "{num}")?;
first = false;
}
f.write_str("]")?;
}
for ty in other_types {
if !first {
f.write_str(" | ")?;
};
first = false;
write!(f, "{}", ty.display(store))?;
}
f.write_str(")")
Ok(())
}
}

Expand All @@ -775,7 +801,6 @@ pub(crate) struct IntersectionType {

impl IntersectionType {
fn display(&self, f: &mut std::fmt::Formatter<'_>, store: &TypeStore) -> std::fmt::Result {
f.write_str("(")?;
let mut first = true;
for (neg, ty) in self
.positive
Expand All @@ -792,7 +817,7 @@ impl IntersectionType {
};
write!(f, "{}", ty.display(store))?;
}
f.write_str(")")
Ok(())
}
}

Expand Down Expand Up @@ -857,7 +882,7 @@ mod tests {
elems.into_iter().collect::<FxIndexSet<_>>()
);
let union = Type::Union(id);
assert_eq!(format!("{}", union.display(&store)), "(C1 | C2)");
assert_eq!(format!("{}", union.display(&store)), "C1 | C2");
}

#[test]
Expand All @@ -880,9 +905,6 @@ mod tests {
neg.into_iter().collect::<FxIndexSet<_>>()
);
let intersection = Type::Intersection(id);
assert_eq!(
format!("{}", intersection.display(&store)),
"(C1 & C2 & ~C3)"
);
assert_eq!(format!("{}", intersection.display(&store)), "C1 & C2 & ~C3");
}
}
38 changes: 28 additions & 10 deletions crates/red_knot/src/semantic/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Literal[1] | Literal[2])")
assert_public_type(&case, "a", "x", "Literal[1, 2]")
}

#[test]
Expand All @@ -450,7 +450,7 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3])")
assert_public_type(&case, "a", "x", "Literal[2, 3]")
}

#[test]
Expand All @@ -467,7 +467,7 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Unbound | Literal[1])")
assert_public_type(&case, "a", "x", "Literal[1] | Unbound")
}

#[test]
Expand All @@ -492,7 +492,7 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Literal[3] | Literal[4] | Literal[5])")?;
assert_public_type(&case, "a", "x", "Literal[3, 4, 5]")?;
assert_public_type(&case, "a", "r", "Literal[2]")?;
assert_public_type(&case, "a", "s", "Literal[5]")
}
Expand All @@ -515,7 +515,7 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3] | Literal[4])")
assert_public_type(&case, "a", "x", "Literal[2, 3, 4]")
}

#[test]
Expand Down Expand Up @@ -569,7 +569,7 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Literal[1] | Literal[2])")
assert_public_type(&case, "a", "x", "Literal[1, 2]")
}

#[test]
Expand All @@ -587,9 +587,9 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "x", "(Literal[1] | Literal[2])")?;
assert_public_type(&case, "a", "a", "(Literal[1] | Literal[0])")?;
assert_public_type(&case, "a", "b", "(Literal[0] | Literal[2])")
assert_public_type(&case, "a", "x", "Literal[1, 2]")?;
assert_public_type(&case, "a", "a", "Literal[0, 1]")?;
assert_public_type(&case, "a", "b", "Literal[0, 2]")
}

#[test]
Expand All @@ -606,7 +606,25 @@ mod tests {
",
)?;

assert_public_type(&case, "a", "a", "(Literal[1] | Literal[2])")
assert_public_type(&case, "a", "a", "Literal[1, 2]")
}

#[test]
fn ifexpr_nested() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
class C1: pass
class C2: pass
class C3: pass
x = C1 if flag else C2 if flag2 else C3
",
)?;

assert_public_type(&case, "a", "x", "Literal[C1] | Literal[C2] | Literal[C3]")
}

#[test]
Expand Down

0 comments on commit cd101c8

Please sign in to comment.