Skip to content

Commit

Permalink
PY-49935 Impl type inference and type checking for PEP 612
Browse files Browse the repository at this point in the history
Support type hints and type checking for typing.ParamSpec and typing.Concatenate

(cherry picked from commit 7854b3386ccdffc0091664e0923622cd8c093fc9)

IJ-MR-12970

GitOrigin-RevId: 4578cb463b6ab8fc244766bfaccb122d0e2b7479
  • Loading branch information
jetmano authored and intellij-monorepo-bot committed Aug 23, 2021
1 parent de7f4d9 commit 997b58d
Show file tree
Hide file tree
Showing 31 changed files with 1,389 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ INSP.type.hints.illegal.literal.parameter='Literal' may be parameterized with li
INSP.type.hints.parameters.to.generic.must.all.be.type.variables=Parameters to 'Generic[...]' must all be type variables
INSP.type.hints.parameters.to.generic.must.all.be.unique=Parameters to 'Generic[...]' must all be unique
INSP.type.hints.illegal.callable.format='Callable' must be used as 'Callable[[arg, ...], result]'
INSP.type.hints.illegal.first.parameter='Callable' first parameter must be parameter expression
INSP.type.hints.parameters.to.generic.types.must.be.types=Parameters to generic types must be types
INSP.type.hints.type.comment.cannot.be.matched.with.unpacked.variables=Type comment cannot be matched with unpacked variables
INSP.type.hints.type.signature.has.too.few.arguments=Type signature has too few arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
private static final String TUPLE = "typing.Tuple";
public static final String CLASS_VAR = "typing.ClassVar";
public static final String TYPE_VAR = "typing.TypeVar";
public static final String PARAM_SPEC = "typing.ParamSpec";
private static final String CHAIN_MAP = "typing.ChainMap";
public static final String UNION = "typing.Union";
public static final String CONCATENATE = "typing.Concatenate";
public static final String OPTIONAL = "typing.Optional";
public static final String NO_RETURN = "typing.NoReturn";
public static final String FINAL = "typing.Final";
Expand Down Expand Up @@ -152,6 +154,8 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
.add(ANY)
.add(TYPE_VAR)
.add(GENERIC)
.add(PARAM_SPEC)
.add(CONCATENATE)
.add(TUPLE)
.add(CALLABLE)
.add(TYPE)
Expand Down Expand Up @@ -598,7 +602,8 @@ public static TextRange getTypeCommentValueRange(@NotNull String text) {
@Nullable
@Override
public PyType getGenericType(@NotNull PyClass cls, @NotNull TypeEvalContext context) {
final List<PyType> genericTypes = collectGenericTypes(cls, new Context(context));
final var typingContext = new Context(context);
final var genericTypes = collectGenericTypes(cls, typingContext);
if (genericTypes.isEmpty()) {
return null;
}
Expand Down Expand Up @@ -697,7 +702,7 @@ private static List<PyType> collectGenericTypes(@NotNull PyClass cls, @NotNull C
if (!isGeneric(cls, context.getTypeContext())) {
return Collections.emptyList();
}
final TypeEvalContext typeEvalContext = context.getTypeContext();
final var typeEvalContext = context.getTypeContext();
return StreamEx.of(PyClassElementType.getSubscriptedSuperClassesStubLike(cls))
.map(PySubscriptionExpression::getIndexExpression)
.flatMap(e -> {
Expand All @@ -706,7 +711,11 @@ private static List<PyType> collectGenericTypes(@NotNull PyClass cls, @NotNull C
})
.nonNull()
.flatMap(e -> tryResolving(e, typeEvalContext).stream())
.map(e -> getGenericTypeFromTypeVar(e, context))
.map(e -> {
final var typeVar = getGenericTypeFromTypeVar(e, context);
if (typeVar != null) return typeVar;
return getParamSpecType(e, context);
})
.select(PyType.class)
.distinct()
.toList();
Expand Down Expand Up @@ -803,6 +812,10 @@ private static Ref<PyType> getTypeForResolvedElement(@Nullable PyTargetExpressio
if (unionType != null) {
return Ref.create(unionType);
}
final PyType concatenateType = getConcatenateType(resolved, context);
if (concatenateType != null) {
return Ref.create(concatenateType);
}
final Ref<PyType> optionalType = getOptionalType(resolved, context);
if (optionalType != null) {
return optionalType;
Expand All @@ -813,7 +826,7 @@ private static Ref<PyType> getTypeForResolvedElement(@Nullable PyTargetExpressio
}
final Ref<PyType> classObjType = getClassObjectType(resolved, context);
if (classObjType != null) {
return Ref.create(addTypeVarAlias(classObjType.get(), alias));
return Ref.create(addGenericAlias(classObjType.get(), alias));
}
final Ref<PyType> finalType = getFinalType(resolved, context);
if (finalType != null) {
Expand Down Expand Up @@ -841,7 +854,11 @@ private static Ref<PyType> getTypeForResolvedElement(@Nullable PyTargetExpressio
}
final PyType genericType = getGenericTypeFromTypeVar(resolved, context);
if (genericType != null) {
return Ref.create(addTypeVarAlias(genericType, alias));
return Ref.create(addGenericAlias(genericType, alias));
}
final PyType paramSpecType = getParamSpecType(resolved, context);
if (paramSpecType != null) {
return Ref.create(addGenericAlias(paramSpecType, alias));
}
final PyType stringBasedType = getStringLiteralType(resolved, context);
if (stringBasedType != null) {
Expand Down Expand Up @@ -918,11 +935,15 @@ private static Ref<PyType> getAliasedType(@NotNull PsiElement resolved, @NotNull
}

@Nullable
private static PyType addTypeVarAlias(@Nullable PyType type, @Nullable PyTargetExpression alias) {
private static PyType addGenericAlias(@Nullable PyType type, @Nullable PyTargetExpression alias) {
final PyGenericType typeVar = as(type, PyGenericType.class);
if (typeVar != null) {
return new PyGenericType(typeVar.getName(), typeVar.getBound(), typeVar.isDefinition(), alias);
}
final PyParamSpecType paramSpec = as(type, PyParamSpecType.class);
if (paramSpec != null) {
return paramSpec.withTargetExpression(alias);
}
return type;
}

Expand Down Expand Up @@ -1275,6 +1296,21 @@ private static PyType getCallableType(@NotNull PsiElement resolved, @NotNull Con
if (isEllipsis(parametersExpr)) {
return new PyCallableTypeImpl(null, Ref.deref(getType(returnTypeExpr, context)));
}
if (isParamSpec(parametersExpr, context.myContext)) {
final var name = parametersExpr.getName();
if (name != null) {
final var parameter = PyCallableParameterImpl.nonPsi(parametersExpr.getName(), new PyParamSpecType(name));
return new PyCallableTypeImpl(Collections.singletonList(parameter), Ref.deref(getType(returnTypeExpr, context)));
}
}
if (parametersExpr instanceof PySubscriptionExpression && isConcatenate(parametersExpr, context.myContext)) {
final var concatenateParameters = getConcatenateParametersTypes((PySubscriptionExpression)parametersExpr, context.myContext);
if (concatenateParameters != null) {
final var concatenate = new PyConcatenateType(concatenateParameters.first, concatenateParameters.second);
final var parameter = PyCallableParameterImpl.nonPsi(parametersExpr.getName(), concatenate);
return new PyCallableTypeImpl(Collections.singletonList(parameter), Ref.deref(getType(returnTypeExpr, context)));
}
}
}
}
}
Expand All @@ -1291,6 +1327,22 @@ private static boolean isEllipsis(@NotNull PyExpression parametersExpr) {
return parametersExpr instanceof PyNoneLiteralExpression && ((PyNoneLiteralExpression)parametersExpr).isEllipsis();
}

public static boolean isParamSpec(@NotNull PyExpression parametersExpr, @NotNull TypeEvalContext context) {
final var resolveContext = PyResolveContext.defaultContext(context);
return PyUtil.multiResolveTopPriority(parametersExpr, resolveContext).stream().anyMatch(it -> {
if (!(it instanceof PyTypedElement)) return false;
final var type = context.getType((PyTypedElement)it);
if (!(type instanceof PyClassLikeType)) return false;
return PARAM_SPEC.equals(((PyClassLikeType)type).getClassQName());
});
}

public static boolean isConcatenate(@NotNull PyExpression parametersExpr, @NotNull TypeEvalContext context) {
if (!(parametersExpr instanceof PySubscriptionExpression)) return false;
final var type = Ref.deref(getType(parametersExpr, context));
return type instanceof PyConcatenateType;
}

@Nullable
private static PyType getUnionType(@NotNull PsiElement element, @NotNull Context context) {
if (element instanceof PySubscriptionExpression) {
Expand All @@ -1304,6 +1356,34 @@ private static PyType getUnionType(@NotNull PsiElement element, @NotNull Context
return null;
}

@Nullable
private static PyType getConcatenateType(@NotNull PsiElement element, @NotNull Context context) {
if (!(element instanceof PySubscriptionExpression)) return null;

final var subscriptionExpr = (PySubscriptionExpression)element;
final var operand = subscriptionExpr.getOperand();
final var operandNames = resolveToQualifiedNames(operand, context.myContext);
if (!operandNames.contains(CONCATENATE)) return null;

final var parameters = getConcatenateParametersTypes(subscriptionExpr, context.myContext);
if (parameters == null) return null;

return new PyConcatenateType(parameters.first, parameters.second);
}

@Nullable
private static Pair<List<PyType>, PyParamSpecType> getConcatenateParametersTypes(@NotNull PySubscriptionExpression subscriptionExpression,
@NotNull TypeEvalContext context) {
final var tuple = subscriptionExpression.getIndexExpression();
if (!(tuple instanceof PyTupleExpression)) return null;
final var result = ContainerUtil.mapNotNull(((PyTupleExpression)tuple).getElements(),
it -> Ref.deref(getType(it, context)));
if (result.size() < 2) return null;
PyType lastParameter = result.get(result.size() - 1);
if (!(lastParameter instanceof PyParamSpecType)) return null;
return new Pair<>(result.subList(0, result.size() - 1), (PyParamSpecType)lastParameter);
}

@Nullable
private static PyGenericType getGenericTypeFromTypeVar(@NotNull PsiElement element, @NotNull Context context) {
if (element instanceof PyCallExpression) {
Expand All @@ -1326,6 +1406,27 @@ private static PyGenericType getGenericTypeFromTypeVar(@NotNull PsiElement eleme
return null;
}

@Nullable
private static PyParamSpecType getParamSpecType(@NotNull PsiElement element, @NotNull Context context) {
if (!(element instanceof PyCallExpression)) return null;

final var assignedCall = (PyCallExpression)element;
final var callee = assignedCall.getCallee();
if (callee == null) return null;

final var calleeQNames = resolveToQualifiedNames(callee, context.getTypeContext());
if (!calleeQNames.contains(PARAM_SPEC)) return null;

final var arguments = assignedCall.getArguments();
if (arguments.length == 0) return null;

final var firstArgument = arguments[0];
if (!(firstArgument instanceof PyStringLiteralExpression)) return null;

final var name = ((PyStringLiteralExpression)firstArgument).getStringValue();
return new PyParamSpecType(name);
}

@Nullable
private static PyType getGenericTypeBound(PyExpression @NotNull [] typeVarArguments, @NotNull Context context) {
final List<PyType> types = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
Expand Down Expand Up @@ -383,7 +382,13 @@ private TypeModel buildCallable(@NotNull PyCallableType type) {
if (parameters != null) {
parameterModels = new ArrayList<>();
for (PyCallableParameter parameter : parameters) {
parameterModels.add(new ParamType(parameter.getName(), build(parameter.getType(myContext), true)));
final var paramType = parameter.getType(myContext);
if (paramType instanceof PyParamSpecType || paramType instanceof PyConcatenateType) {
parameterModels.add(new ParamType(null, build(parameter.getType(myContext), true)));
}
else {
parameterModels.add(new ParamType(parameter.getName(), build(parameter.getType(myContext), true)));
}
}
}
final PyType ret = type.getReturnType(myContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,45 @@ private AnalyzeCalleeResults analyzeCallee(@NotNull PyCallSiteExpression callSit

final List<AnalyzeArgumentResult> result = new ArrayList<>();

final PyExpression receiver = callSite.getReceiver(callableType.getCallable());
final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyReceiver(receiver, myTypeEvalContext);
final Map<PyExpression, PyCallableParameter> mappedParameters = mapping.getMappedParameters();
final var receiver = callSite.getReceiver(callableType.getCallable());
final var substitutions = PyTypeChecker.unifyReceiver(receiver, myTypeEvalContext);
final var mappedParameters = mapping.getMappedParameters();
final var regularMappedParameters = getRegularMappedParameters(mappedParameters);

for (Map.Entry<PyExpression, PyCallableParameter> entry : getRegularMappedParameters(mappedParameters).entrySet()) {
for (Map.Entry<PyExpression, PyCallableParameter> entry : regularMappedParameters.entrySet()) {
final PyExpression argument = entry.getKey();
final PyCallableParameter parameter = entry.getValue();
final PyType expected = parameter.getArgumentType(myTypeEvalContext);
final PyType promotedToLiteral = PyLiteralType.Companion.promoteToLiteral(argument, expected, myTypeEvalContext, substitutions);
final var actual = promotedToLiteral != null ? promotedToLiteral : myTypeEvalContext.getType(argument);
final boolean matched = matchParameterAndArgument(expected, actual, substitutions);
result.add(new AnalyzeArgumentResult(argument, expected, substituteGenerics(expected, substitutions), actual, matched));

if (expected instanceof PyParamSpecType) {
final var allArguments = callSite.getArguments(callableType.getCallable());
analyzeParamSpec((PyParamSpecType)expected, allArguments, substitutions, result);
break;
}
else if (expected instanceof PyConcatenateType) {
final var allArguments = callSite.getArguments(callableType.getCallable());
if (allArguments.isEmpty()) break;

final var concatenateType = (PyConcatenateType)expected;
final var firstExpectedTypes = concatenateType.getFirstTypes();
final var argumentRightBound = Math.min(firstExpectedTypes.size(), allArguments.size());
final var firstArguments = allArguments.subList(0, argumentRightBound);
matchArgumentsAndTypes(firstArguments, firstExpectedTypes, substitutions, result);

if (argumentRightBound < allArguments.size()) {
final var paramSpec = concatenateType.getParamSpec();
final var restArguments = allArguments.subList(argumentRightBound, allArguments.size());
analyzeParamSpec(paramSpec, restArguments, substitutions, result);
}

break;
}
else {
final boolean matched = matchParameterAndArgument(expected, actual, substitutions);
result.add(new AnalyzeArgumentResult(argument, expected, substituteGenerics(expected, substitutions), actual, matched));
}
}
final PyCallableParameter positionalContainer = getMappedPositionalContainer(mappedParameters);
if (positionalContainer != null) {
Expand All @@ -243,10 +270,34 @@ private AnalyzeCalleeResults analyzeCallee(@NotNull PyCallSiteExpression callSit
return new AnalyzeCalleeResults(callableType, callableType.getCallable(), result);
}

private void analyzeParamSpec(@NotNull PyParamSpecType paramSpec, @NotNull List<PyExpression> arguments,
@NotNull PyTypeChecker.GenericSubstitutions substitutions,
@NotNull List<AnalyzeArgumentResult> result) {
final var substParamSpec = substitutions.getParamSpecs().get(paramSpec);
paramSpec = substParamSpec == null ? paramSpec : substParamSpec;
final var parameters = paramSpec.getParameters();
if (parameters == null) return;
final var parametersTypes = ContainerUtil.map(parameters, it -> it.getType(myTypeEvalContext));
matchArgumentsAndTypes(arguments, parametersTypes, substitutions, result);
}

private void matchArgumentsAndTypes(@NotNull List<PyExpression> arguments, @NotNull List<PyType> types,
@NotNull PyTypeChecker.GenericSubstitutions substitutions,
@NotNull List<AnalyzeArgumentResult> result) {
final var size = Math.min(arguments.size(), types.size());
for (int i = 0; i < size; ++i) {
final var expected = types.get(i);
final var argument = arguments.get(i);
final var actual = myTypeEvalContext.getType(argument);
final var matched = matchParameterAndArgument(expected, actual, substitutions);
result.add(new AnalyzeArgumentResult(argument, expected, substituteGenerics(expected, substitutions), actual, matched));
}
}

@NotNull
private List<AnalyzeArgumentResult> analyzeContainerMapping(@NotNull PyCallableParameter container,
@NotNull List<PyExpression> arguments,
@NotNull Map<PyGenericType, PyType> substitutions) {
@NotNull PyTypeChecker.GenericSubstitutions substitutions) {
final PyType expected = container.getArgumentType(myTypeEvalContext);
final PyType expectedWithSubstitutions = substituteGenerics(expected, substitutions);
// For an expected type with generics we have to match all the actual types against it in order to do proper generic unification
Expand All @@ -270,13 +321,13 @@ private List<AnalyzeArgumentResult> analyzeContainerMapping(@NotNull PyCallableP

private boolean matchParameterAndArgument(@Nullable PyType parameterType,
@Nullable PyType argumentType,
@NotNull Map<PyGenericType, PyType> substitutions) {
@NotNull PyTypeChecker.GenericSubstitutions substitutions) {
return PyTypeChecker.match(parameterType, argumentType, myTypeEvalContext, substitutions) &&
!PyProtocolsKt.matchingProtocolDefinitions(parameterType, argumentType, myTypeEvalContext);
}

@Nullable
private PyType substituteGenerics(@Nullable PyType expectedArgumentType, @NotNull Map<PyGenericType, PyType> substitutions) {
private PyType substituteGenerics(@Nullable PyType expectedArgumentType, @NotNull PyTypeChecker.GenericSubstitutions substitutions) {
return PyTypeChecker.hasGenerics(expectedArgumentType, myTypeEvalContext)
? PyTypeChecker.substitute(expectedArgumentType, substitutions, myTypeEvalContext)
: null;
Expand Down
Loading

0 comments on commit 997b58d

Please sign in to comment.