Skip to content

Commit

Permalink
Extended type narrowing for the X is L and X is not L type guard …
Browse files Browse the repository at this point in the history
…pattern. Previously, narrowing was performed only when `L` was an enum or bool literal. Narrowing is also now applied for other literal types but only in the positive (`if`) direction. This addresses #9758. (#9759)
  • Loading branch information
erictraut authored Jan 25, 2025
1 parent 48ea2a0 commit c9d4475
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/type-concepts-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x == ...` and `x != ...` (where `...` is an ellipsis token)
* `type(x) is T` and `type(x) is not T`
* `type(x) == T` and `type(x) != T`
* `x is E` and `x is not E` (where E is a literal enum or bool)
* `x is E` and `x is not L` (where L is an expression that evaluates to a literal type)
* `x is C` and `x is not C` (where C is a class)
* `x == L` and `x != L` (where L is an expression that evaluates to a literal type)
* `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None)
Expand Down
19 changes: 10 additions & 9 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,8 @@ export function getTypeNarrowingCallback(
const rightTypeResult = evaluator.getTypeOfExpression(testExpression.d.rightExpr);
const rightType = rightTypeResult.type;

// Look for "X is Y" or "X is not Y" where Y is a an enum or bool literal.
if (
isClassInstance(rightType) &&
(ClassType.isEnumClass(rightType) || ClassType.isBuiltIn(rightType, 'bool')) &&
rightType.priv.literalValue !== undefined
) {
// Look for "X is Y" or "X is not Y" where Y is a literal.
if (isClassInstance(rightType) && rightType.priv.literalValue !== undefined) {
return (type: Type) => {
return {
type: narrowTypeForLiteralComparison(
Expand Down Expand Up @@ -2495,10 +2491,15 @@ function narrowTypeForLiteralComparison(
} else if (isClassInstance(subtype) && ClassType.isSameGenericClass(literalType, subtype)) {
if (subtype.priv.literalValue !== undefined) {
const literalValueMatches = ClassType.isLiteralValueSame(subtype, literalType);
if ((literalValueMatches && !isPositiveTest) || (!literalValueMatches && isPositiveTest)) {
return undefined;
if (isPositiveTest) {
return literalValueMatches ? subtype : undefined;
} else {
const isEnumOrBool = ClassType.isEnumClass(literalType) || ClassType.isBuiltIn(literalType, 'bool');

// For negative tests, we can eliminate the literal value if it doesn't match,
// but only for equality tests or for 'is' tests that involve enums or bools.
return literalValueMatches && (isEnumOrBool || !isIsOperator) ? undefined : subtype;
}
return subtype;
} else if (isPositiveTest) {
return literalType;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# types that have enumerated literals (bool and enums).

from enum import Enum
from typing import Any, Literal, Union
from typing import Any, Literal, Union, reveal_type


class SomeEnum(Enum):
Expand All @@ -27,12 +27,10 @@ def func2(a: SomeEnum) -> Literal[3]:
return a.value


def must_be_true(a: Literal[True]):
...
def must_be_true(a: Literal[True]): ...


def must_be_false(a: Literal[False]):
...
def must_be_false(a: Literal[False]): ...


def func3(a: bool):
Expand Down Expand Up @@ -75,3 +73,10 @@ def func7(x: Any):
reveal_type(x, expected_text="Literal[MyEnum.ZERO]")
else:
reveal_type(x, expected_text="Any")


def func8(x: Literal[0, 1] | None):
if x is 1:
reveal_type(x, expected_text="Literal[1]")
else:
reveal_type(x, expected_text="Literal[0, 1] | None")

0 comments on commit c9d4475

Please sign in to comment.