Skip to content

Commit

Permalink
Merge pull request #658 from koxudaxi/optimize_resolve
Browse files Browse the repository at this point in the history
Optimize resolving pydantic class
  • Loading branch information
koxudaxi committed Mar 1, 2023
2 parents 03174d7 + 12bea4c commit 04f5218
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 148 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased
- Fix wrong inspections when a model has a __call__ method [[#655](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/655)]
- Reduce unnecessary resolve in type providers [[#656](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/656)]
- Optimize resolving pydantic class [[#658](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/658)]

## 0.3.17 - 2022-12-16
- Support Union operator [[#602](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/602)]
Expand Down
73 changes: 43 additions & 30 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ val CUSTOM_BASE_MODEL_Q_NAMES = listOf(
val CUSTOM_MODEL_FIELD_Q_NAMES = listOf(
SQL_MODEL_FIELD_Q_NAME
)

val DATA_CLASS_Q_NAMES = listOf(DATA_CLASS_Q_NAME, DATA_CLASS_SHORT_Q_NAME)

val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME)
Expand Down Expand Up @@ -135,27 +138,14 @@ const val CUSTOM_ROOT_FIELD = "__root__"

fun PyTypedElement.getType(context: TypeEvalContext): PyType? = context.getType(this)

fun getPyClassByPyCallExpression(
pyCallExpression: PyCallExpression,

fun getPydanticModelByPyKeywordArgument(
pyKeywordArgument: PyKeywordArgument,
includeDataclass: Boolean,
context: TypeEvalContext,
): PyClass? {
val callee = pyCallExpression.callee ?: return null
val pyType = when (val type = callee.getType(context)) {
is PyClass -> return type
is PyClassType -> type
else -> (callee.reference?.resolve() as? PyTypedElement)?.getType(context) ?: return null
}
return pyType.pyClassTypes.firstOrNull {
isPydanticModel(it.pyClass,
includeDataclass,
context)
}?.pyClass
}

fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null
return getPyClassByPyCallExpression(pyCallExpression, true, context)
return getPydanticPyClass(pyCallExpression, context, includeDataclass)
}

fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean {
Expand Down Expand Up @@ -228,6 +218,7 @@ internal val PyClass.isConfigClass: Boolean get() = name == "Config"

internal val PyFunction.isConStr: Boolean get() = qualifiedName == CON_STR_Q_NAME

internal val PyFunction.isPydanticDataclass: Boolean get() = qualifiedName in DATA_CLASS_Q_NAMES
internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression): Boolean {
val pyKeywordArgument = stringLiteralExpression.parent as? PyKeywordArgument ?: return false
if (pyKeywordArgument.keyword != "regex") return false
Expand Down Expand Up @@ -270,14 +261,14 @@ private fun getAliasedFieldName(

fun getResolvedPsiElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): List<PsiElement> {
return RecursionManager.doPreventingRecursion(
Pair.create<PsiElement, TypeEvalContext>(
Pair.create(
referenceExpression,
context
), false
) {
PyUtil.multiResolveTopPriority(
referenceExpression,
PyResolveContext.defaultContext(context)
val resolveContext = PyResolveContext.defaultContext(context)
PyUtil.filterTopPriorityResults(
referenceExpression.getReference(resolveContext).multiResolve(false)
)
} ?: emptyList()
}
Expand Down Expand Up @@ -494,6 +485,9 @@ fun getPyClassByAttribute(pyPsiElement: PsiElement?): PyClass? {
return pyPsiElement?.parent?.parent as? PyClass
}

fun getPydanticModelByAttribute(pyPsiElement: PsiElement?, includeDataclass: Boolean, context: TypeEvalContext): PyClass? =
getPyClassByAttribute(pyPsiElement)?.takeIf { isPydanticModel(it, includeDataclass, context) }

fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: TypeEvalContext): PyClassTypeImpl? {
var psiElement = getPsiElementByQualifiedName(QualifiedName.fromDottedString(qualifiedName), project, context)
if (psiElement == null) {
Expand All @@ -504,11 +498,13 @@ fun createPyClassTypeImpl(qualifiedName: String, project: Project, context: Type
return PyClassTypeImpl.createTypeByQName(psiElement, qualifiedName, false)
}

fun getPydanticPyClass(pyCallExpression: PyCallExpression, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? {
val pyClass = getPyClassByPyCallExpression(pyCallExpression, includeDataclass, context) ?: return null
if (!isPydanticModel(pyClass, includeDataclass, context)) return null
return pyClass
}
fun getPydanticPyClass(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClass? =
getPydanticPyClassType(pyTypedElement, context, includeDataclass)?.pyClass

fun getPydanticPyClassType(pyTypedElement: PyTypedElement, context: TypeEvalContext, includeDataclass: Boolean = false): PyClassType? =
context.getType(pyTypedElement)?.pyClassTypes?.firstOrNull {
isPydanticModel(it.pyClass, includeDataclass, context)
}

fun getAncestorPydanticModels(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): List<PyClass> {
return pyClass.getAncestorClasses(context).filter { isPydanticModel(it, includeDataclass, context) }
Expand All @@ -535,15 +531,27 @@ fun addKeywordArgument(pyCallExpression: PyCallExpression, pyKeywordArgument: Py
}
}

val PyExpression.isKeywordArgument: Boolean get() =
this is PyKeywordArgument || (this as? PyStarArgument)?.isKeyword == true

fun getPydanticUnFilledArguments(
pydanticType: PyCallableType,
pyCallExpression: PyCallExpression,
context: TypeEvalContext,
isDataClass: Boolean
): List<PyCallableParameter> {
val currentArguments =
pyCallExpression.arguments.filter { it is PyKeywordArgument || (it as? PyStarArgument)?.isKeyword == true }
.mapNotNull { it.name }.toSet()
return pydanticType.getParameters(context)?.filterNot { currentArguments.contains(it.name) } ?: emptyList()
val parameters = pydanticType.getParameters(context)?.let { allParameters ->
if (isDataClass) {
pyCallExpression.arguments
.filterNot { it.isKeywordArgument }
.let { allParameters.drop(it.size) }
} else {
allParameters
}
} ?: listOf()

val currentArguments = pyCallExpression.arguments.filter { it.isKeywordArgument }.mapNotNull { it.name }.toSet()
return parameters.filterNot { currentArguments.contains(it.name) }
}

val PyCallableParameter.required: Boolean
Expand Down Expand Up @@ -659,3 +667,8 @@ fun getPydanticModelInit(pyClass: PyClass, context: TypeEvalContext): PyFunction

fun PyCallExpression.isDefinitionCallExpression(context: TypeEvalContext): Boolean =
this.callee?.reference?.resolve()?.let { it as? PyClass }?.getType(context)?.isDefinition == true

fun PyCallExpression.getPyCallableType(context: TypeEvalContext): PyCallableType? =
this.callee?.getType(context) as? PyCallableType
fun PyCallableType.getPydanticModel(includeDataclass: Boolean, context: TypeEvalContext): PyClass? =
this.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass?.takeIf { isPydanticModel(it,includeDataclass, context) }
12 changes: 6 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticAnnotator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ import com.intellij.openapi.util.TextRange
import com.intellij.util.containers.nullize
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyStarArgument
import com.jetbrains.python.psi.types.PyCallableType
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.validation.PyAnnotator


class PydanticAnnotator : PyAnnotator() {
private val pydanticTypeProvider = PydanticTypeProvider()
override fun visitPyCallExpression(node: PyCallExpression) {
super.visitPyCallExpression(node)
annotatePydanticModelCallableExpression(node)
}

private fun annotatePydanticModelCallableExpression(pyCallExpression: PyCallExpression) {
val context = TypeEvalContext.codeAnalysis(pyCallExpression.project, pyCallExpression.containingFile)
if (!pyCallExpression.isDefinitionCallExpression(context)) return

val pyClass = getPydanticPyClass(pyCallExpression, context) ?: return
val pyClassType = pyCallExpression.getPyCallableType(context) ?: return
val pyClass = pyClassType.getPydanticModel(true, context) ?: return
if (!isPydanticModel(pyClass, true, context)) return
if (getPydanticModelInit(pyClass, context) != null) return
val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, pyCallExpression) ?: return
if (!pyCallExpression.isDefinitionCallExpression(context)) return

val unFilledArguments =
getPydanticUnFilledArguments(pydanticType, pyCallExpression, context).nullize()
getPydanticUnFilledArguments(pyClassType, pyCallExpression, context, pyClass.isPydanticDataclass).nullize()
?: return
holder.newSilentAnnotation(HighlightSeverity.INFORMATION).withFix(PydanticInsertArgumentsQuickFix(false))
.create()
Expand Down
13 changes: 4 additions & 9 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,9 @@ class PydanticCompletionContributor : CompletionContributor() {
val typeEvalContext = parameters.getTypeEvalContext()
val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return

val pyType = typeEvalContext.getType(pyTypedElement) ?: return

val pyClassType =
pyType.pyClassTypes.firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) }
?: return
val pyClassType = getPydanticPyClassType(pyTypedElement, typeEvalContext, true) ?: return

val pyClass = pyClassType.pyClass
val config = getConfig(pyClass, typeEvalContext, true)
if (pyClassType.isDefinition) { // class
Expand Down Expand Up @@ -377,9 +375,8 @@ class PydanticCompletionContributor : CompletionContributor() {
) {
val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!configClass.isConfigClass) return
val pydanticModel = getPyClassByAttribute(configClass) ?: return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return
if (getPydanticModelByAttribute(configClass,true, parameters.getTypeEvalContext()) == null) return


val definedSet = configClass.classAttributes
Expand All @@ -404,9 +401,7 @@ class PydanticCompletionContributor : CompletionContributor() {
context: ProcessingContext,
result: CompletionResultSet,
) {
val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return
val pydanticModel = getPydanticModelByAttribute(parameters.position.parent?.parent, true, parameters.getTypeEvalContext()) ?: return
if (pydanticModel.findNestedClass("Config", false) != null) return
val element = PrioritizedLookupElement.withGrouping(
LookupElementBuilder
Expand Down
49 changes: 6 additions & 43 deletions src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package com.koxudaxi.pydantic

import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionImpl
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.types.*

/**
Expand All @@ -20,18 +18,13 @@ import com.jetbrains.python.psi.types.*
*/
class PydanticDataclassTypeProvider : PyTypeProviderBase() {
private val pyDataclassTypeProvider = PyDataclassTypeProvider()
private val pydanticTypeProvider = PydanticTypeProvider()

override fun getReferenceType(
referenceTarget: PsiElement,
context: TypeEvalContext,
anchor: PsiElement?
): Ref<PyType>? {
return when {
referenceTarget is PyClass && referenceTarget.isPydanticDataclass ->
getPydanticDataclassType(referenceTarget, context, anchor as? PyCallExpression, true)
else ->null
}?.let { Ref.create(it) }
override fun getCallableType(callable: PyCallable, context: TypeEvalContext): PyType? {
if (callable is PyFunction && callable.isPydanticDataclass) {
// Drop fake dataclass return type
return PyCallableTypeImpl(callable.getParameters(context), null)
}
return super.getCallableType(callable, context)
}

internal fun getDataclassCallableType(
Expand All @@ -45,34 +38,4 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() {
callSite ?: PyCallExpressionImpl(referenceTarget.node)
)?.get() as? PyCallableType
}

private fun getPydanticDataclassType(
referenceTarget: PsiElement,
context: TypeEvalContext,
callSite: PyCallExpression?,
definition: Boolean,
): PyType? {
val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null

val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null
if (!dataclassType.pyClass.isPydanticDataclass) return null
val ellipsis = PyElementGenerator.getInstance(referenceTarget.project).createEllipsis()
val injectedPyCallableType = PyCallableTypeImpl(
dataclassCallableType.getParameters(context)?.map {
when {
it.defaultValueText == "..." && it.defaultValue is PyNoneLiteralExpression ->
pydanticTypeProvider.injectDefaultValue(dataclassType.pyClass, it, ellipsis, null, context)
?: it

else -> it
}
}, dataclassType
)
val injectedDataclassType = (injectedPyCallableType).getReturnType(context) as? PyClassType ?: return null
return when {
callSite is PyCallExpression && definition -> injectedPyCallableType
definition -> injectedDataclassType.toClass()
else -> injectedDataclassType
}
}
}
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
}
is PyKeywordArgument -> {
val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false
if (isPydanticModel(pyClass, true, context)) return true
return getPydanticModelByPyKeywordArgument(element, true,context) is PyClass
//
}
}
return false
Expand Down Expand Up @@ -64,7 +64,7 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
is PyKeywordArgument ->
element.name?.let { name ->
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
getPydanticModelByPyKeywordArgument(element, true,context)
?.let { pyClass ->
addAllElement(pyClass, name, added, context)
}
Expand Down
3 changes: 1 addition & 2 deletions src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
element.name
?.let { elementName ->
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) }
getPydanticModelByPyKeywordArgument(element, true,context)
?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) }
}
}
Expand Down
12 changes: 4 additions & 8 deletions src/com/koxudaxi/pydantic/PydanticInsertArgumentsQuickFix.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import com.intellij.psi.PsiFile
import com.intellij.util.IncorrectOperationException
import com.intellij.util.containers.nullize
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.types.PyCallableParameter
import com.jetbrains.python.psi.types.TypeEvalContext

class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : LocalQuickFix, IntentionAction,
Expand Down Expand Up @@ -48,14 +47,11 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local
if (originalElement !is PyCallExpression) return null
if (file !is PyFile) return null
val newEl = originalElement.copy() as PyCallExpression
val pyClass = getPydanticPyClass(originalElement, context, true) ?: return null
val pydanticType = if (pyClass.isPydanticDataclass) {
pydanticDataclassTypeProvider.getDataclassCallableType(pyClass, context, originalElement)
} else {
pydanticTypeProvider.getPydanticTypeForClass(pyClass, context, true, originalElement) ?: return null
} ?: return null
val pyCallableType = originalElement.getPyCallableType(context) ?: return null
val pyClass = pyCallableType.getReturnType(context)?.pyClassTypes?.firstOrNull()?.pyClass ?: return null
if (!isPydanticModel(pyClass, true, context)) return null
val unFilledArguments =
getPydanticUnFilledArguments(pydanticType, originalElement, context).let {
getPydanticUnFilledArguments(pyCallableType, originalElement, context, pyClass.isPydanticDataclass).let {
when {
onlyRequired -> it.filter { arguments -> arguments.required }
else -> it
Expand Down
Loading

0 comments on commit 04f5218

Please sign in to comment.