diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 2854b235..af97f06a 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -27,6 +27,7 @@ import com.jetbrains.python.statistics.modules import java.util.regex.Pattern const val BASE_MODEL_Q_NAME = "pydantic.main.BaseModel" +const val ROOT_MODEL_Q_NAME = "pydantic.root_model.RootModel" const val GENERIC_MODEL_Q_NAME = "pydantic.generics.GenericModel" const val DATA_CLASS_Q_NAME = "pydantic.dataclasses.dataclass" const val DATA_CLASS_SHORT_Q_NAME = "pydantic.dataclass" @@ -234,6 +235,10 @@ internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalCo return pyClass.isSubclass(BASE_MODEL_Q_NAME, context) } +internal fun isSubClassOfPydanticRootModel(pyClass: PyClass, context: TypeEvalContext): Boolean { + return pyClass.isSubclass(ROOT_MODEL_Q_NAME, context) +} + internal fun isSubClassOfBaseSetting(pyClass: PyClass, context: TypeEvalContext): Boolean { return pyClass.isSubclass(BASE_SETTINGS_Q_NAME, context) } diff --git a/src/com/koxudaxi/pydantic/PydanticInspection.kt b/src/com/koxudaxi/pydantic/PydanticInspection.kt index 22840cbd..752593f4 100644 --- a/src/com/koxudaxi/pydantic/PydanticInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticInspection.kt @@ -85,14 +85,13 @@ class PydanticInspection : PyInspection() { super.visitPyTypeDeclarationStatement(node) inspectAnnotatedField(node) + inspectCustomRootField(node) } override fun visitPyClass(node: PyClass) { super.visitPyClass(node) - if(pydanticCacheService.isV2) { - inspectCustomRootFieldV2(node) - } + inspectConfig(node) inspectDefaultFactory(node) } @@ -144,16 +143,6 @@ class PydanticInspection : PyInspection() { .mapNotNull { (it.expression as? PyCallExpression)?.getArgument(1, PyReferenceExpression::class.java) } .any { (it.reference.resolve() as? PyTargetExpression)?.findAssignedValue()?.name == "PydanticDeprecatedSince20" } - - private fun inspectCustomRootFieldV2(pyClass: PyClass) { - if (getRootField(pyClass) == null) return - if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return - registerProblem( - pyClass.nameNode?.psi, - "__root__ models are no longer supported in v2; a migration guide will be added in the near future", ProblemHighlightType.GENERIC_ERROR - ) - } - private fun inspectDefaultFactory(pyClass: PyClass) { if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return val defaultFactories = (pyClass.classAttributes + getAncestorPydanticModels( @@ -310,18 +299,41 @@ class PydanticInspection : PyInspection() { } private fun inspectCustomRootField(node: PyAssignmentStatement) { - val pyClass = getPydanticModelByAttribute(node, false, myTypeEvalContext) ?: return - val field = node.leftHandSideExpression as? PyTargetExpression ?: return + inspectCustomRootField(field) + } + private fun inspectCustomRootField(node: PyTypeDeclarationStatement) { + val field = node.target as? PyTargetExpression ?: return + inspectCustomRootField(field) + } + private fun inspectCustomRootField(field: PyTargetExpression) { + val pyClass = getPydanticModelByAttribute(field.parent, false, myTypeEvalContext) ?: return + if (PyTypingTypeProvider.isClassVar(field, myTypeEvalContext)) return val fieldName = field.text ?: return + val isV2 = pydanticCacheService.isV2 + if (isV2 && fieldName == "__root__") { + registerProblem( + pyClass.nameNode?.psi, + "__root__ models are no longer supported in v2; a migration guide will be added in the near future", ProblemHighlightType.GENERIC_ERROR + ) + registerProblem(field, "To define root models, use `pydantic.RootModel` rather than a field called '__root__'", ProblemHighlightType.WARNING) + return + } if (fieldName.startsWith('_')) return - val rootModel = getRootField(pyClass)?.containingClass ?: return - if (!isPydanticModel(rootModel, false, myTypeEvalContext)) return - registerProblem( - node, - "__root__ cannot be mixed with other fields", ProblemHighlightType.WARNING - ) + val message = when { + isV2 -> { + if (fieldName == "root") return + if (!isSubClassOfPydanticRootModel(pyClass, myTypeEvalContext)) return + if (pyClass.findClassAttribute("root", true, myTypeEvalContext) == null) return + "Unexpected field with name ${fieldName}; only 'root' is allowed as a field of a `RootModel`" + } + else -> { + if (pyClass.findClassAttribute("__root__", true, myTypeEvalContext) == null) return + "__root__ cannot be mixed with other fields" + } + } + registerProblem(field, message, ProblemHighlightType.WARNING) } private fun validateDefaultAndDefaultFactory(default: PyExpression?, defaultFactory: PyExpression?): Boolean { @@ -393,9 +405,6 @@ class PydanticInspection : PyInspection() { } } - private fun getRootField(pyClass: PyClass): PyTargetExpression? { - return pyClass.findClassAttribute("__root__", true, myTypeEvalContext) - } } // override fun createOptionsPanel(): JComponent? { diff --git a/testData/inspection/customRoot.py b/testData/inspection/customRoot.py index b07195aa..44d1053a 100644 --- a/testData/inspection/customRoot.py +++ b/testData/inspection/customRoot.py @@ -13,7 +13,7 @@ class B(BaseModel): class C(BaseModel): __root__ = 'xyz' - b = 'xyz' + b = 'xyz' class D(BaseModel): @@ -33,3 +33,7 @@ class G(BaseModel): ATTRIBUTE_NAME: ClassVar[str] = "testing" __root__ = 'xyz' +class H(BaseModel): + __root__ = 'xyz' + b: str + diff --git a/testData/inspectionv2/customRoot.py b/testData/inspectionv2/customRoot.py index d69c4540..0a4d9612 100644 --- a/testData/inspectionv2/customRoot.py +++ b/testData/inspectionv2/customRoot.py @@ -1,8 +1,9 @@ -from pydantic import BaseModel +from typing import ClassVar +from pydantic import BaseModel, RootModel class A(BaseModel): - __root__ = 'xyz' + __root__ = 'xyz' class B(BaseModel): @@ -16,3 +17,41 @@ def d(): __root__ = 'xyz' g = 'xyz' + +class A(RootModel): + root = 'xyz' + + +class B(RootModel): + a = 'xyz' + + +class C(RootModel): + root = 'xyz' + b = 'xyz' + + +class D(RootModel): + root = 'xyz' + _c = 'xyz' + __c = 'xyz' + +class E: + root = 'xyz' + e = 'xyz' + +def f(): + root = 'xyz' + g = 'xyz' + +class G(RootModel): + ATTRIBUTE_NAME: ClassVar[str] = "testing" + root = 'xyz' + +class H(RootModel): + root = 'xyz' + b: str + +class I(RootModel): + b: str + diff --git a/testData/mock/pydanticv2/__init__.py b/testData/mock/pydanticv2/__init__.py index 05113ac6..d6409e19 100644 --- a/testData/mock/pydanticv2/__init__.py +++ b/testData/mock/pydanticv2/__init__.py @@ -22,5 +22,6 @@ from .config import ConfigDict from .version import VERSION from .deprecated import validator, root_validator +from .root_model import RootModel __version__ = VERSION diff --git a/testData/mock/pydanticv2/root_model.py b/testData/mock/pydanticv2/root_model.py new file mode 100644 index 00000000..c415b5c5 --- /dev/null +++ b/testData/mock/pydanticv2/root_model.py @@ -0,0 +1,10 @@ +import typing + +from .main import BaseModel + +RootModelRootType = typing.TypeVar('RootModelRootType') + + + +class RootModel(BaseModel, typing.Generic[RootModelRootType]): + root: RootModelRootType \ No newline at end of file