Skip to content

Commit

Permalink
improve RootModel inspection (#783)
Browse files Browse the repository at this point in the history
* improve RootModel inspection

* Fix unittest

* Add testcase
  • Loading branch information
koxudaxi committed Aug 7, 2023
1 parent 415145e commit 61455a7
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 27 deletions.
5 changes: 5 additions & 0 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.jetbrains.python.statistics.modules
import java.util.regex.Pattern

const val BASE_MODEL_Q_NAME = "pydantic.main.BaseModel"
const val ROOT_MODEL_Q_NAME = "pydantic.root_model.RootModel"
const val GENERIC_MODEL_Q_NAME = "pydantic.generics.GenericModel"
const val DATA_CLASS_Q_NAME = "pydantic.dataclasses.dataclass"
const val DATA_CLASS_SHORT_Q_NAME = "pydantic.dataclass"
Expand Down Expand Up @@ -234,6 +235,10 @@ internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalCo
return pyClass.isSubclass(BASE_MODEL_Q_NAME, context)
}

internal fun isSubClassOfPydanticRootModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(ROOT_MODEL_Q_NAME, context)
}

internal fun isSubClassOfBaseSetting(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(BASE_SETTINGS_Q_NAME, context)
}
Expand Down
57 changes: 33 additions & 24 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ class PydanticInspection : PyInspection() {
super.visitPyTypeDeclarationStatement(node)

inspectAnnotatedField(node)
inspectCustomRootField(node)
}

override fun visitPyClass(node: PyClass) {
super.visitPyClass(node)

if(pydanticCacheService.isV2) {
inspectCustomRootFieldV2(node)
}

inspectConfig(node)
inspectDefaultFactory(node)
}
Expand Down Expand Up @@ -144,16 +143,6 @@ class PydanticInspection : PyInspection() {
.mapNotNull { (it.expression as? PyCallExpression)?.getArgument(1, PyReferenceExpression::class.java) }
.any { (it.reference.resolve() as? PyTargetExpression)?.findAssignedValue()?.name == "PydanticDeprecatedSince20" }


private fun inspectCustomRootFieldV2(pyClass: PyClass) {
if (getRootField(pyClass) == null) return
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
registerProblem(
pyClass.nameNode?.psi,
"__root__ models are no longer supported in v2; a migration guide will be added in the near future", ProblemHighlightType.GENERIC_ERROR
)
}

private fun inspectDefaultFactory(pyClass: PyClass) {
if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return
val defaultFactories = (pyClass.classAttributes + getAncestorPydanticModels(
Expand Down Expand Up @@ -310,18 +299,41 @@ class PydanticInspection : PyInspection() {
}

private fun inspectCustomRootField(node: PyAssignmentStatement) {
val pyClass = getPydanticModelByAttribute(node, false, myTypeEvalContext) ?: return

val field = node.leftHandSideExpression as? PyTargetExpression ?: return
inspectCustomRootField(field)
}
private fun inspectCustomRootField(node: PyTypeDeclarationStatement) {
val field = node.target as? PyTargetExpression ?: return
inspectCustomRootField(field)
}
private fun inspectCustomRootField(field: PyTargetExpression) {
val pyClass = getPydanticModelByAttribute(field.parent, false, myTypeEvalContext) ?: return

if (PyTypingTypeProvider.isClassVar(field, myTypeEvalContext)) return
val fieldName = field.text ?: return
val isV2 = pydanticCacheService.isV2
if (isV2 && fieldName == "__root__") {
registerProblem(
pyClass.nameNode?.psi,
"__root__ models are no longer supported in v2; a migration guide will be added in the near future", ProblemHighlightType.GENERIC_ERROR
)
registerProblem(field, "To define root models, use `pydantic.RootModel` rather than a field called '__root__'", ProblemHighlightType.WARNING)
return
}
if (fieldName.startsWith('_')) return
val rootModel = getRootField(pyClass)?.containingClass ?: return
if (!isPydanticModel(rootModel, false, myTypeEvalContext)) return
registerProblem(
node,
"__root__ cannot be mixed with other fields", ProblemHighlightType.WARNING
)
val message = when {
isV2 -> {
if (fieldName == "root") return
if (!isSubClassOfPydanticRootModel(pyClass, myTypeEvalContext)) return
if (pyClass.findClassAttribute("root", true, myTypeEvalContext) == null) return
"Unexpected field with name ${fieldName}; only 'root' is allowed as a field of a `RootModel`"
}
else -> {
if (pyClass.findClassAttribute("__root__", true, myTypeEvalContext) == null) return
"__root__ cannot be mixed with other fields"
}
}
registerProblem(field, message, ProblemHighlightType.WARNING)
}

private fun validateDefaultAndDefaultFactory(default: PyExpression?, defaultFactory: PyExpression?): Boolean {
Expand Down Expand Up @@ -393,9 +405,6 @@ class PydanticInspection : PyInspection() {
}
}

private fun getRootField(pyClass: PyClass): PyTargetExpression? {
return pyClass.findClassAttribute("__root__", true, myTypeEvalContext)
}
}

// override fun createOptionsPanel(): JComponent? {
Expand Down
6 changes: 5 additions & 1 deletion testData/inspection/customRoot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class B(BaseModel):

class C(BaseModel):
__root__ = 'xyz'
<warning descr="__root__ cannot be mixed with other fields">b = 'xyz'</warning>
<warning descr="__root__ cannot be mixed with other fields">b</warning> = 'xyz'


class D(BaseModel):
Expand All @@ -33,3 +33,7 @@ class G(BaseModel):
ATTRIBUTE_NAME: ClassVar[str] = "testing"
__root__ = 'xyz'

class H(BaseModel):
__root__ = 'xyz'
<warning descr="__root__ cannot be mixed with other fields">b</warning>: str

43 changes: 41 additions & 2 deletions testData/inspectionv2/customRoot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pydantic import BaseModel
from typing import ClassVar
from pydantic import BaseModel, RootModel


class <error descr="__root__ models are no longer supported in v2; a migration guide will be added in the near future">A</error>(BaseModel):
__root__ = 'xyz'
<warning descr="To define root models, use `pydantic.RootModel` rather than a field called '__root__'">__root__</warning> = 'xyz'


class B(BaseModel):
Expand All @@ -16,3 +17,41 @@ def d():
__root__ = 'xyz'
g = 'xyz'


class A(RootModel):
root = 'xyz'


class B(RootModel):
<warning descr="Unexpected field with name a; only 'root' is allowed as a field of a `RootModel`">a</warning> = 'xyz'


class C(RootModel):
root = 'xyz'
<warning descr="Unexpected field with name b; only 'root' is allowed as a field of a `RootModel`">b</warning> = 'xyz'


class D(RootModel):
root = 'xyz'
_c = 'xyz'
__c = 'xyz'

class E:
root = 'xyz'
e = 'xyz'

def f():
root = 'xyz'
g = 'xyz'

class G(RootModel):
ATTRIBUTE_NAME: ClassVar[str] = "testing"
root = 'xyz'

class H(RootModel):
root = 'xyz'
<warning descr="Unexpected field with name b; only 'root' is allowed as a field of a `RootModel`">b</warning>: str

class I(RootModel):
<warning descr="Unexpected field with name b; only 'root' is allowed as a field of a `RootModel`">b</warning>: str

1 change: 1 addition & 0 deletions testData/mock/pydanticv2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from .config import ConfigDict
from .version import VERSION
from .deprecated import validator, root_validator
from .root_model import RootModel

__version__ = VERSION
10 changes: 10 additions & 0 deletions testData/mock/pydanticv2/root_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import typing

from .main import BaseModel

RootModelRootType = typing.TypeVar('RootModelRootType')



class RootModel(BaseModel, typing.Generic[RootModelRootType]):
root: RootModelRootType

0 comments on commit 61455a7

Please sign in to comment.