Skip to content

Commit

Permalink
Improve detecting pydnatic models
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed May 7, 2021
1 parent 1b986a2 commit cc5bec9
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 44 deletions.
1 change: 1 addition & 0 deletions resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<h2>version 0.3.1</h2>
<p>Features</p>
<ul>
<li>Support GenericModel [#289]</li>
<li>Support frozen on config [#288]</li>
<li>Fix format [#287]</li>
<li>Improve handling pydantic version [#286]</li>
Expand Down
17 changes: 11 additions & 6 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,22 @@ fun getPyClassByPyCallExpression(
is PyClassType -> type
else -> (callee.reference?.resolve() as? PyTypedElement)?.let { context.getType(it) } ?: return null
}
return getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, includeDataclass) }?.pyClass
return getPyClassTypeByPyTypes(pyType).firstOrNull {
isPydanticModel(it.pyClass,
includeDataclass,
context)
}?.pyClass
}

fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null
return getPyClassByPyCallExpression(pyCallExpression, true, context)
}

fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext? = null): Boolean {
fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean {
return (isSubClassOfPydanticBaseModel(pyClass,
context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel(
context) || isSubClassOfPydanticGenericModel(pyClass, context) || (includeDataclass && isPydanticDataclass(
pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel(
pyClass)
}

Expand All @@ -148,11 +153,11 @@ fun isPydanticGenericModel(pyClass: PyClass): Boolean {
return pyClass.qualifiedName == GENERIC_MODEL_Q_NAME
}

internal fun isSubClassOfPydanticGenericModel(pyClass: PyClass, context: TypeEvalContext?): Boolean {
internal fun isSubClassOfPydanticGenericModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(GENERIC_MODEL_Q_NAME, context)
}

internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext?): Boolean {
internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(BASE_MODEL_Q_NAME, context)
}

Expand Down Expand Up @@ -375,7 +380,7 @@ fun getConfig(
val version = pydanticVersion ?: PydanticVersionService.getVersion(pyClass.project, context)
pyClass.getAncestorClasses(context)
.reversed()
.filter { isPydanticModel(it, false) }
.filter { isPydanticModel(it, false, context) }
.map { getConfig(it, context, false, version) }
.forEach {
it.entries.forEach { entry ->
Expand Down
18 changes: 11 additions & 7 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class PydanticCompletionContributor : CompletionContributor() {
val newElements: LinkedHashMap<String, LookupElement> = LinkedHashMap()

pyClass.getAncestorClasses(typeEvalContext)
.filter { isPydanticModel(it, true) }
.filter { isPydanticModel(it, true, typeEvalContext) }
.forEach {
addFieldElement(it,
newElements,
Expand Down Expand Up @@ -183,12 +183,12 @@ class PydanticCompletionContributor : CompletionContributor() {
excludes: HashSet<String>, config: HashMap<String, Any?>,
) {

if (!isPydanticModel(pyClass, true)) return
if (!isPydanticModel(pyClass, true, typeEvalContext)) return

val fieldElements: HashSet<String> = HashSet()

pyClass.getAncestorClasses(typeEvalContext)
.filter { isPydanticModel(it, true) }
.filter { isPydanticModel(it, true, typeEvalContext) }
.forEach {
fieldElements.addAll(it.classAttributes
.filterNot { attribute ->
Expand Down Expand Up @@ -294,10 +294,12 @@ class PydanticCompletionContributor : CompletionContributor() {
) {
val typeEvalContext = parameters.getTypeEvalContext()
val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return

val pyType = typeEvalContext.getType(pyTypedElement) ?: return

val pyClassType = getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true) }
?: return
val pyClassType =
getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) }
?: return
val pyClass = pyClassType.pyClass
val config = getConfig(pyClass, typeEvalContext, true)
if (pyClassType.isDefinition) { // class
Expand Down Expand Up @@ -357,8 +359,9 @@ class PydanticCompletionContributor : CompletionContributor() {
val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!isConfigClass(configClass)) return
val pydanticModel = getPyClassByAttribute(configClass) ?: return
if (!isPydanticModel(pydanticModel, true)) return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return


val definedSet = configClass.classAttributes
.mapNotNull { it.name }
Expand All @@ -383,7 +386,8 @@ class PydanticCompletionContributor : CompletionContributor() {
result: CompletionResultSet,
) {
val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!isPydanticModel(pydanticModel, true)) return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return
if (pydanticModel.findNestedClass("Config", false) != null) return
val element = PrioritizedLookupElement.withGrouping(
LookupElementBuilder
Expand Down
31 changes: 19 additions & 12 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
when (element) {
is PyTargetExpression -> {
val pyClass = element.containingClass ?: return false
if (isPydanticModel(pyClass, true)) return true
val content = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
if (isPydanticModel(pyClass, true, content)) return true
}
is PyKeywordArgument -> {
val pyClass = getPyClassByPyKeywordArgument(element,
TypeEvalContext.codeAnalysis(element.project, element.containingFile)) ?: return false
if (isPydanticModel(pyClass, true)) return true
val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false
if (isPydanticModel(pyClass, true, context)) return true
}
}
return false
Expand Down Expand Up @@ -55,33 +56,39 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
element.name?.let { name ->
element.containingClass
?.let { pyClass ->
addAllElement(pyClass, name, added)
val content = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
addAllElement(pyClass, name, added, content)
}
suggestAllNames(name, newName)
}
is PyKeywordArgument ->
element.name?.let { name ->
getPyClassByPyKeywordArgument(element,
TypeEvalContext.userInitiated(element.project, element.containingFile))
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
?.let { pyClass ->
addAllElement(pyClass, name, added)
addAllElement(pyClass, name, added, context)
}
suggestAllNames(name, newName)
}
}
}

private fun addAllElement(pyClass: PyClass, elementName: String, added: MutableSet<PyClass>) {
private fun addAllElement(
pyClass: PyClass,
elementName: String,
added: MutableSet<PyClass>,
context: TypeEvalContext,
) {
added.add(pyClass)
addClassAttributes(pyClass, elementName)
addKeywordArguments(pyClass, elementName)
pyClass.getAncestorClasses(null)
.filter { isPydanticModel(it, true) && !added.contains(it) }
.forEach { addAllElement(it, elementName, added) }
.filter { isPydanticModel(it, true, context) && !added.contains(it) }
.forEach { addAllElement(it, elementName, added, context) }

PyClassInheritorsSearch.search(pyClass, true)
.filterNot { added.contains(it) }
.forEach { addAllElement(it, elementName, added) }
.forEach { addAllElement(it, elementName, added, context) }
}

private fun addClassAttributes(pyClass: PyClass, elementName: String) {
Expand Down
55 changes: 39 additions & 16 deletions src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,37 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
is PyKeywordArgument -> run<RuntimeException> {
element.name
?.let { elementName ->
getPyClassByPyKeywordArgument(element,
TypeEvalContext.userInitiated(element.project, element.containingFile))
?.takeIf { pyClass -> isPydanticModel(pyClass, true) }
?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer) }
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) }
?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer, context) }
}
}
is PyTargetExpression -> run<RuntimeException> {
element.name
?.let { elementName ->
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
element.containingClass
?.takeIf { pyClass -> isPydanticModel(pyClass, true) }
?.takeIf { pyClass -> isPydanticModel(pyClass, true, context) }
?.let { pyClass ->
searchAllElementReference(pyClass,
elementName,
mutableSetOf(),
consumer)
consumer,
context)
}
}
}
}
}

private fun searchField(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>): Boolean {
if (!isPydanticModel(pyClass, true)) return false
private fun searchField(
pyClass: PyClass,
elementName: String,
consumer: Processor<in PsiReference>,
context: TypeEvalContext,
): Boolean {
if (!isPydanticModel(pyClass, true, context)) return false
val pyTargetExpression = pyClass.findClassAttribute(elementName, false, null) ?: return false
consumer.process(pyTargetExpression.reference)
return true
Expand All @@ -61,7 +68,12 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
}
}

private fun searchKeywordArgument(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>) {
private fun searchKeywordArgument(
pyClass: PyClass,
elementName: String,
consumer: Processor<in PsiReference>,
typeEvalContext: TypeEvalContext,
) {
ReferencesSearch.search(pyClass as PsiElement).forEach { psiReference ->
searchKeywordArgumentByPsiReference(psiReference, elementName, consumer)

Expand All @@ -72,7 +84,11 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
psiReference.element.containingFile))
?.let { pyType ->
getPyClassTypeByPyTypes(pyType)
.firstOrNull { pyClassType -> isPydanticModel(pyClassType.pyClass, true) }
.firstOrNull { pyClassType ->
isPydanticModel(pyClassType.pyClass,
true,
typeEvalContext)
}
?.let {
ReferencesSearch.search(param as PsiElement).forEach {
searchKeywordArgumentByPsiReference(it, elementName, consumer)
Expand All @@ -89,28 +105,35 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
pyClass: PyClass,
elementName: String,
consumer: Processor<in PsiReference>,
context: TypeEvalContext,
): Boolean {
if (searchField(pyClass, elementName, consumer)) return true
if (searchField(pyClass, elementName, consumer, context)) return true

return pyClass.getAncestorClasses(null)
.firstOrNull { isPydanticModel(it, true) && searchDirectReferenceField(it, elementName, consumer) } != null
.firstOrNull {
isPydanticModel(it, true, context) && searchDirectReferenceField(it,
elementName,
consumer,
context)
} != null
}

private fun searchAllElementReference(
pyClass: PyClass,
elementName: String,
added: MutableSet<PyClass>,
consumer: Processor<in PsiReference>,
context: TypeEvalContext,
) {
added.add(pyClass)
searchField(pyClass, elementName, consumer)
searchKeywordArgument(pyClass, elementName, consumer)
searchField(pyClass, elementName, consumer, context)
searchKeywordArgument(pyClass, elementName, consumer, context)
pyClass.getAncestorClasses(null)
.filter { !isPydanticBaseModel(it) && !added.contains(it) }
.forEach { searchField(it, elementName, consumer) }
.forEach { searchField(it, elementName, consumer, context) }

PyClassInheritorsSearch.search(pyClass, true)
.filterNot { added.contains(it) }
.forEach { searchAllElementReference(it, elementName, added, consumer) }
.forEach { searchAllElementReference(it, elementName, added, consumer, context) }
}
}
4 changes: 2 additions & 2 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ class PydanticInspection : PyInspection() {
is PyClass -> type
is PyClassType -> getPyClassTypeByPyTypes(type).firstOrNull {
isPydanticModel(it.pyClass,
false)
false, myTypeEvalContext)
}?.pyClass
else -> null
} ?: return
if (!isPydanticModel(pyClass, false)) return
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
val config = getConfig(pyClass, myTypeEvalContext, true)
if (config["orm_mode"] != true) {
registerProblem(pyCallExpression,
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticTypeCheckerInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PydanticTypeCheckerInspection : PyTypeCheckerInspection() {
override fun visitPyCallExpression(node: PyCallExpression) {
val pyClass = getPyClassByPyCallExpression(node, true, myTypeEvalContext)
getPyClassByPyCallExpression(node, true, myTypeEvalContext)
if (pyClass is PyClass && isPydanticModel(pyClass, true)) {
if (pyClass is PyClass && isPydanticModel(pyClass, true, myTypeEvalContext)) {
checkCallSiteForPydantic(node)
return
}
Expand Down
8 changes: 8 additions & 0 deletions testData/typeinspectionv18/genericModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,11 @@ class Model(GenericModel, Generic[CT, DT]):
Model[y, Optional[y]](a=1, b=2)

Model[y, Optional[y]](<warning descr="Expected type 'int', got 'str' instead">a='1'</warning>, <warning descr="Expected type 'Optional[int]', got 'str' instead">b='2'</warning>)


class Model(GenericModel, Generic[CT, DT, ET, FT, aaaaaaaaaa]):
a: Type[CT]
b: List[aaaaa]
c: Dict[ET, aaaaaaaa]

Model[aaaaaaaaaa, List[aaaaaa], Tuple[aaaaaaaaaa], Type[aaaaaaaaaaa]](a=int, b=[2], <warning descr="Expected type 'Dict[Tuple[Any], Any]', got 'Dict[str, int]' instead">c={'c': 3}</warning>)

0 comments on commit cc5bec9

Please sign in to comment.