From 241a9abcdd46e2c0c258f586fa98dc9983a71b9b Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 28 Feb 2023 00:43:05 +0900 Subject: [PATCH 01/12] Optimize resolve --- src/com/koxudaxi/pydantic/Pydantic.kt | 27 +++++-------------- .../koxudaxi/pydantic/PydanticAnnotator.kt | 6 ++--- .../pydantic/PydanticFieldRenameFactory.kt | 6 ++--- .../pydantic/PydanticFieldSearchExecutor.kt | 4 +-- .../pydantic/PydanticTypeCheckerInspection.kt | 5 ++-- .../inspection/acceptsOnlyKeywordArguments.py | 2 +- ...sOnlyKeywordArgumentsSingleStarArgument.py | 2 +- .../typeinspection/ignoreInitArguments.py | 2 +- testData/typeinspectionv18/sqlModel.py | 2 +- 9 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index f450af00..84b602cb 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -135,27 +135,14 @@ const val CUSTOM_ROOT_FIELD = "__root__" fun PyTypedElement.getType(context: TypeEvalContext): PyType? = context.getType(this) -fun getPyClassByPyCallExpression( - pyCallExpression: PyCallExpression, + +fun getPydanticModelByPyKeywordArgument( + pyKeywordArgument: PyKeywordArgument, includeDataclass: Boolean, context: TypeEvalContext, ): PyClass? { - val callee = pyCallExpression.callee ?: return null - val pyType = when (val type = callee.getType(context)) { - is PyClass -> return type - is PyClassType -> type - else -> (callee.reference?.resolve() as? PyTypedElement)?.getType(context) ?: return null - } - return pyType.pyClassTypes.firstOrNull { - isPydanticModel(it.pyClass, - includeDataclass, - context) - }?.pyClass -} - -fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? { val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null - return getPyClassByPyCallExpression(pyCallExpression, true, context) + return getPydanticPyClass(pyCallExpression, context, includeDataclass) } fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean { @@ -505,9 +492,9 @@ fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: Type } fun getPydanticPyClass(pyCallExpression: PyCallExpression, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? { - val pyClass = getPyClassByPyCallExpression(pyCallExpression, includeDataclass, context) ?: return null - if (!isPydanticModel(pyClass, includeDataclass, context)) return null - return pyClass + return context.getType(pyCallExpression)?.pyClassTypes?.firstOrNull { + isPydanticModel(it.pyClass, includeDataclass, context) + }?.pyClass } fun getAncestorPydanticModels(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): List { diff --git a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt index cf80b4c0..10f0a57d 100644 --- a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt +++ b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt @@ -20,9 +20,9 @@ class PydanticAnnotator : PyAnnotator() { private fun annotatePydanticModelCallableExpression(pyCallExpression: PyCallExpression) { val context = TypeEvalContext.codeAnalysis(pyCallExpression.project, pyCallExpression.containingFile) if (!pyCallExpression.isDefinitionCallExpression(context)) return - - val pyClass = getPydanticPyClass(pyCallExpression, context) ?: return - if (getPydanticModelInit(pyClass, context) != null) return + val pyClassType = pyCallExpression.callee?.getType(context)?.pyClassTypes?.firstOrNull() + val pyClass = pyClassType?.pyClass ?: return + if (getPydanticModelInit(pyClassType.pyClass, context) != null) return val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, pyCallExpression) ?: return val unFilledArguments = diff --git a/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt b/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt index 947d327e..83a50fe6 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt @@ -25,8 +25,8 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory { } is PyKeywordArgument -> { val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile) - val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false - if (isPydanticModel(pyClass, true, context)) return true + return getPydanticModelByPyKeywordArgument(element, true,context) is PyClass +// } } return false @@ -64,7 +64,7 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory { is PyKeywordArgument -> element.name?.let { name -> val context = TypeEvalContext.userInitiated(element.project, element.containingFile) - getPyClassByPyKeywordArgument(element, context) + getPydanticModelByPyKeywordArgument(element, true,context) ?.let { pyClass -> addAllElement(pyClass, name, added, context) } diff --git a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt index 1cca22ae..43b2a4b7 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt @@ -22,8 +22,8 @@ class PydanticFieldSearchExecutor : QueryExecutorBase val context = TypeEvalContext.userInitiated(element.project, element.containingFile) - getPyClassByPyKeywordArgument(element, context) - ?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) } + getPydanticModelByPyKeywordArgument(element, true,context) +// ?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) } ?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) } } } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt b/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt index 81f285bc..a38c6291 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt @@ -37,9 +37,8 @@ class PydanticTypeCheckerInspection : PyTypeCheckerInspection() { private val pydanticConfigService = PydanticConfigService.getInstance(holder!!.project) override fun visitPyCallExpression(node: PyCallExpression) { - val pyClass = getPyClassByPyCallExpression(node, true, myTypeEvalContext) - getPyClassByPyCallExpression(node, true, myTypeEvalContext) - if (pyClass is PyClass && isPydanticModel(pyClass, true, myTypeEvalContext)) { + val pyClass = getPydanticPyClass(node, myTypeEvalContext, true) + if (pyClass is PyClass) { checkCallSiteForPydantic(node) return } diff --git a/testData/inspection/acceptsOnlyKeywordArguments.py b/testData/inspection/acceptsOnlyKeywordArguments.py index a284f1ba..98e6962c 100644 --- a/testData/inspection/acceptsOnlyKeywordArguments.py +++ b/testData/inspection/acceptsOnlyKeywordArguments.py @@ -5,7 +5,7 @@ class A(BaseModel): a: str -A('a') +A('a') @dataclass class B(): diff --git a/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py b/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py index a46bde11..eb4bd293 100644 --- a/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py +++ b/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py @@ -6,7 +6,7 @@ class A(BaseModel): a: str -A(*['a']) +A(*['a']) @dataclass diff --git a/testData/typeinspection/ignoreInitArguments.py b/testData/typeinspection/ignoreInitArguments.py index 5a166478..aa4d778b 100644 --- a/testData/typeinspection/ignoreInitArguments.py +++ b/testData/typeinspection/ignoreInitArguments.py @@ -9,6 +9,6 @@ class A(BaseModel): def __init__(self, xyz: str): super().__init__(a=xyz) -A(xyz=123) +A(xyz=123) A(a=123) \ No newline at end of file diff --git a/testData/typeinspectionv18/sqlModel.py b/testData/typeinspectionv18/sqlModel.py index b4147894..ac3867f5 100644 --- a/testData/typeinspectionv18/sqlModel.py +++ b/testData/typeinspectionv18/sqlModel.py @@ -13,6 +13,6 @@ class Hero(SQLModel, table=True): hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador") hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48) -hero_4 = Hero(secret_name="test", ) +hero_4 = Hero(secret_name="test") hero_5 = Hero(name=123, secret_name=456, age="abc") \ No newline at end of file From 97f9c73f7fa6fd0d9d73c480cc514b13fcfb12ff Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 28 Feb 2023 01:38:21 +0900 Subject: [PATCH 02/12] Refactor methods --- src/com/koxudaxi/pydantic/Pydantic.kt | 13 ++++-- .../pydantic/PydanticCompletionContributor.kt | 13 ++---- .../koxudaxi/pydantic/PydanticInspection.kt | 43 ++++++------------- 3 files changed, 27 insertions(+), 42 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 84b602cb..14adc82a 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -481,6 +481,9 @@ fun getPyClassByAttribute(pyPsiElement: PsiElement?): PyClass? { return pyPsiElement?.parent?.parent as? PyClass } +fun getPydanticModelByAttribute(pyPsiElement: PsiElement?, includeDataclass: Boolean, context: TypeEvalContext): PyClass? = + getPyClassByAttribute(pyPsiElement)?.takeIf { isPydanticModel(it, includeDataclass, context) } + fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: TypeEvalContext): PyClassTypeImpl? { var psiElement = getPsiElementByQualifiedName(QualifiedName.fromDottedString(qualifiedName), project, context) if (psiElement == null) { @@ -491,11 +494,13 @@ fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: Type return PyClassTypeImpl.createTypeByQName(psiElement, qualifiedName, false) } -fun getPydanticPyClass(pyCallExpression: PyCallExpression, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? { - return context.getType(pyCallExpression)?.pyClassTypes?.firstOrNull { +fun getPydanticPyClass(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? = + getPydanticPyClassType(pyTypedElement, context, includeDataclass)?.pyClass + +fun getPydanticPyClassType(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClassType? = + context.getType(pyTypedElement)?.pyClassTypes?.firstOrNull { isPydanticModel(it.pyClass, includeDataclass, context) - }?.pyClass -} + } fun getAncestorPydanticModels(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): List { return pyClass.getAncestorClasses(context).filter { isPydanticModel(it, includeDataclass, context) } diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 72d15897..7e860959 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -313,11 +313,9 @@ class PydanticCompletionContributor : CompletionContributor() { val typeEvalContext = parameters.getTypeEvalContext() val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return - val pyType = typeEvalContext.getType(pyTypedElement) ?: return - val pyClassType = - pyType.pyClassTypes.firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) } - ?: return + val pyClassType = getPydanticPyClassType(pyTypedElement, typeEvalContext, true) ?: return + val pyClass = pyClassType.pyClass val config = getConfig(pyClass, typeEvalContext, true) if (pyClassType.isDefinition) { // class @@ -377,9 +375,8 @@ class PydanticCompletionContributor : CompletionContributor() { ) { val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return if (!configClass.isConfigClass) return - val pydanticModel = getPyClassByAttribute(configClass) ?: return val typeEvalContext = parameters.getTypeEvalContext() - if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return + if (getPydanticModelByAttribute(configClass,true, parameters.getTypeEvalContext()) == null) return val definedSet = configClass.classAttributes @@ -404,9 +401,7 @@ class PydanticCompletionContributor : CompletionContributor() { context: ProcessingContext, result: CompletionResultSet, ) { - val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return - val typeEvalContext = parameters.getTypeEvalContext() - if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return + val pydanticModel = getPydanticModelByAttribute(parameters.position.parent?.parent, true, parameters.getTypeEvalContext()) ?: return if (pydanticModel.findNestedClass("Config", false) != null) return val element = PrioritizedLookupElement.withGrouping( LookupElementBuilder diff --git a/src/com/koxudaxi/pydantic/PydanticInspection.kt b/src/com/koxudaxi/pydantic/PydanticInspection.kt index 71956a80..4d75cdfb 100644 --- a/src/com/koxudaxi/pydantic/PydanticInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticInspection.kt @@ -35,8 +35,8 @@ class PydanticInspection : PyInspection() { override fun visitPyFunction(node: PyFunction) { super.visitPyFunction(node) - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, true, myTypeEvalContext) || !node.isValidatorMethod) return + if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return + if (!node.isValidatorMethod) return val paramList = node.parameterList val params = paramList.parameters val firstParam = params.firstOrNull() @@ -188,23 +188,9 @@ class PydanticInspection : PyInspection() { val resolveContext = PyResolveContext.defaultContext(myTypeEvalContext) val pyCallable = pyCallExpression.multiResolveCalleeFunction(resolveContext).firstOrNull() ?: return if (pyCallable.asMethod()?.qualifiedName != "pydantic.main.BaseModel.from_orm") return - val type = - (pyCallExpression.node?.firstChildNode?.firstChildNode?.psi as? PyTypedElement)?.getType( - myTypeEvalContext - ) - ?: return - val pyClass = when (type) { - is PyClass -> type - is PyClassType -> type.pyClassTypes.firstOrNull { - isPydanticModel( - it.pyClass, - false, myTypeEvalContext - ) - }?.pyClass + val typedElement = pyCallExpression.node?.firstChildNode?.firstChildNode?.psi as? PyTypedElement ?: return + val pyClass = getPydanticPyClass(typedElement, myTypeEvalContext, false) ?: return - else -> null - } ?: return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return val config = getConfig(pyClass, myTypeEvalContext, true) if (config["orm_mode"] != true) { registerProblem( @@ -229,11 +215,11 @@ class PydanticInspection : PyInspection() { } private fun inspectReadOnlyProperty(node: PyAssignmentStatement) { - val pyType = - (node.leftHandSideExpression?.firstChild as? PyTypedElement)?.getType(myTypeEvalContext) ?: return - if ((pyType as? PyClassTypeImpl)?.isDefinition == true) return - val pyClass = pyType.pyClassTypes.firstOrNull()?.pyClass ?: return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return + val pyTypedElement = + node.leftHandSideExpression?.firstChild as? PyTypedElement ?: return + val pyClassType = getPydanticPyClassType(pyTypedElement, myTypeEvalContext, false) ?: return + if (pyClassType.isDefinition) return + val pyClass = pyClassType.pyClass val attributeName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.name ?: return val config = getConfig(pyClass, myTypeEvalContext, true) val version = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext) @@ -247,8 +233,7 @@ class PydanticInspection : PyInspection() { } private fun inspectWarnUntypedFields(node: PyAssignmentStatement) { - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return + if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return if (node.annotation != null) return if ((node.leftHandSideExpression as? PyTargetExpressionImpl)?.text?.isValidFieldName != true) return registerProblem( @@ -259,8 +244,8 @@ class PydanticInspection : PyInspection() { } private fun inspectCustomRootField(node: PyAssignmentStatement) { - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return + val pyClass = getPydanticModelByAttribute(node, false, myTypeEvalContext) ?: return + val fieldName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.text ?: return if (fieldName.startsWith('_')) return val rootModel = pyClass.findClassAttribute("__root__", true, myTypeEvalContext)?.containingClass ?: return @@ -302,8 +287,8 @@ class PydanticInspection : PyInspection() { } private fun inspectAnnotatedAssignedField(node: PyAssignmentStatement) { - val pyClass = getPyClassByAttribute(node) ?: return - if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return + if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return + val fieldName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.text ?: return val assignedValue = node.assignedValue From 7e86c75ad97ff2cc5cdc6f755e54c7ad99211455 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 02:37:16 +0900 Subject: [PATCH 03/12] Drop fake dataclass return type --- src/com/koxudaxi/pydantic/Pydantic.kt | 4 ++ .../pydantic/PydanticDataclassTypeProvider.kt | 61 ++----------------- 2 files changed, 10 insertions(+), 55 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 14adc82a..6e1d6624 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -68,6 +68,9 @@ val CUSTOM_BASE_MODEL_Q_NAMES = listOf( val CUSTOM_MODEL_FIELD_Q_NAMES = listOf( SQL_MODEL_FIELD_Q_NAME ) + +val DATA_CLASS_Q_NAMES = listOf(DATA_CLASS_Q_NAME, DATA_CLASS_SHORT_Q_NAME) + val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME) val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME) @@ -215,6 +218,7 @@ internal val PyClass.isConfigClass: Boolean get() = name == "Config" internal val PyFunction.isConStr: Boolean get() = qualifiedName == CON_STR_Q_NAME +internal val PyFunction.isPydanticDataclass: Boolean get() = qualifiedName in DATA_CLASS_Q_NAMES internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression): Boolean { val pyKeywordArgument = stringLiteralExpression.parent as? PyKeywordArgument ?: return false if (pyKeywordArgument.keyword != "regex") return false diff --git a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt index 90391ef2..fedf8e38 100644 --- a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt @@ -1,11 +1,9 @@ package com.koxudaxi.pydantic -import com.intellij.openapi.util.Ref import com.intellij.psi.PsiElement import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyCallExpressionImpl -import com.jetbrains.python.psi.impl.PyCallExpressionNavigator import com.jetbrains.python.psi.types.* /** @@ -20,30 +18,13 @@ import com.jetbrains.python.psi.types.* */ class PydanticDataclassTypeProvider : PyTypeProviderBase() { private val pyDataclassTypeProvider = PyDataclassTypeProvider() - private val pydanticTypeProvider = PydanticTypeProvider() - override fun getReferenceType( - referenceTarget: PsiElement, - context: TypeEvalContext, - anchor: PsiElement? - ): Ref? { - return when { - referenceTarget is PyClass && referenceTarget.isPydanticDataclass -> - getPydanticDataclassType(referenceTarget, context, anchor as? PyCallExpression, true) - - referenceTarget is PyTargetExpression -> (referenceTarget as? PyTypedElement) - ?.getType(context)?.pyClassTypes - ?.filter { pyClassType -> pyClassType.pyClass.isPydanticDataclass } - ?.firstNotNullOfOrNull { pyClassType -> - getPydanticDataclassType( - pyClassType.pyClass, - context, - anchor as? PyCallExpression, - pyClassType.isDefinition - ) - } - else ->null - }?.let { Ref.create(it) } + override fun getCallableType(callable: PyCallable, context: TypeEvalContext): PyType? { + if (callable is PyFunction && callable.isPydanticDataclass) { + // Drop fake dataclass return type + return PyCallableTypeImpl(callable.getParameters(context), null) + } + return super.getCallableType(callable, context) } internal fun getDataclassCallableType( @@ -57,34 +38,4 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() { callSite ?: PyCallExpressionImpl(referenceTarget.node) )?.get() as? PyCallableType } - - private fun getPydanticDataclassType( - referenceTarget: PsiElement, - context: TypeEvalContext, - callSite: PyCallExpression?, - definition: Boolean, - ): PyType? { - val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null - - val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null - if (!dataclassType.pyClass.isPydanticDataclass) return null - val ellipsis = PyElementGenerator.getInstance(referenceTarget.project).createEllipsis() - val injectedPyCallableType = PyCallableTypeImpl( - dataclassCallableType.getParameters(context)?.map { - when { - it.defaultValueText == "..." && it.defaultValue is PyNoneLiteralExpression -> - pydanticTypeProvider.injectDefaultValue(dataclassType.pyClass, it, ellipsis, null, context) - ?: it - - else -> it - } - }, dataclassType - ) - val injectedDataclassType = (injectedPyCallableType).getReturnType(context) as? PyClassType ?: return null - return when { - callSite is PyCallExpression && definition -> injectedPyCallableType - definition -> injectedDataclassType.toClass() - else -> injectedDataclassType - } - } } From 7c991ea73c14e39deef98ae5c37f8538535c5672 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 02:51:31 +0900 Subject: [PATCH 04/12] Fix insert args --- src/com/koxudaxi/pydantic/PydanticAnnotator.kt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt index 10f0a57d..2a519595 100644 --- a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt +++ b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt @@ -6,12 +6,12 @@ import com.intellij.openapi.util.TextRange import com.intellij.util.containers.nullize import com.jetbrains.python.psi.PyCallExpression import com.jetbrains.python.psi.PyStarArgument +import com.jetbrains.python.psi.types.PyCallableType import com.jetbrains.python.psi.types.TypeEvalContext import com.jetbrains.python.validation.PyAnnotator class PydanticAnnotator : PyAnnotator() { - private val pydanticTypeProvider = PydanticTypeProvider() override fun visitPyCallExpression(node: PyCallExpression) { super.visitPyCallExpression(node) annotatePydanticModelCallableExpression(node) @@ -19,14 +19,14 @@ class PydanticAnnotator : PyAnnotator() { private fun annotatePydanticModelCallableExpression(pyCallExpression: PyCallExpression) { val context = TypeEvalContext.codeAnalysis(pyCallExpression.project, pyCallExpression.containingFile) + val pyClassType = pyCallExpression.callee?.getType(context) as? PyCallableType ?: return + val pyClass = pyClassType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return + if (!isPydanticModel(pyClass, true, context)) return + if (getPydanticModelInit(pyClass, context) != null) return if (!pyCallExpression.isDefinitionCallExpression(context)) return - val pyClassType = pyCallExpression.callee?.getType(context)?.pyClassTypes?.firstOrNull() - val pyClass = pyClassType?.pyClass ?: return - if (getPydanticModelInit(pyClassType.pyClass, context) != null) return - val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, pyCallExpression) ?: return val unFilledArguments = - getPydanticUnFilledArguments(pydanticType, pyCallExpression, context).nullize() + getPydanticUnFilledArguments(pyClassType, pyCallExpression, context).nullize() ?: return holder.newSilentAnnotation(HighlightSeverity.INFORMATION).withFix(PydanticInsertArgumentsQuickFix(false)) .create() From 6b77de43a41c2b13d0ee74c1bb80d66bcf35fe38 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 03:39:23 +0900 Subject: [PATCH 05/12] Fix insert args --- src/com/koxudaxi/pydantic/Pydantic.kt | 5 +++++ src/com/koxudaxi/pydantic/PydanticAnnotator.kt | 4 ++-- .../pydantic/PydanticInsertArgumentsQuickFix.kt | 12 ++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 6e1d6624..d4e1bf1b 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -655,3 +655,8 @@ fun getPydanticModelInit(pyClass: PyClass, context: TypeEvalContext): PyFunction fun PyCallExpression.isDefinitionCallExpression(context: TypeEvalContext): Boolean = this.callee?.reference?.resolve()?.let { it as? PyClass }?.getType(context)?.isDefinition == true + +fun PyCallExpression.getPyCallableType(context: TypeEvalContext): PyCallableType? = + this.callee?.getType(context) as? PyCallableType +fun PyCallableType.getPydanticModel(includeDataclass: Boolean, context: TypeEvalContext): PyClass? = + this.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass?.takeIf { isPydanticModel(it,includeDataclass, context) } diff --git a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt index 2a519595..0499895f 100644 --- a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt +++ b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt @@ -19,8 +19,8 @@ class PydanticAnnotator : PyAnnotator() { private fun annotatePydanticModelCallableExpression(pyCallExpression: PyCallExpression) { val context = TypeEvalContext.codeAnalysis(pyCallExpression.project, pyCallExpression.containingFile) - val pyClassType = pyCallExpression.callee?.getType(context) as? PyCallableType ?: return - val pyClass = pyClassType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return + val pyClassType = pyCallExpression.getPyCallableType(context) ?: return + val pyClass = pyClassType.getPydanticModel(true, context) ?: return if (!isPydanticModel(pyClass, true, context)) return if (getPydanticModelInit(pyClass, context) != null) return if (!pyCallExpression.isDefinitionCallExpression(context)) return diff --git a/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt b/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt index ae173d42..a38800b2 100644 --- a/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt +++ b/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt @@ -12,7 +12,6 @@ import com.intellij.psi.PsiFile import com.intellij.util.IncorrectOperationException import com.intellij.util.containers.nullize import com.jetbrains.python.psi.* -import com.jetbrains.python.psi.types.PyCallableParameter import com.jetbrains.python.psi.types.TypeEvalContext class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : LocalQuickFix, IntentionAction, @@ -48,14 +47,11 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local if (originalElement !is PyCallExpression) return null if (file !is PyFile) return null val newEl = originalElement.copy() as PyCallExpression - val pyClass = getPydanticPyClass(originalElement, context, true) ?: return null - val pydanticType = if (pyClass.isPydanticDataclass) { - pydanticDataclassTypeProvider.getDataclassCallableType(pyClass, context, originalElement) - } else { - pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, originalElement) ?: return null - } ?: return null + val pyCallableType = originalElement.getPyCallableType(context) ?: return null + val pyClass = pyCallableType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return null + if (!isPydanticModel(pyClass, true, context)) return null val unFilledArguments = - getPydanticUnFilledArguments(pydanticType, originalElement, context).let { + getPydanticUnFilledArguments(pyCallableType, originalElement, context).let { when { onlyRequired -> it.filter { arguments -> arguments.required } else -> it From 5c1f5639bbc84cddd5f5fae8e8822016bac1b609 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 14:51:26 +0900 Subject: [PATCH 06/12] Fix insert args for dataclass --- src/com/koxudaxi/pydantic/Pydantic.kt | 22 +++++++++++++++---- .../koxudaxi/pydantic/PydanticAnnotator.kt | 2 +- .../PydanticInsertArgumentsQuickFix.kt | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index d4e1bf1b..f4c04608 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -531,15 +531,29 @@ fun addKeywordArgument(pyCallExpression: PyCallExpression, pyKeywordArgument: Py } } +val PyExpression.isKeywordArgument: Boolean get() = + this is PyKeywordArgument || (this as? PyStarArgument)?.isKeyword == true + fun getPydanticUnFilledArguments( pydanticType: PyCallableType, pyCallExpression: PyCallExpression, context: TypeEvalContext, + isDataClass: Boolean ): List { - val currentArguments = - pyCallExpression.arguments.filter { it is PyKeywordArgument || (it as? PyStarArgument)?.isKeyword == true } - .mapNotNull { it.name }.toSet() - return pydanticType.getParameters(context)?.filterNot { currentArguments.contains(it.name) } ?: emptyList() + val parameters = pydanticType.getParameters(context)?.let { allParameters -> + if (isDataClass) { + pyCallExpression.arguments + .filterNot { it.isKeywordArgument } + .takeIf { it.isNotEmpty() } + ?.let { allParameters.drop(it.size - 1) } + ?: allParameters + } else { + allParameters + } + } ?: listOf() + + val currentArguments = pyCallExpression.arguments.filter { it.isKeywordArgument }.mapNotNull { it.name }.toSet() + return parameters.filterNot { currentArguments.contains(it.name) } } val PyCallableParameter.required: Boolean diff --git a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt index 0499895f..11d22421 100644 --- a/src/com/koxudaxi/pydantic/PydanticAnnotator.kt +++ b/src/com/koxudaxi/pydantic/PydanticAnnotator.kt @@ -26,7 +26,7 @@ class PydanticAnnotator : PyAnnotator() { if (!pyCallExpression.isDefinitionCallExpression(context)) return val unFilledArguments = - getPydanticUnFilledArguments(pyClassType, pyCallExpression, context).nullize() + getPydanticUnFilledArguments(pyClassType, pyCallExpression, context, pyClass.isPydanticDataclass).nullize() ?: return holder.newSilentAnnotation(HighlightSeverity.INFORMATION).withFix(PydanticInsertArgumentsQuickFix(false)) .create() diff --git a/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt b/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt index a38800b2..95c1270a 100644 --- a/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt +++ b/src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt @@ -51,7 +51,7 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local val pyClass = pyCallableType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return null if (!isPydanticModel(pyClass, true, context)) return null val unFilledArguments = - getPydanticUnFilledArguments(pyCallableType, originalElement, context).let { + getPydanticUnFilledArguments(pyCallableType, originalElement, context, pyClass.isPydanticDataclass).let { when { onlyRequired -> it.filter { arguments -> arguments.required } else -> it From 6b56aeba86e691b20cdf5e3b5817315060a4728b Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 15:03:46 +0900 Subject: [PATCH 07/12] Fix insert args for dataclass --- src/com/koxudaxi/pydantic/Pydantic.kt | 4 +--- testData/inspection/acceptsOnlyKeywordArguments.py | 2 +- .../acceptsOnlyKeywordArgumentsSingleStarArgument.py | 2 +- testData/typeinspection/ignoreInitArguments.py | 2 +- testData/typeinspectionv18/sqlModel.py | 2 +- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index f4c04608..2669b360 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -544,9 +544,7 @@ fun getPydanticUnFilledArguments( if (isDataClass) { pyCallExpression.arguments .filterNot { it.isKeywordArgument } - .takeIf { it.isNotEmpty() } - ?.let { allParameters.drop(it.size - 1) } - ?: allParameters + .let { allParameters.drop(it.size) } } else { allParameters } diff --git a/testData/inspection/acceptsOnlyKeywordArguments.py b/testData/inspection/acceptsOnlyKeywordArguments.py index 98e6962c..a284f1ba 100644 --- a/testData/inspection/acceptsOnlyKeywordArguments.py +++ b/testData/inspection/acceptsOnlyKeywordArguments.py @@ -5,7 +5,7 @@ class A(BaseModel): a: str -A('a') +A('a') @dataclass class B(): diff --git a/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py b/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py index eb4bd293..a46bde11 100644 --- a/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py +++ b/testData/inspection/acceptsOnlyKeywordArgumentsSingleStarArgument.py @@ -6,7 +6,7 @@ class A(BaseModel): a: str -A(*['a']) +A(*['a']) @dataclass diff --git a/testData/typeinspection/ignoreInitArguments.py b/testData/typeinspection/ignoreInitArguments.py index aa4d778b..5a166478 100644 --- a/testData/typeinspection/ignoreInitArguments.py +++ b/testData/typeinspection/ignoreInitArguments.py @@ -9,6 +9,6 @@ class A(BaseModel): def __init__(self, xyz: str): super().__init__(a=xyz) -A(xyz=123) +A(xyz=123) A(a=123) \ No newline at end of file diff --git a/testData/typeinspectionv18/sqlModel.py b/testData/typeinspectionv18/sqlModel.py index ac3867f5..46b3b0b2 100644 --- a/testData/typeinspectionv18/sqlModel.py +++ b/testData/typeinspectionv18/sqlModel.py @@ -13,6 +13,6 @@ class Hero(SQLModel, table=True): hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador") hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48) -hero_4 = Hero(secret_name="test") +hero_4 = Hero(secret_name="test") hero_5 = Hero(name=123, secret_name=456, age="abc") \ No newline at end of file From 9cdc7c1b0a7e50fa0a4b9315f301a7143b689582 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 15:11:56 +0900 Subject: [PATCH 08/12] Fix unittest --- testData/typeinspectionv18/dataclass.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/testData/typeinspectionv18/dataclass.py b/testData/typeinspectionv18/dataclass.py index 994f6648..4dbfe83d 100644 --- a/testData/typeinspectionv18/dataclass.py +++ b/testData/typeinspectionv18/dataclass.py @@ -31,23 +31,23 @@ class ChildDataclass(MyDataclass): ChildDataclass(a=2, b='orange', c=4, d='cherry') -a: MyDataclass = MyDataclass() +a: MyDataclass = MyDataclass() b: Type[MyDataclass] = MyDataclass c: MyDataclass = MyDataclass -d: Type[MyDataclass] = MyDataclass() +d: Type[MyDataclass] = MyDataclass() -aa: Union[str, MyDataclass] = MyDataclass() +aa: Union[str, MyDataclass] = MyDataclass() bb: Union[str, Type[MyDataclass]] = MyDataclass cc: Union[str, MyDataclass] = MyDataclass -dd: Union[str, Type[MyDataclass]] = MyDataclass() +dd: Union[str, Type[MyDataclass]] = MyDataclass() -aaa: ChildDataclass = ChildDataclass() +aaa: ChildDataclass = ChildDataclass() bbb: Type[ChildDataclass] = ChildDataclass ccc: ChildDataclass = ChildDataclass -ddd: Type[ChildDataclass] = ChildDataclass() +ddd: Type[ChildDataclass] = ChildDataclass() e: str = MyDataclass(a='apple', b=1).a @@ -79,7 +79,7 @@ class ChildDataclass(MyDataclass): mm: str = ii.d def my_fn_1() -> MyDataclass: - return MyDataclass() + return MyDataclass() def my_fn_2() -> Type[MyDataclass]: return MyDataclass @@ -88,10 +88,10 @@ def my_fn_3() -> MyDataclass: return MyDataclass def my_fn_4() -> Type[MyDataclass]: - return MyDataclass() + return MyDataclass() def my_fn_5() -> Union[str, MyDataclass]: - return MyDataclass() + return MyDataclass() def my_fn_6() -> Type[str, MyDataclass]: return MyDataclass @@ -100,10 +100,10 @@ def my_fn_7() -> Union[str, MyDataclass]: return MyDataclass def my_fn_8() -> Union[str, Type[MyDataclass]]: - return MyDataclass() + return MyDataclass() def my_fn_9() -> ChildDataclass: - return ChildDataclass() + return ChildDataclass() def my_fn_10() -> Type[ChildDataclass]: return ChildDataclass @@ -112,10 +112,10 @@ def my_fn_11() -> ChildDataclass: return ChildDataclass def my_fn_12() -> Type[ChildDataclass]: - return ChildDataclass() + return ChildDataclass() def my_fn_13() -> Union[str, ChildDataclass]: - return ChildDataclass() + return ChildDataclass() def my_fn_14() -> Type[str, ChildDataclass]: return ChildDataclass @@ -124,4 +124,4 @@ def my_fn_7() -> Union[str, ChildDataclass]: return ChildDataclass def my_fn_8() -> Union[str, Type[ChildDataclass]]: - return ChildDataclass() + return ChildDataclass() From 35507af9d9d020fdaf3aa2615a50d347ca5fc55e Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 15:13:05 +0900 Subject: [PATCH 09/12] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7ebee91..a5eef50a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased - Fix wrong inspections when a model has a __call__ method [[#655](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/655)] - Reduce unnecessary resolve in type providers [[#656](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/656)] +- Optimize resolving pydantic class [[#658](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/658)] ## 0.3.17 - 2022-12-16 - Support Union operator [[#602](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/602)] From 5dfe37bfde59be6a906e1371238db52dc89f5b46 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 15:40:24 +0900 Subject: [PATCH 10/12] Improve getResolvedPsiElements --- src/com/koxudaxi/pydantic/Pydantic.kt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 2669b360..9038b6d3 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -261,14 +261,14 @@ private fun getAliasedFieldName( fun getResolvedPsiElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): List { return RecursionManager.doPreventingRecursion( - Pair.create( + Pair.create( referenceExpression, context ), false ) { - PyUtil.multiResolveTopPriority( - referenceExpression, - PyResolveContext.defaultContext(context) + val resolveContext = PyResolveContext.defaultContext(context) + PyUtil.filterTopPriorityResults( + referenceExpression.getReference(resolveContext).multiResolve(false) ) } ?: emptyList() } From 82d1cd050ea1339a28281caabe429ff299470805 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 17:18:04 +0900 Subject: [PATCH 11/12] Add unittest --- testData/inspection/acceptsOnlyKeywordArguments.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/testData/inspection/acceptsOnlyKeywordArguments.py b/testData/inspection/acceptsOnlyKeywordArguments.py index a284f1ba..3642feb7 100644 --- a/testData/inspection/acceptsOnlyKeywordArguments.py +++ b/testData/inspection/acceptsOnlyKeywordArguments.py @@ -22,3 +22,17 @@ def __call__(self, *args, **kwargs): c = C(a='abc') c('a') + +@dataclass +class D(): + a: str + b: str + + +D('a') + + +class E(BaseModel): + pass + +E() From 12bea4c7da353f99a68b33e5b327813a05554ce9 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 1 Mar 2023 17:43:58 +0900 Subject: [PATCH 12/12] Remove comment --- src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt index 43b2a4b7..770590a7 100644 --- a/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt +++ b/src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt @@ -23,7 +23,6 @@ class PydanticFieldSearchExecutor : QueryExecutorBase val context = TypeEvalContext.userInitiated(element.project, element.containingFile) getPydanticModelByPyKeywordArgument(element, true,context) -// ?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) } ?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) } } }