From c2132f0892770ee96e678670b5b92c908602742b Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 5 May 2021 22:57:46 +0900 Subject: [PATCH 1/6] Support GenericModel --- src/com/koxudaxi/pydantic/Pydantic.kt | 29 ++-- .../pydantic/PydanticCompletionContributor.kt | 44 +++++- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 149 ++++++++++++++++-- 3 files changed, 190 insertions(+), 32 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index bf93b2fc..5fb2de7f 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -25,6 +25,7 @@ import com.jetbrains.python.statistics.modules import java.util.regex.Pattern const val BASE_MODEL_Q_NAME = "pydantic.main.BaseModel" +const val GENERIC_MODEL_Q_NAME = "pydantic.generics.GenericModel" const val DATA_CLASS_Q_NAME = "pydantic.dataclasses.dataclass" const val DATA_CLASS_SHORT_Q_NAME = "pydantic.dataclass" const val VALIDATOR_Q_NAME = "pydantic.class_validators.validator" @@ -52,6 +53,8 @@ const val OPTIONAL_Q_NAME = "typing.Optional" const val UNION_Q_NAME = "typing.Union" const val ANNOTATED_Q_NAME = "typing.Annotated" const val CLASSVAR_Q_NAME = "typing.ClassVar" +const val GENERIC_Q_NAME = "typing.Generic" +const val TYPE_Q_NAME = "typing.Type" val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME) @@ -133,13 +136,21 @@ fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext? = null): Boolean { return (isSubClassOfPydanticBaseModel(pyClass, - context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass) + context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel(pyClass) } fun isPydanticBaseModel(pyClass: PyClass): Boolean { return pyClass.qualifiedName == BASE_MODEL_Q_NAME } +fun isPydanticGenericModel(pyClass: PyClass): Boolean { + return pyClass.qualifiedName == GENERIC_MODEL_Q_NAME +} + +internal fun isSubClassOfPydanticGenericModel(pyClass: PyClass, context: TypeEvalContext?): Boolean { + return pyClass.isSubclass(GENERIC_MODEL_Q_NAME, context) +} + internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext?): Boolean { return pyClass.isSubclass(BASE_MODEL_Q_NAME, context) } @@ -500,6 +511,7 @@ fun getPyTypeFromPyExpression(pyExpression: PyExpression, context: TypeEvalConte is PyType -> pyExpression is PyReferenceExpression -> { getResolvedPsiElements(pyExpression, context) + .asSequence() .filterIsInstance() .map { pyClass -> pyClass.getType(context)?.getReturnType(context) } .firstOrNull() @@ -576,14 +588,11 @@ internal fun getDefaultFactoryFromField(field: PyCallExpression): PyExpression? internal fun getQualifiedName(pyExpression: PyExpression, context: TypeEvalContext): String? { return when (pyExpression) { is PySubscriptionExpression -> pyExpression.qualifier?.let { getQualifiedName(it, context) } - is PyReferenceExpression -> { - return getResolvedPsiElements(pyExpression, context) - .filterIsInstance() - .mapNotNull { it.qualifiedName } - .firstOrNull() - } - else -> { - return null - } + is PyReferenceExpression -> return getResolvedPsiElements(pyExpression, context) + .asSequence() + .filterIsInstance() + .mapNotNull { it.qualifiedName } + .firstOrNull() + else -> return null } } \ No newline at end of file diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 3665c6a2..1d9e1387 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -12,6 +12,7 @@ import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHin import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyEvaluator import com.jetbrains.python.psi.types.PyClassType +import com.jetbrains.python.psi.types.PyType import com.jetbrains.python.psi.types.TypeEvalContext import javax.swing.Icon @@ -69,6 +70,7 @@ class PydanticCompletionContributor : CompletionContributor() { pydanticVersion: KotlinVersion?, config: HashMap, isDataclass: Boolean, + genericTypeMap: Map?, ): String { val parameter = typeProvider.fieldToParameter(pyTargetExpression, @@ -77,6 +79,7 @@ class PydanticCompletionContributor : CompletionContributor() { pyClass, pydanticVersion, config, + genericTypeMap, isDataclass = isDataclass) val defaultValue = parameter?.defaultValue?.let { when { @@ -102,6 +105,7 @@ class PydanticCompletionContributor : CompletionContributor() { config: HashMap, excludes: HashSet?, isDataclass: Boolean, + genericTypeMap: Map?, ) { val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext) getClassVariables(pyClass, typeEvalContext) @@ -121,7 +125,8 @@ class PydanticCompletionContributor : CompletionContributor() { ellipsis, pydanticVersion, config, - isDataclass)) + isDataclass, + genericTypeMap)) .withIcon(icon), 1) results[elementName] = PrioritizedLookupElement.withPriority(element, 100.0) } @@ -133,6 +138,7 @@ class PydanticCompletionContributor : CompletionContributor() { pyClass: PyClass, typeEvalContext: TypeEvalContext, ellipsis: PyNoneLiteralExpression, config: HashMap, + genericTypeMap: Map?, excludes: HashSet? = null, isDataclass: Boolean, ) { @@ -141,9 +147,25 @@ class PydanticCompletionContributor : CompletionContributor() { pyClass.getAncestorClasses(typeEvalContext) .filter { isPydanticModel(it, true) } - .forEach { addFieldElement(it, newElements, typeEvalContext, ellipsis, config, excludes, isDataclass) } + .forEach { + addFieldElement(it, + newElements, + typeEvalContext, + ellipsis, + config, + excludes, + isDataclass, + genericTypeMap) + } - addFieldElement(pyClass, newElements, typeEvalContext, ellipsis, config, excludes, isDataclass) + addFieldElement(pyClass, + newElements, + typeEvalContext, + ellipsis, + config, + excludes, + isDataclass, + genericTypeMap) result.runRemainingContributors(parameters) { completionResult -> @@ -238,14 +260,18 @@ class PydanticCompletionContributor : CompletionContributor() { .toHashSet() val config = getConfig(pyClassType.pyClass, typeEvalContext, true) val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis() - addAllFieldElement(parameters, + val genericTypeMap = typeProvider.getGenericTypeMap(pyClassType.pyClass, typeEvalContext) + addAllFieldElement( + parameters, result, pyClassType.pyClass, typeEvalContext, ellipsis, config, + genericTypeMap, definedSet, - isPydanticDataclass(pyClassType.pyClass)) + isPydanticDataclass(pyClassType.pyClass), + ) } } @@ -278,14 +304,18 @@ class PydanticCompletionContributor : CompletionContributor() { removeAllFieldElement(parameters, result, pyClassType.pyClass, typeEvalContext, excludeFields, config) return } + val genericTypeMap = typeProvider.getGenericTypeMap(pyClassType.pyClass, typeEvalContext) val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis() - addAllFieldElement(parameters, + addAllFieldElement( + parameters, result, pyClassType.pyClass, typeEvalContext, ellipsis, config, - isDataclass = isPydanticDataclass(pyClassType.pyClass)) + genericTypeMap, + isDataclass = isPydanticDataclass(pyClassType.pyClass), + ) } } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 7eb6a47c..58ca5b1c 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -3,6 +3,8 @@ package com.koxudaxi.pydantic import com.intellij.openapi.util.Ref import com.intellij.psi.PsiElement import com.intellij.psi.util.PsiTreeUtil +import com.jetbrains.python.PyCustomType +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.* import com.jetbrains.python.psi.types.* @@ -10,6 +12,7 @@ import com.koxudaxi.pydantic.PydanticConfigService.Companion.getInstance import one.util.streamex.StreamEx class PydanticTypeProvider : PyTypeProviderBase() { + private val pyTypingTypeProvider = PyTypingTypeProvider() override fun getReferenceExpressionType( referenceExpression: PyReferenceExpression, context: TypeEvalContext, @@ -71,20 +74,22 @@ class PydanticTypeProvider : PyTypeProviderBase() { context: TypeEvalContext, ellipsis: PyNoneLiteralExpression, pydanticVersion: KotlinVersion?, + genericTypeMap: Map?, ): Ref? { return pyClass.findClassAttribute(name, false, context) - ?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion) } + ?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion, genericTypeMap) } } private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): Ref? { val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis() val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context) - return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion) + val genericTypeMap = getGenericTypeMap(pyClass, context) + return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion, genericTypeMap) ?: pyClass.getAncestorClasses(context) .filter { isPydanticModel(it, false, context) } .mapNotNull { ancestor -> - getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion) + getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion, genericTypeMap) }.firstOrNull() } @@ -92,6 +97,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyTargetExpression: PyTargetExpression, ellipsis: PyNoneLiteralExpression, context: TypeEvalContext, pyClass: PyClass, pydanticVersion: KotlinVersion?, + genericTypeMap: Map?, ): Ref? { return fieldToParameter( pyTargetExpression, @@ -99,12 +105,92 @@ class PydanticTypeProvider : PyTypeProviderBase() { context, pyClass, pydanticVersion, - getConfig(pyClass, context, true) + getConfig(pyClass, context, true), + genericTypeMap ) ?.let { parameter -> Ref.create(parameter.getType(context)) } } + private fun getInjectedGenericType( + pyExpression: PyExpression, + context: TypeEvalContext, + ): PyType? { + return when (pyExpression) { + is PySubscriptionExpression -> { + val typingType = (pyExpression.rootOperand as? PyReferenceExpression) + ?.let { pyReferenceExpression -> + getResolvedPsiElements(pyReferenceExpression, context) + .filterIsInstance() + .any { it.qualifiedName == TYPE_Q_NAME } + } + return if (typingType == true) { + (pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) } + } else { + (context.getType(pyExpression) as? PyClassLikeType)?.toInstance() + } + } + else -> (context.getType(pyExpression) as? PyClassLikeType)?.toInstance() + } + } + + private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List { + return pyClass.superClassExpressions + .mapNotNull { + when (it) { + is PySubscriptionExpression -> it + is PyReferenceExpression -> getResolvedPsiElements(it, context) + .asSequence() + .filterIsInstance() + .firstOrNull() + else -> null + } + }.flatMap { pySubscriptionExpression -> + val referenceExpression = + pySubscriptionExpression.rootOperand as? PyReferenceExpression ?: return@flatMap emptyList() + val rootOperandType = context.getType(referenceExpression) ?: return@flatMap emptyList() + val isGenericModel = + rootOperandType is PyClassType && isSubClassOfPydanticGenericModel(rootOperandType.pyClass, context) + if (!isGenericModel && (rootOperandType as? PyCustomType)?.classQName != GENERIC_Q_NAME) return@flatMap emptyList() + + when (val indexExpression = pySubscriptionExpression.indexExpression) { + is PyTupleExpression -> indexExpression.elements + .map { context.getType(it) }.filterIsInstance().toList() + is PyGenericType -> listOf(context.getType(indexExpression)) + else -> null + } + }.filterNotNull().distinct() + } + + override fun prepareCalleeTypeForCall( + type: PyType?, + call: PyCallExpression, + context: TypeEvalContext, + ): Ref? { + val pyClassType = type as? PyClassType ?: return null + val pyClass = pyClassType.pyClass + if (!isSubClassOfPydanticGenericModel(pyClass, context) || isPydanticGenericModel(pyClass)) return null + val pySubscriptionExpression = call.node.firstChildNode.psi as? PySubscriptionExpression ?: return null + + + val injectedTypes = (pySubscriptionExpression.indexExpression as? PyTupleExpression) + ?.elements + ?.map { getInjectedGenericType(it, context) } + ?: listOf((pySubscriptionExpression.indexExpression?.let { getInjectedGenericType(it, context) })) + + + val genericTypeMap = collectGenericTypes(pyClass, context) + .take(injectedTypes.size) + .mapIndexed { index, genericType -> genericType to injectedTypes[index] } + .filterIsInstance>().toMap() + + return getPydanticTypeForClass( + pyClass, + context, + getInstance(pyClass.project).currentInitTyped, + genericTypeMap)?.let { Ref.create(it) } + } + private fun getPydanticTypeForCallee( referenceExpression: PyReferenceExpression, context: TypeEvalContext, @@ -225,6 +311,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { } ?: return null val modelName = when (modelNameArgument) { is PyReferenceExpression -> getResolvedPsiElements(modelNameArgument, context) + .asSequence() .filterIsInstance() .map { it.findAssignedValue() } .firstOrNull() @@ -238,6 +325,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { when (val baseArgument = (keywordArguments["__base__"] as? PyKeywordArgument)?.valueExpression) { is PyReferenceExpression -> { getResolvedPsiElements(baseArgument, context) + .asSequence() .map { when (it) { is PyTargetExpression -> getPydanticDynamicModelPyClass(it, context) @@ -324,7 +412,18 @@ class PydanticTypeProvider : PyTypeProviderBase() { return null } - fun getPydanticTypeForClass(pyClass: PyClass, context: TypeEvalContext, init: Boolean = false): PyCallableType? { + fun getGenericTypeMap(pyClass: PyClass, context: TypeEvalContext): Map { + if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return emptyMap() + return pyTypingTypeProvider.getGenericSubstitutions(pyClass, context).filterValues { it is PyType } + } + + + fun getPydanticTypeForClass( + pyClass: PyClass, + context: TypeEvalContext, + init: Boolean = false, + genericTypeMap: Map? = null, + ): PyCallableType? { if (!isPydanticModel(pyClass, false, context)) return null val clsType = (context.getType(pyClass) as? PyClassLikeType) ?: return null val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis() @@ -335,12 +434,17 @@ class PydanticTypeProvider : PyTypeProviderBase() { if (isSubClassOfBaseSetting(pyClass, context)) { getBaseSetting(pyClass, context)?.let { baseSetting -> getBaseSettingInitParameters(baseSetting, context, typed) - ?.map { parameter -> Pair(parameter.name, parameter) } + ?.map { parameter -> parameter.name to parameter } ?.filterIsInstance>() ?.let { collected.putAll(it) } } } - + val pyClassGenericTypeMap = getGenericTypeMap(pyClass, context) + val mergedGenericTypeMap = if (genericTypeMap is Map) { + pyClassGenericTypeMap.toMutableMap().apply { this.putAll(genericTypeMap) } + } else { + pyClassGenericTypeMap + } val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context) val config = getConfig(pyClass, context, true) for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) { @@ -351,7 +455,18 @@ class PydanticTypeProvider : PyTypeProviderBase() { getClassVariables(current, context) .filterNot { isUntouchedClass(it.findAssignedValue(), config, context) } - .mapNotNull { fieldToParameter(it, ellipsis, context, current, pydanticVersion, config, typed) } + .mapNotNull { + fieldToParameter( + it, + ellipsis, + context, + current, + pydanticVersion, + config, + mergedGenericTypeMap, + typed, + ) + } .filter { parameter -> parameter.name?.let { !collected.containsKey(it) } ?: false } .forEach { parameter -> collected[parameter.name!!] = parameter } } @@ -371,6 +486,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass: PyClass, pydanticVersion: KotlinVersion?, config: HashMap, + genericTypeMap: Map?, typed: Boolean = true, isDataclass: Boolean = false, ): PyCallableParameter? { @@ -385,15 +501,18 @@ class PydanticTypeProvider : PyTypeProviderBase() { val typeForParameter = when { !typed -> null - !hasAnnotationValue(field) && defaultValueFromField is PyTypedElement -> { - // get type from default value - context.getType(defaultValueFromField) - } - else -> { - // get type from annotation - getTypeForParameter(field, context) + !hasAnnotationValue(field) && defaultValueFromField is PyTypedElement -> context.getType( + defaultValueFromField) + else -> getTypeForParameter(field, context) + }?.let { + if (genericTypeMap == null) { + it + } else { + genericTypeMap[it] ?: it } } + // get type from default value + // get type from annotation return PyCallableParameterImpl.nonPsi( getFieldName(field, context, config, pydanticVersion), From a3bbd2688ac8041ddb7104c7e8840fe5aba61f98 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 5 May 2021 23:07:40 +0900 Subject: [PATCH 2/6] Fix comments --- src/com/koxudaxi/pydantic/PydanticTypeProvider.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 58ca5b1c..2424876a 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -501,8 +501,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { val typeForParameter = when { !typed -> null + // get type from default value !hasAnnotationValue(field) && defaultValueFromField is PyTypedElement -> context.getType( defaultValueFromField) + // get type from annotation else -> getTypeForParameter(field, context) }?.let { if (genericTypeMap == null) { @@ -511,8 +513,6 @@ class PydanticTypeProvider : PyTypeProviderBase() { genericTypeMap[it] ?: it } } - // get type from default value - // get type from annotation return PyCallableParameterImpl.nonPsi( getFieldName(field, context, config, pydanticVersion), From 0bbe7d4b205236ba99a02950cce71c17d11db018 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Thu, 6 May 2021 01:42:38 +0900 Subject: [PATCH 3/6] Support GenericModel on auto-completion --- src/com/koxudaxi/pydantic/Pydantic.kt | 6 +- .../pydantic/PydanticCompletionContributor.kt | 46 ++++---- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 102 ++++++++++-------- 3 files changed, 83 insertions(+), 71 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 5fb2de7f..a713eed9 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -136,7 +136,8 @@ fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext? = null): Boolean { return (isSubClassOfPydanticBaseModel(pyClass, - context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel(pyClass) + context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel( + pyClass) } fun isPydanticBaseModel(pyClass: PyClass): Boolean { @@ -495,7 +496,8 @@ fun getPydanticUnFilledArguments( context: TypeEvalContext, ): List { val pydanticClass = pyClass ?: getPydanticPyClass(pyCallExpression, context) ?: return emptyList() - val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pydanticClass, context, true) ?: return emptyList() + val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pydanticClass, context, true, pyCallExpression) + ?: return emptyList() val currentArguments = pyCallExpression.arguments.filter { it is PyKeywordArgument || (it as? PyStarArgumentImpl)?.isKeyword == true } .mapNotNull { it.name }.toSet() diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 1d9e1387..694a3ec8 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -70,7 +70,7 @@ class PydanticCompletionContributor : CompletionContributor() { pydanticVersion: KotlinVersion?, config: HashMap, isDataclass: Boolean, - genericTypeMap: Map?, + genericTypeMap: Map, ): String { val parameter = typeProvider.fieldToParameter(pyTargetExpression, @@ -105,7 +105,7 @@ class PydanticCompletionContributor : CompletionContributor() { config: HashMap, excludes: HashSet?, isDataclass: Boolean, - genericTypeMap: Map?, + genericTypeMap: Map, ) { val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext) getClassVariables(pyClass, typeEvalContext) @@ -138,7 +138,7 @@ class PydanticCompletionContributor : CompletionContributor() { pyClass: PyClass, typeEvalContext: TypeEvalContext, ellipsis: PyNoneLiteralExpression, config: HashMap, - genericTypeMap: Map?, + genericTypeMap: Map, excludes: HashSet? = null, isDataclass: Boolean, ) { @@ -248,29 +248,28 @@ class PydanticCompletionContributor : CompletionContributor() { val pyArgumentList = parameters.position.parent?.parent as? PyArgumentList ?: return val typeEvalContext = parameters.getTypeEvalContext() - val pyClassType = - (pyArgumentList.parent as? PyCallExpression)?.let { typeEvalContext.getType(it) } as? PyClassType + val pyCallExpression = pyArgumentList.parent as? PyCallExpression + val pyClass = + pyCallExpression?.let { (typeEvalContext.getType(it) as? PyClassType)?.pyClass } ?: return - if (!isPydanticModel(pyClassType.pyClass, true, typeEvalContext)) return + if (!isPydanticModel(pyClass, true, typeEvalContext)) return val definedSet = pyArgumentList.children .mapNotNull { (it as? PyKeywordArgument)?.name } .map { "${it}=" } .toHashSet() - val config = getConfig(pyClassType.pyClass, typeEvalContext, true) - val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis() - val genericTypeMap = typeProvider.getGenericTypeMap(pyClassType.pyClass, typeEvalContext) + addAllFieldElement( parameters, result, - pyClassType.pyClass, + pyClass, typeEvalContext, - ellipsis, - config, - genericTypeMap, + PyElementGenerator.getInstance(pyClass.project).createEllipsis(), + getConfig(pyClass, typeEvalContext, true), + typeProvider.getGenericTypeMap(pyClass, typeEvalContext, pyCallExpression), definedSet, - isPydanticDataclass(pyClassType.pyClass), + isPydanticDataclass(pyClass), ) } } @@ -293,28 +292,27 @@ class PydanticCompletionContributor : CompletionContributor() { result: CompletionResultSet, ) { val typeEvalContext = parameters.getTypeEvalContext() - val pyType = - (parameters.position.parent?.firstChild as? PyTypedElement)?.let { typeEvalContext.getType(it) } - ?: return + val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return + val pyType = typeEvalContext.getType(pyTypedElement) ?: return val pyClassType = getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true) } ?: return - val config = getConfig(pyClassType.pyClass, typeEvalContext, true) + val pyClass = pyClassType.pyClass + val config = getConfig(pyClass, typeEvalContext, true) if (pyClassType.isDefinition) { // class - removeAllFieldElement(parameters, result, pyClassType.pyClass, typeEvalContext, excludeFields, config) + removeAllFieldElement(parameters, result, pyClass, typeEvalContext, excludeFields, config) return } - val genericTypeMap = typeProvider.getGenericTypeMap(pyClassType.pyClass, typeEvalContext) - val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis() + val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis() addAllFieldElement( parameters, result, - pyClassType.pyClass, + pyClass, typeEvalContext, ellipsis, config, - genericTypeMap, - isDataclass = isPydanticDataclass(pyClassType.pyClass), + typeProvider.getGenericTypeMap(pyClass, typeEvalContext, pyTypedElement as? PyCallExpression), + isDataclass = isPydanticDataclass(pyClass), ) } } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 2424876a..24d18a05 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -74,22 +74,20 @@ class PydanticTypeProvider : PyTypeProviderBase() { context: TypeEvalContext, ellipsis: PyNoneLiteralExpression, pydanticVersion: KotlinVersion?, - genericTypeMap: Map?, ): Ref? { return pyClass.findClassAttribute(name, false, context) - ?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion, genericTypeMap) } + ?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion) } } private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): Ref? { val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis() val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context) - val genericTypeMap = getGenericTypeMap(pyClass, context) - return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion, genericTypeMap) + return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion) ?: pyClass.getAncestorClasses(context) .filter { isPydanticModel(it, false, context) } .mapNotNull { ancestor -> - getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion, genericTypeMap) + getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion) }.firstOrNull() } @@ -97,7 +95,6 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyTargetExpression: PyTargetExpression, ellipsis: PyNoneLiteralExpression, context: TypeEvalContext, pyClass: PyClass, pydanticVersion: KotlinVersion?, - genericTypeMap: Map?, ): Ref? { return fieldToParameter( pyTargetExpression, @@ -106,7 +103,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass, pydanticVersion, getConfig(pyClass, context, true), - genericTypeMap + getGenericTypeMap(pyClass, context) ) ?.let { parameter -> Ref.create(parameter.getType(context)) } @@ -167,45 +164,38 @@ class PydanticTypeProvider : PyTypeProviderBase() { call: PyCallExpression, context: TypeEvalContext, ): Ref? { - val pyClassType = type as? PyClassType ?: return null - val pyClass = pyClassType.pyClass + val pyClass = (type as? PyClassType)?.pyClass ?: return null if (!isSubClassOfPydanticGenericModel(pyClass, context) || isPydanticGenericModel(pyClass)) return null - val pySubscriptionExpression = call.node.firstChildNode.psi as? PySubscriptionExpression ?: return null - - - val injectedTypes = (pySubscriptionExpression.indexExpression as? PyTupleExpression) - ?.elements - ?.map { getInjectedGenericType(it, context) } - ?: listOf((pySubscriptionExpression.indexExpression?.let { getInjectedGenericType(it, context) })) - - - val genericTypeMap = collectGenericTypes(pyClass, context) - .take(injectedTypes.size) - .mapIndexed { index, genericType -> genericType to injectedTypes[index] } - .filterIsInstance>().toMap() return getPydanticTypeForClass( pyClass, context, getInstance(pyClass.project).currentInitTyped, - genericTypeMap)?.let { Ref.create(it) } + call + )?.let { Ref.create(it) } } private fun getPydanticTypeForCallee( referenceExpression: PyReferenceExpression, context: TypeEvalContext, ): PyType? { - if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null + val pyCallExpression = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null return getResolvedPsiElements(referenceExpression, context) .asSequence() .map { when { - it is PyClass -> getPydanticTypeForClass(it, context, true) + it is PyClass -> getPydanticTypeForClass(it, context, true, pyCallExpression) it is PyParameter && it.isSelf -> { PsiTreeUtil.getParentOfType(it, PyFunction::class.java) ?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD } - ?.let { it.containingClass?.let { getPydanticTypeForClass(it, context) } } + ?.let { + it.containingClass?.let { + getPydanticTypeForClass(it, + context, + pyCallExpression = pyCallExpression) + } + } } it is PyNamedParameter -> it.getArgumentType(context)?.let { pyType -> getPyClassTypeByPyTypes(pyType).filter { pyClassType -> @@ -214,7 +204,8 @@ class PydanticTypeProvider : PyTypeProviderBase() { getPydanticTypeForClass( filteredPyClassType.pyClass, context, - true + true, + pyCallExpression ) }.firstOrNull() } @@ -225,7 +216,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { ?.filter { pyClassType -> pyClassType.isDefinition } ?.filterNot { pyClassType -> pyClassType is PydanticDynamicModelClassType } ?.map { filteredPyClassType -> - getPydanticTypeForClass(filteredPyClassType.pyClass, context, true) + getPydanticTypeForClass(filteredPyClassType.pyClass, + context, + true, + pyCallExpression) }?.firstOrNull() } ?: getPydanticDynamicModelTypeForTargetExpression(it, context)?.pyCallableType else -> null @@ -412,17 +406,44 @@ class PydanticTypeProvider : PyTypeProviderBase() { return null } - fun getGenericTypeMap(pyClass: PyClass, context: TypeEvalContext): Map { + fun getGenericTypeMap( + pyClass: PyClass, + context: TypeEvalContext, + pyCallExpression: PyCallExpression? = null, + ): Map { if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return emptyMap() - return pyTypingTypeProvider.getGenericSubstitutions(pyClass, context).filterValues { it is PyType } - } + if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return emptyMap() + val pyClassGenericTypeMap = + pyTypingTypeProvider.getGenericSubstitutions(pyClass, context).filterValues { it is PyType } + val pySubscriptionExpression = when (val firstChild = pyCallExpression?.firstChild) { + is PySubscriptionExpression -> firstChild + is PyReferenceExpression -> getResolvedPsiElements(firstChild, context) + .firstOrNull() + ?.let { it as? PyTargetExpression } + ?.findAssignedValue() as? PySubscriptionExpression + else -> null + } ?: return pyClassGenericTypeMap + + val injectedTypes = (pySubscriptionExpression.indexExpression as? PyTupleExpression) + ?.elements + ?.map { getInjectedGenericType(it, context) } + ?: listOf((pySubscriptionExpression.indexExpression?.let { getInjectedGenericType(it, context) })) + return pyClassGenericTypeMap.toMutableMap().apply { + this.putAll(collectGenericTypes(pyClass, context) + .take(injectedTypes.size) + .mapIndexed { index, genericType -> genericType to injectedTypes[index] } + .filterIsInstance>().toMap() + ) + } + } + fun getPydanticTypeForClass( pyClass: PyClass, context: TypeEvalContext, init: Boolean = false, - genericTypeMap: Map? = null, + pyCallExpression: PyCallExpression, ): PyCallableType? { if (!isPydanticModel(pyClass, false, context)) return null val clsType = (context.getType(pyClass) as? PyClassLikeType) ?: return null @@ -439,12 +460,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { ?.let { collected.putAll(it) } } } - val pyClassGenericTypeMap = getGenericTypeMap(pyClass, context) - val mergedGenericTypeMap = if (genericTypeMap is Map) { - pyClassGenericTypeMap.toMutableMap().apply { this.putAll(genericTypeMap) } - } else { - pyClassGenericTypeMap - } + val genericTypeMap = getGenericTypeMap(pyClass, context, pyCallExpression) val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context) val config = getConfig(pyClass, context, true) for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) { @@ -463,7 +479,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { current, pydanticVersion, config, - mergedGenericTypeMap, + genericTypeMap, typed, ) } @@ -486,7 +502,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass: PyClass, pydanticVersion: KotlinVersion?, config: HashMap, - genericTypeMap: Map?, + genericTypeMap: Map, typed: Boolean = true, isDataclass: Boolean = false, ): PyCallableParameter? { @@ -507,11 +523,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { // get type from annotation else -> getTypeForParameter(field, context) }?.let { - if (genericTypeMap == null) { - it - } else { - genericTypeMap[it] ?: it - } + genericTypeMap[it] ?: it } return PyCallableParameterImpl.nonPsi( From 1b986a2f6b0c2955b5fa954676d7bd4dc197b128 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Fri, 7 May 2021 22:12:19 +0900 Subject: [PATCH 4/6] Add unittest --- src/com/koxudaxi/pydantic/Pydantic.kt | 2 +- .../pydantic/PydanticCompletionContributor.kt | 7 +- .../pydantic/PydanticDataclassTypeProvider.kt | 2 +- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 105 +++++++--- testData/completionv18/genericField.py | 21 ++ .../completionv18/genericKeywordArgument.py | 21 ++ testData/mock/pydanticv18/generics.py | 38 ++++ testData/typeinspectionv18/genericModel.py | 191 ++++++++++++++++++ .../pydantic/PydanticCompletionV18Test.kt | 80 ++++++++ .../pydantic/PydanticTypeInspectionV18Test.kt | 4 + 10 files changed, 436 insertions(+), 35 deletions(-) create mode 100644 testData/completionv18/genericField.py create mode 100644 testData/completionv18/genericKeywordArgument.py create mode 100644 testData/mock/pydanticv18/generics.py create mode 100644 testData/typeinspectionv18/genericModel.py create mode 100644 testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index a713eed9..1f90c497 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -55,7 +55,7 @@ const val ANNOTATED_Q_NAME = "typing.Annotated" const val CLASSVAR_Q_NAME = "typing.ClassVar" const val GENERIC_Q_NAME = "typing.Generic" const val TYPE_Q_NAME = "typing.Type" - +const val TUPLE_Q_NAME = "typing.Tuple" val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME) diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 694a3ec8..96d1e59b 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -12,6 +12,7 @@ import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHin import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyEvaluator import com.jetbrains.python.psi.types.PyClassType +import com.jetbrains.python.psi.types.PyGenericType import com.jetbrains.python.psi.types.PyType import com.jetbrains.python.psi.types.TypeEvalContext import javax.swing.Icon @@ -70,7 +71,7 @@ class PydanticCompletionContributor : CompletionContributor() { pydanticVersion: KotlinVersion?, config: HashMap, isDataclass: Boolean, - genericTypeMap: Map, + genericTypeMap: Map?, ): String { val parameter = typeProvider.fieldToParameter(pyTargetExpression, @@ -105,7 +106,7 @@ class PydanticCompletionContributor : CompletionContributor() { config: HashMap, excludes: HashSet?, isDataclass: Boolean, - genericTypeMap: Map, + genericTypeMap: Map?, ) { val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext) getClassVariables(pyClass, typeEvalContext) @@ -138,7 +139,7 @@ class PydanticCompletionContributor : CompletionContributor() { pyClass: PyClass, typeEvalContext: TypeEvalContext, ellipsis: PyNoneLiteralExpression, config: HashMap, - genericTypeMap: Map, + genericTypeMap: Map?, excludes: HashSet? = null, isDataclass: Boolean, ) { diff --git a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt index 10dd9fe4..7094a41d 100644 --- a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt @@ -53,7 +53,7 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() { return when { callSite is PyCallExpression && definition -> dataclassCallableType - definition -> (dataclassType.declarationElement as? PyTypedElement)?.let { context.getType(it) } + definition -> dataclassType.toClass() else -> dataclassType } } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 24d18a05..9b9d07ca 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -109,29 +109,51 @@ class PydanticTypeProvider : PyTypeProviderBase() { } + + private fun getPyType(pyExpression: PyExpression, context: TypeEvalContext): PyType? { + return when (val type = context.getType(pyExpression)) { + is PyClassLikeType -> type.toInstance() + else -> type + } + } + private fun getInjectedGenericType( pyExpression: PyExpression, context: TypeEvalContext, ): PyType? { - return when (pyExpression) { - is PySubscriptionExpression -> { - val typingType = (pyExpression.rootOperand as? PyReferenceExpression) - ?.let { pyReferenceExpression -> - getResolvedPsiElements(pyReferenceExpression, context) - .filterIsInstance() - .any { it.qualifiedName == TYPE_Q_NAME } + if (pyExpression is PySubscriptionExpression) { + val rootOperand = (pyExpression.rootOperand as? PyReferenceExpression) + ?.let { pyReferenceExpression -> + getResolvedPsiElements(pyReferenceExpression, context) + .asSequence() + .filterIsInstance() + .firstOrNull() + } + when (val qualifiedName = rootOperand?.qualifiedName) { + TYPE_Q_NAME -> return (pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) } + in listOf(TUPLE_Q_NAME, UNION_Q_NAME, OPTIONAL_Q_NAME) -> { + val indexExpression = pyExpression.indexExpression + when (indexExpression) { + is PyTupleExpression -> indexExpression.elements + .map { element -> getInjectedGenericType(element, context) } + is PySubscriptionExpression -> listOf(getInjectedGenericType(indexExpression, context)) + is PyTypedElement -> listOf(getPyType(indexExpression, context)) + else -> null + }?.let { + return when (qualifiedName) { + UNION_Q_NAME -> PyUnionType.union(it) + OPTIONAL_Q_NAME -> PyUnionType.union(it + PyNoneType.INSTANCE) + else -> PyTupleType.create(indexExpression as PsiElement, it) + } } - return if (typingType == true) { - (pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) } - } else { - (context.getType(pyExpression) as? PyClassLikeType)?.toInstance() } } - else -> (context.getType(pyExpression) as? PyClassLikeType)?.toInstance() } + return getPyType(pyExpression, context) } - private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List { + + private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List { return pyClass.superClassExpressions .mapNotNull { when (it) { @@ -154,9 +176,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { is PyTupleExpression -> indexExpression.elements .map { context.getType(it) }.filterIsInstance().toList() is PyGenericType -> listOf(context.getType(indexExpression)) + is PyTypedElement -> (context.getType(indexExpression) as? PyGenericType)?.let { listOf(it) } else -> null - } - }.filterNotNull().distinct() + } ?: emptyList() + }.filterIsInstance().distinct() } override fun prepareCalleeTypeForCall( @@ -193,7 +216,9 @@ class PydanticTypeProvider : PyTypeProviderBase() { it.containingClass?.let { getPydanticTypeForClass(it, context, - pyCallExpression = pyCallExpression) + true, + pyCallExpression + ) } } } @@ -216,10 +241,12 @@ class PydanticTypeProvider : PyTypeProviderBase() { ?.filter { pyClassType -> pyClassType.isDefinition } ?.filterNot { pyClassType -> pyClassType is PydanticDynamicModelClassType } ?.map { filteredPyClassType -> - getPydanticTypeForClass(filteredPyClassType.pyClass, + getPydanticTypeForClass( + filteredPyClassType.pyClass, context, true, - pyCallExpression) + pyCallExpression + ) }?.firstOrNull() } ?: getPydanticDynamicModelTypeForTargetExpression(it, context)?.pyCallableType else -> null @@ -410,11 +437,19 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass: PyClass, context: TypeEvalContext, pyCallExpression: PyCallExpression? = null, - ): Map { - if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return emptyMap() - if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return emptyMap() - val pyClassGenericTypeMap = - pyTypingTypeProvider.getGenericSubstitutions(pyClass, context).filterValues { it is PyType } + ): Map? { + if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return null + if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return null + + // class Response(GenericModel, Generic[TypeA, TypeB]): pass + val pyClassGenericTypeMap = pyTypingTypeProvider.getGenericSubstitutions(pyClass, context) + .mapNotNull { (key, value) -> + if (key is PyGenericType && value is PyType) { + Pair(key, value) + } else null + }.toMap() + + // Response[TypeA] val pySubscriptionExpression = when (val firstChild = pyCallExpression?.firstChild) { is PySubscriptionExpression -> firstChild is PyReferenceExpression -> getResolvedPsiElements(firstChild, context) @@ -422,11 +457,13 @@ class PydanticTypeProvider : PyTypeProviderBase() { ?.let { it as? PyTargetExpression } ?.findAssignedValue() as? PySubscriptionExpression else -> null - } ?: return pyClassGenericTypeMap + } ?: return pyClassGenericTypeMap.takeIf { it.isNotEmpty() } + // Response[TypeA, TypeB]() val injectedTypes = (pySubscriptionExpression.indexExpression as? PyTupleExpression) ?.elements ?.map { getInjectedGenericType(it, context) } + // Response[TypeA]() ?: listOf((pySubscriptionExpression.indexExpression?.let { getInjectedGenericType(it, context) })) @@ -434,9 +471,9 @@ class PydanticTypeProvider : PyTypeProviderBase() { this.putAll(collectGenericTypes(pyClass, context) .take(injectedTypes.size) .mapIndexed { index, genericType -> genericType to injectedTypes[index] } - .filterIsInstance>().toMap() + .filterIsInstance>().toMap() ) - } + }.takeIf { it.isNotEmpty() } } fun getPydanticTypeForClass( @@ -502,14 +539,15 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass: PyClass, pydanticVersion: KotlinVersion?, config: HashMap, - genericTypeMap: Map, + genericTypeMap: Map?, typed: Boolean = true, isDataclass: Boolean = false, ): PyCallableParameter? { if (!isValidField(field, context)) return null if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax - val defaultValueFromField = getDefaultValueForParameter(field, ellipsis, context, pydanticVersion, isDataclass) + val defaultValueFromField = + getDefaultValueForParameter(field, ellipsis, context, pydanticVersion, isDataclass) val defaultValue = when { isSubClassOfBaseSetting(pyClass, context) -> ellipsis else -> defaultValueFromField @@ -523,7 +561,11 @@ class PydanticTypeProvider : PyTypeProviderBase() { // get type from annotation else -> getTypeForParameter(field, context) }?.let { - genericTypeMap[it] ?: it + if (genericTypeMap == null) { + it + } else { + PyTypeChecker.substitute(it, genericTypeMap, context) + } } return PyCallableParameterImpl.nonPsi( @@ -712,7 +754,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { } } - private fun getDefaultValueForDataclass(assignedValue: PyCallExpression, context: TypeEvalContext): PyExpression? { + private fun getDefaultValueForDataclass( + assignedValue: PyCallExpression, + context: TypeEvalContext, + ): PyExpression? { val defaultValue = getDefaultValueForDataclass(assignedValue, context, "default") val defaultFactoryValue = getDefaultValueForDataclass(assignedValue, context, "default_factory") return when { diff --git a/testData/completionv18/genericField.py b/testData/completionv18/genericField.py new file mode 100644 index 00000000..89d1183e --- /dev/null +++ b/testData/completionv18/genericField.py @@ -0,0 +1,21 @@ +from typing import TypeVar, Type, List, Dict, Generic, Optional +from pydantic.generics import GenericModel + + +AT = TypeVar('AT') +BT = TypeVar('BT') +CT = TypeVar('CT') +DT = TypeVar('DT') +ET = TypeVar('ET') + + +class A(GenericModel, Generic[AT, BT, CT, DT]): + a: Type[AT] + b: List[BT] + c: Dict[CT, DT] + +class B(A[int, BT, CT, DT], Generic[BT, CT, DT, ET]): + hij: Optional[ET] + + +B[str, float, bytes, bool](). \ No newline at end of file diff --git a/testData/completionv18/genericKeywordArgument.py b/testData/completionv18/genericKeywordArgument.py new file mode 100644 index 00000000..615863fa --- /dev/null +++ b/testData/completionv18/genericKeywordArgument.py @@ -0,0 +1,21 @@ +from typing import TypeVar, Type, List, Dict, Generic, Optional +from pydantic.generics import GenericModel + + +AT = TypeVar('AT') +BT = TypeVar('BT') +CT = TypeVar('CT') +DT = TypeVar('DT') +ET = TypeVar('ET') + + +class A(GenericModel, Generic[AT, BT, CT, DT]): + a: Type[AT] + b: List[BT] + c: Dict[CT, DT] + +class B(A[int, BT, CT, DT], Generic[BT, CT, DT, ET]): + hij: Optional[ET] + + +B[str, float, bytes, bool]() \ No newline at end of file diff --git a/testData/mock/pydanticv18/generics.py b/testData/mock/pydanticv18/generics.py new file mode 100644 index 00000000..1e5eac45 --- /dev/null +++ b/testData/mock/pydanticv18/generics.py @@ -0,0 +1,38 @@ +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + get_type_hints, +) + +from .main import BaseModel + +_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {} +GenericModelT = TypeVar('GenericModelT', bound='GenericModel') +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type + + +class GenericModel(BaseModel): + __slots__ = () + __concrete__: ClassVar[bool] = False + + if TYPE_CHECKING: + # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with + # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of + # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. + __parameters__: ClassVar[Tuple[TypeVarType, ...]] + + # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings + def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]: + pass \ No newline at end of file diff --git a/testData/typeinspectionv18/genericModel.py b/testData/typeinspectionv18/genericModel.py new file mode 100644 index 00000000..fdd07866 --- /dev/null +++ b/testData/typeinspectionv18/genericModel.py @@ -0,0 +1,191 @@ +from typing import Generic, TypeVar, Optional, List, Tuple, Any, Dict, Type, Union + +from pydantic import BaseModel +from pydantic.generics import GenericModel + +DataT = TypeVar('DataT') + + +class Error(BaseModel): + code: int + message: str + + +class DataModel(BaseModel): + numbers: List[int] + people: List[str] + + +class Response(GenericModel, Generic[DataT]): + data: Optional[DataT] + error: Optional[Error] + + +Response[int](data=1, error=None) + +Response[str](data=1, error=None) + +TypeX = TypeVar('TypeX') + +class BaseClass(GenericModel, Generic[TypeX]): + X: TypeX + + +class ChildClass(BaseClass[TypeX], Generic[TypeX]): + # Inherit from Generic[TypeX] + pass + + +ChildClass[int](X=1) + +ChildClass[str](X=1) + + + +TypeX = TypeVar('TypeX') +TypeY = TypeVar('TypeY') +TypeZ = TypeVar('TypeZ') + + +class BaseClass(GenericModel, Generic[TypeX, TypeY]): + x: TypeX + y: TypeY + + +class ChildClass(BaseClass[int, TypeY], Generic[TypeY, TypeZ]): + z: TypeZ + + +# Replace TypeY by str +ChildClass[str, float](x=1, y='y', z=3.1) + +ChildClass[float, bytes](x=b'1', y='y', z=1_3) + +DataT = TypeVar('DataT') + + +class Response(GenericModel, Generic[DataT]): + data: DataT + + @classmethod + def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: + return f'{params[0].__name__.title()}Response' + + +Response[int](data=1) + +Response[str](data='a') + + +Response[str](data=1) +Response[float](data='a') + + +T = TypeVar('T') + + +class InnerT(GenericModel, Generic[T]): + inner: T + + +class OuterT(GenericModel, Generic[T]): + outer: T + nested: InnerT[T] + + +nested = InnerT[int](inner=1) +OuterT[int](outer=1, nested=nested) + +nested = InnerT[str](inner='a') +OuterT[int](outer='a', nested=nested) + +AT = TypeVar('AT') +BT = TypeVar('BT') + + +class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + +Model(a='a', b='a') +#> a='a' b='a' + +IntT = TypeVar('IntT', bound=int) +typevar_model = Model[int, IntT] +typevar_model(a=1, b=1) + + +typevar_model(a='a', b='a') + +concrete_model = typevar_model[int] +concrete_model(a=1, b=1) + + +CT = TypeVar('CT') +DT = TypeVar('DT') +ET = TypeVar('ET') +FT = TypeVar('FT') + +class Model(GenericModel, Generic[CT, DT, ET, FT]): + a: Type[CT] + b: List[DT] + c: Dict[ET, FT] + +Model[int, int, str, int](a=int, b=[2], c={'c': 3}) + +Model[int, int, int, int](a=1, b=2, c=3) + +class Model(GenericModel, Generic[CT, DT]): + a: List[Type[CT]] + b: Union[List[DT], float] + +Model[int, int](a=[int], b=[2]) + +Model[int, int](a=1, b='2') + +class Model(GenericModel, Generic[CT, DT, ET, FT]): + a: CT + b: DT + c: ET + d: FT + +Model[Type[int], List[Type[int]], Optional[Type[int]], Tuple[Type[int]]](a=int, b=[int], c=int, d=(int,)) + +Model[Type[int], List[Type[int]], Optional[Type[int]], Tuple[Type[int]]](a=1, b=[2], c=[3], d=(4, )) + +class Model(GenericModel, Generic[CT, Broken]): + a: CT + b: Broken + +Model[int, int, int](a=1, b=2) + +Model[str, str, str](a=1, b=2) + +class Model(GenericModel, Generic[CT]): + a: CT + +Model[Union[int, float]](a=1) + +Model[Union[int, float]](a='1') + +class Model(GenericModel, Generic[CT, DT]): + a: CT + b: DT + +def x(b: ET) -> ET: + pass + +Model[x(int), Optional[x(int)]](a=1, b=2) + +Model[x(int), Optional[x(int)]](a='1', b='2') + +class Model(GenericModel, Generic[CT, DT]): + a: CT + b: DT + +y = int + +Model[y, Optional[y]](a=1, b=2) + +Model[y, Optional[y]](a='1', b='2') diff --git a/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt b/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt new file mode 100644 index 00000000..dce728cf --- /dev/null +++ b/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt @@ -0,0 +1,80 @@ +package com.koxudaxi.pydantic + +import com.intellij.codeInsight.lookup.LookupElementPresentation +import com.jetbrains.python.psi.PyTargetExpression + + +open class PydanticCompletionV18Test : PydanticTestCase(version = "v18") { + + + private fun doFieldTest(fieldNames: List>, additionalModules: List? = null) { + configureByFile(additionalModules) + val excludes = listOf( + "__annotations__", + "__base__", + "__bases__", + "__basicsize__", + "__dict__", + "__dictoffset__", + "__flags__", + "__itemsize__", + "__mro__", + "__name__", + "__qualname__", + "__slots__", + "__text_signature__", + "__weakrefoffset__", + "Ellipsis", + "EnvironmentError", + "IOError", + "NotImplemented", + "List", + "Type", + "Annotated", + "MISSING", + "WindowsError", + "__concrete__", + "__parameters__", + "___slots__", + "Generic", + "Dict", + "Optional", + ) + val actual = myFixture!!.completeBasic().filter { + it!!.psiElement is PyTargetExpression + }.filterNot { + excludes.contains(it!!.lookupString) + }.mapNotNull { + Pair(it!!.lookupString, LookupElementPresentation.renderElement(it).typeText ?: "null") + } + assertEquals(fieldNames, actual) + } + + fun testGenericField() { + doFieldTest( + listOf( + Pair("a", "Type[int] A"), + Pair("b", "List[str] A"), + Pair("c", "Dict[float, bytes] A"), + Pair("hij", "Optional[bool]=None B"), + ) + ) + } + + + fun testGenericKeywordArgument() { + doFieldTest( + listOf( + Pair("a=", "Type[int] A"), + Pair("b=", "List[str] A"), + Pair("c=", "Dict[float, bytes] A"), + Pair("hij=", "Optional[bool]=None B"), + Pair("AT", "null"), + Pair("BT", "null"), + Pair("CT", "null"), + Pair("DT", "null"), + Pair("ET", "null")) + ) + } + +} diff --git a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt index e4e9467b..16e28232 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt @@ -20,4 +20,8 @@ open class PydanticTypeInspectionV18Test : PydanticInspectionBase("v18") { fun testBaseSetting() { doTest() } + + fun testGenericModel() { + doTest() + } } \ No newline at end of file From cc5bec9f91c4e1ba500f01e0cf2949614ad4abee Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Fri, 7 May 2021 23:28:16 +0900 Subject: [PATCH 5/6] Improve detecting pydnatic models --- resources/META-INF/plugin.xml | 1 + src/com/koxudaxi/pydantic/Pydantic.kt | 17 ++++-- .../pydantic/PydanticCompletionContributor.kt | 18 +++--- .../pydantic/PydanticFieldRenameFactory.kt | 31 +++++++---- .../pydantic/PydanticFieldSearchExecutor.kt | 55 +++++++++++++------ .../koxudaxi/pydantic/PydanticInspection.kt | 4 +- .../pydantic/PydanticTypeCheckerInspection.kt | 2 +- testData/typeinspectionv18/genericModel.py | 8 +++ 8 files changed, 92 insertions(+), 44 deletions(-) diff --git a/resources/META-INF/plugin.xml b/resources/META-INF/plugin.xml index 2adb21fa..32a053bb 100644 --- a/resources/META-INF/plugin.xml +++ b/resources/META-INF/plugin.xml @@ -7,6 +7,7 @@

version 0.3.1

Features

    +
  • Support GenericModel [#289]
  • Support frozen on config [#288]
  • Fix format [#287]
  • Improve handling pydantic version [#286]
  • diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 1f90c497..ed2066a4 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -126,7 +126,11 @@ fun getPyClassByPyCallExpression( is PyClassType -> type else -> (callee.reference?.resolve() as? PyTypedElement)?.let { context.getType(it) } ?: return null } - return getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, includeDataclass) }?.pyClass + return getPyClassTypeByPyTypes(pyType).firstOrNull { + isPydanticModel(it.pyClass, + includeDataclass, + context) + }?.pyClass } fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? { @@ -134,9 +138,10 @@ fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: return getPyClassByPyCallExpression(pyCallExpression, true, context) } -fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext? = null): Boolean { +fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean { return (isSubClassOfPydanticBaseModel(pyClass, - context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel( + context) || isSubClassOfPydanticGenericModel(pyClass, context) || (includeDataclass && isPydanticDataclass( + pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel( pyClass) } @@ -148,11 +153,11 @@ fun isPydanticGenericModel(pyClass: PyClass): Boolean { return pyClass.qualifiedName == GENERIC_MODEL_Q_NAME } -internal fun isSubClassOfPydanticGenericModel(pyClass: PyClass, context: TypeEvalContext?): Boolean { +internal fun isSubClassOfPydanticGenericModel(pyClass: PyClass, context: TypeEvalContext): Boolean { return pyClass.isSubclass(GENERIC_MODEL_Q_NAME, context) } -internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext?): Boolean { +internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext): Boolean { return pyClass.isSubclass(BASE_MODEL_Q_NAME, context) } @@ -375,7 +380,7 @@ fun getConfig( val version = pydanticVersion ?: PydanticVersionService.getVersion(pyClass.project, context) pyClass.getAncestorClasses(context) .reversed() - .filter { isPydanticModel(it, false) } + .filter { isPydanticModel(it, false, context) } .map { getConfig(it, context, false, version) } .forEach { it.entries.forEach { entry -> diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 96d1e59b..f6abfde2 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -147,7 +147,7 @@ class PydanticCompletionContributor : CompletionContributor() { val newElements: LinkedHashMap = LinkedHashMap() pyClass.getAncestorClasses(typeEvalContext) - .filter { isPydanticModel(it, true) } + .filter { isPydanticModel(it, true, typeEvalContext) } .forEach { addFieldElement(it, newElements, @@ -183,12 +183,12 @@ class PydanticCompletionContributor : CompletionContributor() { excludes: HashSet, config: HashMap, ) { - if (!isPydanticModel(pyClass, true)) return + if (!isPydanticModel(pyClass, true, typeEvalContext)) return val fieldElements: HashSet = HashSet() pyClass.getAncestorClasses(typeEvalContext) - .filter { isPydanticModel(it, true) } + .filter { isPydanticModel(it, true, typeEvalContext) } .forEach { fieldElements.addAll(it.classAttributes .filterNot { attribute -> @@ -294,10 +294,12 @@ class PydanticCompletionContributor : CompletionContributor() { ) { val typeEvalContext = parameters.getTypeEvalContext() val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return + val pyType = typeEvalContext.getType(pyTypedElement) ?: return - val pyClassType = getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true) } - ?: return + val pyClassType = + getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) } + ?: return val pyClass = pyClassType.pyClass val config = getConfig(pyClass, typeEvalContext, true) if (pyClassType.isDefinition) { // class @@ -357,8 +359,9 @@ class PydanticCompletionContributor : CompletionContributor() { val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return if (!isConfigClass(configClass)) return val pydanticModel = getPyClassByAttribute(configClass) ?: return - if (!isPydanticModel(pydanticModel, true)) return val typeEvalContext = parameters.getTypeEvalContext() + if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return + val definedSet = configClass.classAttributes .mapNotNull { it.name } @@ -383,7 +386,8 @@ class PydanticCompletionContributor : CompletionContributor() { result: CompletionResultSet, ) { val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return - if (!isPydanticModel(pydanticModel, true)) return + val typeEvalContext = parameters.getTypeEvalContext() + if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return if (pydanticModel.findNestedClass("Config", false) != null) return val element = PrioritizedLookupElement.withGrouping( LookupElementBuilder diff --git a/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt b/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt index 4fdc9f00..e10534ae 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt @@ -20,12 +20,13 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory { when (element) { is PyTargetExpression -> { val pyClass = element.containingClass ?: return false - if (isPydanticModel(pyClass, true)) return true + val content = TypeEvalContext.codeAnalysis(element.project, element.containingFile) + if (isPydanticModel(pyClass, true, content)) return true } is PyKeywordArgument -> { - val pyClass = getPyClassByPyKeywordArgument(element, - TypeEvalContext.codeAnalysis(element.project, element.containingFile)) ?: return false - if (isPydanticModel(pyClass, true)) return true + val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile) + val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false + if (isPydanticModel(pyClass, true, context)) return true } } return false @@ -55,33 +56,39 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory { element.name?.let { name -> element.containingClass ?.let { pyClass -> - addAllElement(pyClass, name, added) + val content = TypeEvalContext.codeAnalysis(element.project, element.containingFile) + addAllElement(pyClass, name, added, content) } suggestAllNames(name, newName) } is PyKeywordArgument -> element.name?.let { name -> - getPyClassByPyKeywordArgument(element, - TypeEvalContext.userInitiated(element.project, element.containingFile)) + val context = TypeEvalContext.userInitiated(element.project, element.containingFile) + getPyClassByPyKeywordArgument(element, context) ?.let { pyClass -> - addAllElement(pyClass, name, added) + addAllElement(pyClass, name, added, context) } suggestAllNames(name, newName) } } } - private fun addAllElement(pyClass: PyClass, elementName: String, added: MutableSet) { + private fun addAllElement( + pyClass: PyClass, + elementName: String, + added: MutableSet, + context: TypeEvalContext, + ) { added.add(pyClass) addClassAttributes(pyClass, elementName) addKeywordArguments(pyClass, elementName) pyClass.getAncestorClasses(null) - .filter { isPydanticModel(it, true) && !added.contains(it) } - .forEach { addAllElement(it, elementName, added) } + .filter { isPydanticModel(it, true, context) && !added.contains(it) } + .forEach { addAllElement(it, elementName, added, context) } PyClassInheritorsSearch.search(pyClass, true) .filterNot { added.contains(it) } - .forEach { addAllElement(it, elementName, added) } + .forEach { addAllElement(it, elementName, added, context) } } private fun addClassAttributes(pyClass: PyClass, elementName: String) { diff --git a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt index 0f4118f7..5aea4706 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt @@ -21,30 +21,37 @@ class PydanticFieldSearchExecutor : QueryExecutorBase run { element.name ?.let { elementName -> - getPyClassByPyKeywordArgument(element, - TypeEvalContext.userInitiated(element.project, element.containingFile)) - ?.takeIf { pyClass -> isPydanticModel(pyClass, true) } - ?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer) } + val context = TypeEvalContext.userInitiated(element.project, element.containingFile) + getPyClassByPyKeywordArgument(element, context) + ?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) } + ?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) } } } is PyTargetExpression -> run { element.name ?.let { elementName -> + val context = TypeEvalContext.userInitiated(element.project, element.containingFile) element.containingClass - ?.takeIf { pyClass -> isPydanticModel(pyClass, true) } + ?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) } ?.let { pyClass -> searchAllElementReference(pyClass, elementName, mutableSetOf(), - consumer) + consumer, + context) } } } } } - private fun searchField(pyClass: PyClass, elementName: String, consumer: Processor): Boolean { - if (!isPydanticModel(pyClass, true)) return false + private fun searchField( + pyClass: PyClass, + elementName: String, + consumer: Processor, + context: TypeEvalContext, + ): Boolean { + if (!isPydanticModel(pyClass, true, context)) return false val pyTargetExpression = pyClass.findClassAttribute(elementName, false, null) ?: return false consumer.process(pyTargetExpression.reference) return true @@ -61,7 +68,12 @@ class PydanticFieldSearchExecutor : QueryExecutorBase) { + private fun searchKeywordArgument( + pyClass: PyClass, + elementName: String, + consumer: Processor, + typeEvalContext: TypeEvalContext, + ) { ReferencesSearch.search(pyClass as PsiElement).forEach { psiReference -> searchKeywordArgumentByPsiReference(psiReference, elementName, consumer) @@ -72,7 +84,11 @@ class PydanticFieldSearchExecutor : QueryExecutorBase getPyClassTypeByPyTypes(pyType) - .firstOrNull { pyClassType -> isPydanticModel(pyClassType.pyClass, true) } + .firstOrNull { pyClassType -> + isPydanticModel(pyClassType.pyClass, + true, + typeEvalContext) + } ?.let { ReferencesSearch.search(param as PsiElement).forEach { searchKeywordArgumentByPsiReference(it, elementName, consumer) @@ -89,11 +105,17 @@ class PydanticFieldSearchExecutor : QueryExecutorBase, + context: TypeEvalContext, ): Boolean { - if (searchField(pyClass, elementName, consumer)) return true + if (searchField(pyClass, elementName, consumer, context)) return true return pyClass.getAncestorClasses(null) - .firstOrNull { isPydanticModel(it, true) && searchDirectReferenceField(it, elementName, consumer) } != null + .firstOrNull { + isPydanticModel(it, true, context) && searchDirectReferenceField(it, + elementName, + consumer, + context) + } != null } private fun searchAllElementReference( @@ -101,16 +123,17 @@ class PydanticFieldSearchExecutor : QueryExecutorBase, consumer: Processor, + context: TypeEvalContext, ) { added.add(pyClass) - searchField(pyClass, elementName, consumer) - searchKeywordArgument(pyClass, elementName, consumer) + searchField(pyClass, elementName, consumer, context) + searchKeywordArgument(pyClass, elementName, consumer, context) pyClass.getAncestorClasses(null) .filter { !isPydanticBaseModel(it) && !added.contains(it) } - .forEach { searchField(it, elementName, consumer) } + .forEach { searchField(it, elementName, consumer, context) } PyClassInheritorsSearch.search(pyClass, true) .filterNot { added.contains(it) } - .forEach { searchAllElementReference(it, elementName, added, consumer) } + .forEach { searchAllElementReference(it, elementName, added, consumer, context) } } } diff --git a/src/com/koxudaxi/pydantic/PydanticInspection.kt b/src/com/koxudaxi/pydantic/PydanticInspection.kt index 3c7aea97..fc5d826e 100644 --- a/src/com/koxudaxi/pydantic/PydanticInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticInspection.kt @@ -99,11 +99,11 @@ class PydanticInspection : PyInspection() { is PyClass -> type is PyClassType -> getPyClassTypeByPyTypes(type).firstOrNull { isPydanticModel(it.pyClass, - false) + false, myTypeEvalContext) }?.pyClass else -> null } ?: return - if (!isPydanticModel(pyClass, false)) return + if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return val config = getConfig(pyClass, myTypeEvalContext, true) if (config["orm_mode"] != true) { registerProblem(pyCallExpression, diff --git a/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt b/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt index c287cd17..4864f505 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt @@ -36,7 +36,7 @@ class PydanticTypeCheckerInspection : PyTypeCheckerInspection() { override fun visitPyCallExpression(node: PyCallExpression) { val pyClass = getPyClassByPyCallExpression(node, true, myTypeEvalContext) getPyClassByPyCallExpression(node, true, myTypeEvalContext) - if (pyClass is PyClass && isPydanticModel(pyClass, true)) { + if (pyClass is PyClass && isPydanticModel(pyClass, true, myTypeEvalContext)) { checkCallSiteForPydantic(node) return } diff --git a/testData/typeinspectionv18/genericModel.py b/testData/typeinspectionv18/genericModel.py index fdd07866..c7a4ddb8 100644 --- a/testData/typeinspectionv18/genericModel.py +++ b/testData/typeinspectionv18/genericModel.py @@ -189,3 +189,11 @@ class Model(GenericModel, Generic[CT, DT]): Model[y, Optional[y]](a=1, b=2) Model[y, Optional[y]](a='1', b='2') + + +class Model(GenericModel, Generic[CT, DT, ET, FT, aaaaaaaaaa]): + a: Type[CT] + b: List[aaaaa] + c: Dict[ET, aaaaaaaa] + +Model[aaaaaaaaaa, List[aaaaaa], Tuple[aaaaaaaaaa], Type[aaaaaaaaaaa]](a=int, b=[2], c={'c': 3}) \ No newline at end of file From 3d6c2c9d04bbd0ccbe3be92b400d6ca7f7e48e84 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Sat, 8 May 2021 00:16:46 +0900 Subject: [PATCH 6/6] Improve code style --- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 62 +++++++------------ 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 9b9d07ca..7ed64241 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -42,13 +42,11 @@ class PydanticTypeProvider : PyTypeProviderBase() { context: TypeEvalContext, anchor: PsiElement?, ): Ref? { - if (referenceTarget is PyTargetExpression) { - val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null - if (!isPydanticModel(pyClass, false, context)) return null - val name = referenceTarget.name ?: return null - getRefTypeFromFieldName(name, context, pyClass)?.let { return it } - } - return null + if (referenceTarget !is PyTargetExpression) return null + val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null + if (!isPydanticModel(pyClass, false, context)) return null + val name = referenceTarget.name ?: return null + return getRefTypeFromFieldName(name, context, pyClass) } override fun getParameterType(param: PyNamedParameter, func: PyFunction, context: TypeEvalContext): Ref? { @@ -106,7 +104,6 @@ class PydanticTypeProvider : PyTypeProviderBase() { getGenericTypeMap(pyClass, context) ) ?.let { parameter -> Ref.create(parameter.getType(context)) } - } @@ -173,10 +170,8 @@ class PydanticTypeProvider : PyTypeProviderBase() { if (!isGenericModel && (rootOperandType as? PyCustomType)?.classQName != GENERIC_Q_NAME) return@flatMap emptyList() when (val indexExpression = pySubscriptionExpression.indexExpression) { - is PyTupleExpression -> indexExpression.elements - .map { context.getType(it) }.filterIsInstance().toList() - is PyGenericType -> listOf(context.getType(indexExpression)) - is PyTypedElement -> (context.getType(indexExpression) as? PyGenericType)?.let { listOf(it) } + is PyTupleExpression -> indexExpression.elements.map { context.getType(it) }.toList() + is PyTypedElement -> listOf(context.getType(indexExpression)) else -> null } ?: emptyList() }.filterIsInstance().distinct() @@ -209,19 +204,16 @@ class PydanticTypeProvider : PyTypeProviderBase() { .map { when { it is PyClass -> getPydanticTypeForClass(it, context, true, pyCallExpression) - it is PyParameter && it.isSelf -> { + it is PyParameter && it.isSelf -> PsiTreeUtil.getParentOfType(it, PyFunction::class.java) ?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD } - ?.let { - it.containingClass?.let { - getPydanticTypeForClass(it, - context, - true, - pyCallExpression - ) - } + ?.containingClass?.let { + getPydanticTypeForClass(it, + context, + true, + pyCallExpression + ) } - } it is PyNamedParameter -> it.getArgumentType(context)?.let { pyType -> getPyClassTypeByPyTypes(pyType).filter { pyClassType -> pyClassType.isDefinition @@ -278,9 +270,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyTargetExpression: PyTargetExpression, context: TypeEvalContext, ): PyClass? { - val pyCallableType = getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression, context) - ?: return null - return pyCallableType.pyClass + return getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression, context)?.pyClass } private fun getPydanticDynamicModelTypeForTargetExpression( @@ -303,8 +293,6 @@ class PydanticTypeProvider : PyTypeProviderBase() { .filterIsInstance() .map { it.takeIf { pyFunction -> isPydanticCreateModel(pyFunction) } }.firstOrNull() ?: return null - - return getPydanticDynamicModelTypeForFunction(pyFunction, arguments, context) } @@ -423,14 +411,13 @@ class PydanticTypeProvider : PyTypeProviderBase() { } private fun getBaseSetting(pyClass: PyClass, context: TypeEvalContext): PyClass? { - pyClass.getSuperClasses(context).forEach { - return if (isBaseSetting(it)) { + return pyClass.getSuperClasses(context).mapNotNull { + if (isBaseSetting(it)) { it } else { getBaseSetting(it, context) } - } - return null + }.firstOrNull() } fun getGenericTypeMap( @@ -443,11 +430,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { // class Response(GenericModel, Generic[TypeA, TypeB]): pass val pyClassGenericTypeMap = pyTypingTypeProvider.getGenericSubstitutions(pyClass, context) - .mapNotNull { (key, value) -> - if (key is PyGenericType && value is PyType) { - Pair(key, value) - } else null - }.toMap() + .mapNotNull { (key, value) -> key to value }.filterIsInstance>().toMap() // Response[TypeA] val pySubscriptionExpression = when (val firstChild = pyCallExpression?.firstChild) { @@ -561,10 +544,9 @@ class PydanticTypeProvider : PyTypeProviderBase() { // get type from annotation else -> getTypeForParameter(field, context) }?.let { - if (genericTypeMap == null) { - it - } else { - PyTypeChecker.substitute(it, genericTypeMap, context) + when (genericTypeMap) { + null -> it + else -> PyTypeChecker.substitute(it, genericTypeMap, context) } }