diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt
index 2854b235..af97f06a 100644
--- a/src/com/koxudaxi/pydantic/Pydantic.kt
+++ b/src/com/koxudaxi/pydantic/Pydantic.kt
@@ -27,6 +27,7 @@ import com.jetbrains.python.statistics.modules
import java.util.regex.Pattern
const val BASE_MODEL_Q_NAME = "pydantic.main.BaseModel"
+const val ROOT_MODEL_Q_NAME = "pydantic.root_model.RootModel"
const val GENERIC_MODEL_Q_NAME = "pydantic.generics.GenericModel"
const val DATA_CLASS_Q_NAME = "pydantic.dataclasses.dataclass"
const val DATA_CLASS_SHORT_Q_NAME = "pydantic.dataclass"
@@ -234,6 +235,10 @@ internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalCo
return pyClass.isSubclass(BASE_MODEL_Q_NAME, context)
}
+internal fun isSubClassOfPydanticRootModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
+ return pyClass.isSubclass(ROOT_MODEL_Q_NAME, context)
+}
+
internal fun isSubClassOfBaseSetting(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(BASE_SETTINGS_Q_NAME, context)
}
diff --git a/src/com/koxudaxi/pydantic/PydanticInspection.kt b/src/com/koxudaxi/pydantic/PydanticInspection.kt
index 22840cbd..752593f4 100644
--- a/src/com/koxudaxi/pydantic/PydanticInspection.kt
+++ b/src/com/koxudaxi/pydantic/PydanticInspection.kt
@@ -85,14 +85,13 @@ class PydanticInspection : PyInspection() {
super.visitPyTypeDeclarationStatement(node)
inspectAnnotatedField(node)
+ inspectCustomRootField(node)
}
override fun visitPyClass(node: PyClass) {
super.visitPyClass(node)
- if(pydanticCacheService.isV2) {
- inspectCustomRootFieldV2(node)
- }
+
inspectConfig(node)
inspectDefaultFactory(node)
}
@@ -144,16 +143,6 @@ class PydanticInspection : PyInspection() {
.mapNotNull { (it.expression as? PyCallExpression)?.getArgument(1, PyReferenceExpression::class.java) }
.any { (it.reference.resolve() as? PyTargetExpression)?.findAssignedValue()?.name == "PydanticDeprecatedSince20" }
-
- private fun inspectCustomRootFieldV2(pyClass: PyClass) {
- if (getRootField(pyClass) == null) return
- if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
- registerProblem(
- pyClass.nameNode?.psi,
- "__root__ models are no longer supported in v2; a migration guide will be added in the near future", ProblemHighlightType.GENERIC_ERROR
- )
- }
-
private fun inspectDefaultFactory(pyClass: PyClass) {
if (!isPydanticModel(pyClass, true, myTypeEvalContext)) return
val defaultFactories = (pyClass.classAttributes + getAncestorPydanticModels(
@@ -310,18 +299,41 @@ class PydanticInspection : PyInspection() {
}
private fun inspectCustomRootField(node: PyAssignmentStatement) {
- val pyClass = getPydanticModelByAttribute(node, false, myTypeEvalContext) ?: return
-
val field = node.leftHandSideExpression as? PyTargetExpression ?: return
+ inspectCustomRootField(field)
+ }
+ private fun inspectCustomRootField(node: PyTypeDeclarationStatement) {
+ val field = node.target as? PyTargetExpression ?: return
+ inspectCustomRootField(field)
+ }
+ private fun inspectCustomRootField(field: PyTargetExpression) {
+ val pyClass = getPydanticModelByAttribute(field.parent, false, myTypeEvalContext) ?: return
+
if (PyTypingTypeProvider.isClassVar(field, myTypeEvalContext)) return
val fieldName = field.text ?: return
+ val isV2 = pydanticCacheService.isV2
+ if (isV2 && fieldName == "__root__") {
+ registerProblem(
+ pyClass.nameNode?.psi,
+ "__root__ models are no longer supported in v2; a migration guide will be added in the near future", ProblemHighlightType.GENERIC_ERROR
+ )
+ registerProblem(field, "To define root models, use `pydantic.RootModel` rather than a field called '__root__'", ProblemHighlightType.WARNING)
+ return
+ }
if (fieldName.startsWith('_')) return
- val rootModel = getRootField(pyClass)?.containingClass ?: return
- if (!isPydanticModel(rootModel, false, myTypeEvalContext)) return
- registerProblem(
- node,
- "__root__ cannot be mixed with other fields", ProblemHighlightType.WARNING
- )
+ val message = when {
+ isV2 -> {
+ if (fieldName == "root") return
+ if (!isSubClassOfPydanticRootModel(pyClass, myTypeEvalContext)) return
+ if (pyClass.findClassAttribute("root", true, myTypeEvalContext) == null) return
+ "Unexpected field with name ${fieldName}; only 'root' is allowed as a field of a `RootModel`"
+ }
+ else -> {
+ if (pyClass.findClassAttribute("__root__", true, myTypeEvalContext) == null) return
+ "__root__ cannot be mixed with other fields"
+ }
+ }
+ registerProblem(field, message, ProblemHighlightType.WARNING)
}
private fun validateDefaultAndDefaultFactory(default: PyExpression?, defaultFactory: PyExpression?): Boolean {
@@ -393,9 +405,6 @@ class PydanticInspection : PyInspection() {
}
}
- private fun getRootField(pyClass: PyClass): PyTargetExpression? {
- return pyClass.findClassAttribute("__root__", true, myTypeEvalContext)
- }
}
// override fun createOptionsPanel(): JComponent? {
diff --git a/testData/inspection/customRoot.py b/testData/inspection/customRoot.py
index b07195aa..44d1053a 100644
--- a/testData/inspection/customRoot.py
+++ b/testData/inspection/customRoot.py
@@ -13,7 +13,7 @@ class B(BaseModel):
class C(BaseModel):
__root__ = 'xyz'
- b = 'xyz'
+ b = 'xyz'
class D(BaseModel):
@@ -33,3 +33,7 @@ class G(BaseModel):
ATTRIBUTE_NAME: ClassVar[str] = "testing"
__root__ = 'xyz'
+class H(BaseModel):
+ __root__ = 'xyz'
+ b: str
+
diff --git a/testData/inspectionv2/customRoot.py b/testData/inspectionv2/customRoot.py
index d69c4540..0a4d9612 100644
--- a/testData/inspectionv2/customRoot.py
+++ b/testData/inspectionv2/customRoot.py
@@ -1,8 +1,9 @@
-from pydantic import BaseModel
+from typing import ClassVar
+from pydantic import BaseModel, RootModel
class A(BaseModel):
- __root__ = 'xyz'
+ __root__ = 'xyz'
class B(BaseModel):
@@ -16,3 +17,41 @@ def d():
__root__ = 'xyz'
g = 'xyz'
+
+class A(RootModel):
+ root = 'xyz'
+
+
+class B(RootModel):
+ a = 'xyz'
+
+
+class C(RootModel):
+ root = 'xyz'
+ b = 'xyz'
+
+
+class D(RootModel):
+ root = 'xyz'
+ _c = 'xyz'
+ __c = 'xyz'
+
+class E:
+ root = 'xyz'
+ e = 'xyz'
+
+def f():
+ root = 'xyz'
+ g = 'xyz'
+
+class G(RootModel):
+ ATTRIBUTE_NAME: ClassVar[str] = "testing"
+ root = 'xyz'
+
+class H(RootModel):
+ root = 'xyz'
+ b: str
+
+class I(RootModel):
+ b: str
+
diff --git a/testData/mock/pydanticv2/__init__.py b/testData/mock/pydanticv2/__init__.py
index 05113ac6..d6409e19 100644
--- a/testData/mock/pydanticv2/__init__.py
+++ b/testData/mock/pydanticv2/__init__.py
@@ -22,5 +22,6 @@
from .config import ConfigDict
from .version import VERSION
from .deprecated import validator, root_validator
+from .root_model import RootModel
__version__ = VERSION
diff --git a/testData/mock/pydanticv2/root_model.py b/testData/mock/pydanticv2/root_model.py
new file mode 100644
index 00000000..c415b5c5
--- /dev/null
+++ b/testData/mock/pydanticv2/root_model.py
@@ -0,0 +1,10 @@
+import typing
+
+from .main import BaseModel
+
+RootModelRootType = typing.TypeVar('RootModelRootType')
+
+
+
+class RootModel(BaseModel, typing.Generic[RootModelRootType]):
+ root: RootModelRootType
\ No newline at end of file