From 91c642834543cfa32ffb2adceb6975d24a1b9f0e Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 13 Dec 2022 03:06:12 +0900 Subject: [PATCH 1/2] Improve dataclass default Value detection --- .../pydantic/PydanticDataclassTypeProvider.kt | 43 ++++++++++++++++--- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 16 ++++--- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt index 8e71077e..90382cee 100644 --- a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt @@ -19,13 +19,15 @@ import com.jetbrains.python.psi.types.* */ class PydanticDataclassTypeProvider : PyTypeProviderBase() { private val pyDataclassTypeProvider = PyDataclassTypeProvider() - + private val pydanticTypeProvider = PydanticTypeProvider() override fun getReferenceExpressionType( referenceExpression: PyReferenceExpression, context: TypeEvalContext, ): PyType? { - return getPydanticDataclass(referenceExpression, - TypeEvalContext.codeInsightFallback(referenceExpression.project)) + return getPydanticDataclass( + referenceExpression, + TypeEvalContext.codeInsightFallback(referenceExpression.project) + ) } @@ -49,16 +51,44 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() { ): PyType? { val callSite = PyCallExpressionNavigator.getPyCallExpressionByCallee(pyReferenceExpression) val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null + val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null if (!dataclassType.pyClass.isPydanticDataclass) return null + val ellipsis = PyElementGenerator.getInstance(referenceTarget.project).createEllipsis() + val injectedPyCallableType = PyCallableTypeImpl( + dataclassCallableType.getParameters(context)?.map { + when { + it.defaultValueText == "..." && it.defaultValue is PyNoneLiteralExpression -> + injectDefaultValue(dataclassType.pyClass, it, ellipsis, context) ?: it + else -> it + } + }, dataclassType + ) + val injectedDataclassType = (injectedPyCallableType).getReturnType(context) as? PyClassType ?: return null return when { - callSite is PyCallExpression && definition -> dataclassCallableType - definition -> dataclassType.toClass() - else -> dataclassType + callSite is PyCallExpression && definition -> injectedPyCallableType + definition -> injectedDataclassType.toClass() + else -> injectedDataclassType } } + private fun injectDefaultValue( + pyClass: PyClass, + pyCallableParameter: PyCallableParameter, + ellipsis: PyNoneLiteralExpression, + context: TypeEvalContext + ): PyCallableParameter? { + val name = pyCallableParameter.name ?: return null + val attribute = pyClass.findClassAttribute(name, true, context) ?: return null + val defaultValue = + pydanticTypeProvider.getDefaultValueByAssignedValue(attribute, ellipsis, context, null, true) + return PyCallableParameterImpl.nonPsi( + name, + pyCallableParameter.getArgumentType(context), + defaultValue + ) + } private fun getPydanticDataclass(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? { return getResolvedPsiElements(referenceExpression, context) @@ -67,6 +97,7 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() { when { it is PyClass && it.isPydanticDataclass -> getPydanticDataclassType(it, context, referenceExpression, true) + it is PyTargetExpression -> (it as? PyTypedElement) ?.getType(context)?.pyClassTypes ?.filter { pyClassType -> pyClassType.pyClass.isPydanticDataclass } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 82d1f067..e4f5af98 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -701,7 +701,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { } - private fun getDefaultValueByAssignedValue( + internal fun getDefaultValueByAssignedValue( field: PyTargetExpression, ellipsis: PyNoneLiteralExpression, context: TypeEvalContext, @@ -721,7 +721,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { if (isDataclass) { resolveResults .any { - it.isDataclassField + it.isDataclassField || it.isPydanticField } .let { return when { @@ -760,9 +760,13 @@ class PydanticTypeProvider : PyTypeProviderBase() { private fun getDefaultValueForDataclass( assignedValue: PyCallExpression, context: TypeEvalContext, - argumentName: String, + argumentName: String?, ): PyExpression? { - val defaultValue = assignedValue.getKeywordArgument(argumentName) + val defaultValue = if (argumentName is String) { + assignedValue.getKeywordArgument(argumentName) + } else { + assignedValue.argumentList?.arguments?.firstOrNull().takeIf { it !is PyKeywordArgument } + } return when { defaultValue == null -> null defaultValue.text == "..." -> null @@ -786,9 +790,9 @@ class PydanticTypeProvider : PyTypeProviderBase() { val defaultValue = getDefaultValueForDataclass(assignedValue, context, "default") val defaultFactoryValue = getDefaultValueForDataclass(assignedValue, context, "default_factory") return when { - defaultValue == null && defaultFactoryValue == null -> null defaultValue != null -> defaultValue - else -> defaultFactoryValue + defaultFactoryValue != null -> defaultFactoryValue + else -> getDefaultValueForDataclass(assignedValue, context, null) } } } From e9f3a3666a76a4e22f6602a805e7ec6d22518e16 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 13 Dec 2022 23:51:19 +0900 Subject: [PATCH 2/2] Add unittest --- testData/completion/dataclassKeywordArgument.py | 5 ++++- testSrc/com/koxudaxi/pydantic/PydanticCompletionTest.kt | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/testData/completion/dataclassKeywordArgument.py b/testData/completion/dataclassKeywordArgument.py index 11d4bf24..a379351e 100644 --- a/testData/completion/dataclassKeywordArgument.py +++ b/testData/completion/dataclassKeywordArgument.py @@ -1,7 +1,8 @@ from dataclasses import field, MISSING -from pydantic.dataclasses import dataclass +from pydantic.dataclasses import dataclass, Field + def dummy(): return '123' @@ -20,6 +21,8 @@ class A: cda: str = field(default=MISSING, default_factory=MISSING) edc: str = dummy() gef: str = field(default=unresolved) + jih: str = field(..., title="empty", ) + mlk: str = field(..., title="empty", ) @dataclass class B(A): diff --git a/testSrc/com/koxudaxi/pydantic/PydanticCompletionTest.kt b/testSrc/com/koxudaxi/pydantic/PydanticCompletionTest.kt index 461703b9..135c12e2 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticCompletionTest.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticCompletionTest.kt @@ -721,7 +721,9 @@ open class PydanticCompletionTest : PydanticTestCase() { Pair("efg=", "str='xyz' A"), Pair("gef=", "str=unresolved A"), Pair("hij=", "str=lambda :'asd' A"), + Pair("jih=", "str A"), Pair("klm=", "str='qwe' A"), + Pair("mlk=", "str A"), Pair("qrs=", "str='fgh' A"), Pair("tuw=", "str A"), Pair("xyz=", "str A")