Skip to content

Commit

Permalink
Added support for argument type expansion of bool, enums, and tuples …
Browse files Browse the repository at this point in the history
…of fixed length when evaluating overloads. This behavior is mandated by the new draft typing spec update. This addresses #9706. (#9763)
  • Loading branch information
erictraut authored Jan 26, 2025
1 parent 62fa084 commit 4cae89f
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 60 deletions.
42 changes: 42 additions & 0 deletions packages/pyright-internal/src/analyzer/tuples.ts
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,48 @@ export function getSlicedTupleType(
return ClassType.cloneAsInstance(specializeTupleClass(tupleType, slicedTypeArgs));
}

// If the type is a fixed-length tuple instance and one or more of the element types
// is a union, this function expands the tuple into a union of tuples where each
// element is a union of the corresponding element types. This is done for all
// element combinations until the total number of tuples exceeds maxExpansion,
// at which point the function returns the original tuple type.
export function expandTuple(tupleType: ClassType, maxExpansion: number): Type[] | undefined {
if (
!isTupleClass(tupleType) ||
!tupleType.priv.tupleTypeArgs ||
tupleType.priv.tupleTypeArgs.some((typeArg) => typeArg.isUnbounded || isTypeVarTuple(typeArg.type))
) {
return undefined;
}

let typesToCombine: ClassType[] = [tupleType];
let index = 0;

while (index < tupleType.priv.tupleTypeArgs.length) {
const elemType = tupleType.priv.tupleTypeArgs[index].type;
if (isUnion(elemType)) {
const newTypesToCombine: ClassType[] = [];

for (const typeToCombine of typesToCombine) {
doForEachSubtype(elemType, (subtype) => {
const newTypeArgs = [...typeToCombine.priv.tupleTypeArgs!];
newTypeArgs[index] = { type: subtype, isUnbounded: false };
newTypesToCombine.push(ClassType.cloneAsInstance(specializeTupleClass(typeToCombine, newTypeArgs)));
});
}
typesToCombine = newTypesToCombine;
}

if (typesToCombine.length > maxExpansion) {
return undefined;
}

index++;
}

return typesToCombine.length === 1 ? undefined : typesToCombine;
}

function getTupleSliceParam(
evaluator: TypeEvaluator,
expression: ExpressionNode | undefined,
Expand Down
67 changes: 48 additions & 19 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ import { evaluateStaticBoolExpression } from './staticExpressions';
import { indeterminateSymbolId, Symbol, SymbolFlags, SynthesizedTypeInfo } from './symbol';
import { isConstantName, isPrivateName, isPrivateOrProtectedName } from './symbolNameUtils';
import { getLastTypedDeclarationForSymbol, isEffectivelyClassVar } from './symbolUtils';
import { assignTupleTypeArgs, getSlicedTupleType, getTypeOfTuple, makeTupleObject } from './tuples';
import { assignTupleTypeArgs, expandTuple, getSlicedTupleType, getTypeOfTuple, makeTupleObject } from './tuples';
import { SpeculativeModeOptions, SpeculativeTypeTracker } from './typeCacheUtils';
import {
assignToTypedDict,
Expand Down Expand Up @@ -217,6 +217,7 @@ import {
ValidateArgTypeParams,
ValidateTypeArgsOptions,
} from './typeEvaluatorTypes';
import { enumerateLiteralsForType } from './typeGuards';
import * as TypePrinter from './typePrinter';
import {
AnyType,
Expand Down Expand Up @@ -548,9 +549,9 @@ const maxDeclarationsToUseForInference = 64;
// of a variable that has no type declaration.
const maxEffectiveTypeEvaluationAttempts = 16;

// Maximum number of combinatoric union type expansions allowed
// Maximum number of combinatoric argument type expansions allowed
// when resolving an overload.
const maxOverloadUnionExpansionCount = 64;
const maxOverloadArgTypeExpansionCount = 64;

// Maximum number of recursive function return type inference attempts
// that can be concurrently pending before we give up.
Expand Down Expand Up @@ -9593,10 +9594,10 @@ export function createTypeEvaluator(
});
}

expandedArgTypes = expandArgUnionTypes(contextFreeArgTypes!, expandedArgTypes);
expandedArgTypes = expandArgTypes(contextFreeArgTypes!, expandedArgTypes);

// Check for combinatoric explosion and break out of loop.
if (!expandedArgTypes || expandedArgTypes.length > maxOverloadUnionExpansionCount) {
if (!expandedArgTypes || expandedArgTypes.length > maxOverloadArgTypeExpansionCount) {
break;
}
}
Expand All @@ -9620,12 +9621,13 @@ export function createTypeEvaluator(
}

// Replaces each item in the expandedArgTypes with n items where n is
// the number of subtypes in a union. The contextFreeArgTypes parameter
// represents the types of the arguments evaluated with no bidirectional
// type inference (i.e. without the help of the corresponding parameter's
// expected type). If the function returns undefined, that indicates that
// all unions have been expanded, and no more expansion is possible.
function expandArgUnionTypes(
// the number of subtypes in a union or other expandable type.
// The contextFreeArgTypes parameter represents the types of the arguments
// evaluated with no bidirectional type inference (i.e. without the help of
// the corresponding parameter's expected type). If the function returns
// undefined, that indicates that all types have been expanded, and no
// more expansion is possible.
function expandArgTypes(
contextFreeArgTypes: Type[],
expandedArgTypes: (Type | undefined)[][]
): (Type | undefined)[][] | undefined {
Expand All @@ -9642,30 +9644,28 @@ export function createTypeEvaluator(
return undefined;
}

let unionToExpand: Type | undefined;
let expandedTypes: Type[] | undefined;
while (indexToExpand < contextFreeArgTypes.length) {
// Is this a union type? If so, we can expand it.
const argType = contextFreeArgTypes[indexToExpand];
if (isUnion(argType)) {
unionToExpand = makeTopLevelTypeVarsConcrete(argType);
break;
} else if (isTypeVar(argType) && TypeVarType.hasConstraints(argType)) {
unionToExpand = makeTopLevelTypeVarsConcrete(argType);

expandedTypes = expandArgType(argType);
if (expandedTypes) {
break;
}
indexToExpand++;
}

// We have nothing left to expand.
if (!unionToExpand) {
if (!expandedTypes) {
return undefined;
}

// Expand entry indexToExpand.
const newExpandedArgTypes: (Type | undefined)[][] = [];

expandedArgTypes.forEach((preExpandedTypes) => {
doForEachSubtype(unionToExpand!, (subtype) => {
expandedTypes.forEach((subtype) => {
const expandedTypes = [...preExpandedTypes];
expandedTypes[indexToExpand] = subtype;
newExpandedArgTypes.push(expandedTypes);
Expand All @@ -9675,6 +9675,35 @@ export function createTypeEvaluator(
return newExpandedArgTypes;
}

function expandArgType(type: Type): Type[] | undefined {
const expandedTypes: Type[] = [];

// Expand any top-level type variables with constraints.
type = makeTopLevelTypeVarsConcrete(type);

doForEachSubtype(type, (subtype) => {
if (isClassInstance(subtype)) {
// Expand any bool or Enum literals.
const expandedLiteralTypes = enumerateLiteralsForType(evaluatorInterface, subtype);
if (expandedLiteralTypes) {
appendArray(expandedTypes, expandedLiteralTypes);
return;
}

// Expand any fixed-size tuples.
const expandedTuples = expandTuple(subtype, maxOverloadArgTypeExpansionCount);
if (expandedTuples) {
appendArray(expandedTypes, expandedTuples);
return;
}
}

expandedTypes.push(subtype);
});

return expandedTypes.length > 1 ? expandedTypes : undefined;
}

// Validates that the arguments can be assigned to the call's parameter
// list, specializes the call based on arg types, and returns the
// specialized type of the return value. If it detects an error along
Expand Down
119 changes: 78 additions & 41 deletions packages/pyright-internal/src/tests/samples/overloadCall4.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,31 @@
# This sample tests the expansion of union types during overload matching.
# This sample tests the expansion of argument types during overload matching.


from enum import Enum
from typing import AnyStr, Literal, TypeVar, overload


class A:
...
class A: ...


class B:
...
class B: ...


class C:
...
class C: ...


_T1 = TypeVar("_T1", bound=B)


@overload
def overloaded1(x: A) -> str:
...
def overloaded1(x: A) -> str: ...


@overload
def overloaded1(x: _T1) -> _T1:
...
def overloaded1(x: _T1) -> _T1: ...


def overloaded1(x: A | B) -> str | B:
...
def overloaded1(x: A | B) -> str | B: ...


def func1(a: A | B, b: A | B | C):
Expand All @@ -46,32 +41,26 @@ def func1(a: A | B, b: A | B | C):


@overload
def overloaded2(a: LargeUnion, b: Literal[2]) -> str:
...
def overloaded2(a: LargeUnion, b: Literal[2]) -> str: ...


@overload
def overloaded2(a: LargeUnion, b: Literal[3]) -> str:
...
def overloaded2(a: LargeUnion, b: Literal[3]) -> str: ...


@overload
def overloaded2(a: LargeUnion, b: Literal[4]) -> float:
...
def overloaded2(a: LargeUnion, b: Literal[4]) -> float: ...


@overload
def overloaded2(a: LargeUnion, b: Literal[9]) -> float:
...
def overloaded2(a: LargeUnion, b: Literal[9]) -> float: ...


@overload
def overloaded2(a: LargeUnion, b: Literal[10]) -> float:
...
def overloaded2(a: LargeUnion, b: Literal[10]) -> float: ...


def overloaded2(a: LargeUnion, b: LargeUnion | Literal[9, 10]) -> str | float:
...
def overloaded2(a: LargeUnion, b: LargeUnion | Literal[9, 10]) -> str | float: ...


def func2(a: LargeUnion, b: Literal[2, 3, 4], c: Literal[2, 3, 4, 9, 10]):
Expand All @@ -91,17 +80,14 @@ def func2(a: LargeUnion, b: Literal[2, 3, 4], c: Literal[2, 3, 4, 9, 10]):


@overload
def overloaded3(x: str) -> str:
...
def overloaded3(x: str) -> str: ...


@overload
def overloaded3(x: bytes) -> bytes:
...
def overloaded3(x: bytes) -> bytes: ...


def overloaded3(x: str | bytes) -> str | bytes:
...
def overloaded3(x: str | bytes) -> str | bytes: ...


def func3(y: _T2):
Expand All @@ -116,31 +102,26 @@ def func5(a: _T3) -> _T3:


@overload
def overloaded4(b: str) -> str:
...
def overloaded4(b: str) -> str: ...


@overload
def overloaded4(b: int) -> int:
...
def overloaded4(b: int) -> int: ...


def overloaded4(b: str | int) -> str | int:
...
def overloaded4(b: str | int) -> str | int: ...


def func6(x: str | int) -> None:
y: str | int = overloaded4(func5(x))


@overload
def overloaded5(pattern: AnyStr) -> AnyStr:
...
def overloaded5(pattern: AnyStr) -> AnyStr: ...


@overload
def overloaded5(pattern: int) -> int:
...
def overloaded5(pattern: int) -> int: ...


def overloaded5(pattern: AnyStr | int) -> AnyStr | int:
Expand All @@ -153,3 +134,59 @@ def func7(a: str | bytes) -> str | bytes:

def func8(a: AnyStr | str | bytes) -> str | bytes:
return overloaded5(a)


class E(Enum):
A = "A"
B = "B"


@overload
def func9(v: Literal[E.A]) -> int: ...
@overload
def func9(v: Literal[E.B]) -> str: ...
@overload
def func9(v: bool) -> list[str]: ...


def func9(v: E | bool) -> int | str | list[str]: ...


def test9(a1: E | bool):
reveal_type(func9(a1), expected_text="int | str | list[str]")


@overload
def func10(v: Literal[True]) -> int: ...
@overload
def func10(v: Literal[False]) -> str: ...


def func10(v: bool) -> int | str: ...


def test10(a1: bool):
reveal_type(func10(a1), expected_text="int | str")


@overload
def func11(v: tuple[int, int]) -> int: ...


@overload
def func11(v: tuple[str, int]) -> str: ...


@overload
def func11(v: tuple[int, str]) -> int: ...


@overload
def func11(v: tuple[str, str]) -> str: ...


def func11(v: tuple[int | str, int | str]) -> int | str: ...


def test11(a1: tuple[int | str, int | str]):
reveal_type(func11(a1), expected_text="int | str")

0 comments on commit 4cae89f

Please sign in to comment.