From bd665a197034f5132e782e764a13258bc1391a1e Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Sun, 6 Aug 2023 04:15:00 +0900 Subject: [PATCH] treat model validator as an instance method when mode='after' (#779) * treat model validator as an instance method when mode='after' * Fix unittest * Fix unittest --- src/com/koxudaxi/pydantic/Pydantic.kt | 32 ++++++++++++++----- .../pydantic/PydanticIgnoreInspection.kt | 14 ++++---- .../koxudaxi/pydantic/PydanticInspection.kt | 11 +++---- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 4 +-- .../PydanticTypedValidatorMethodHandler.kt | 3 +- .../ignoreinspection/validatorModeAfter.py | 9 ++++++ testData/inspectionv2/validatorField.py | 10 +++++- testData/inspectionv2/validatorSelf.py | 6 ++++ .../pydantic/PydanticIgnoreInspectionTest.kt | 4 +++ 9 files changed, 65 insertions(+), 28 deletions(-) create mode 100644 testData/ignoreinspection/validatorModeAfter.py diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 1c3049f7..2854b235 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -140,6 +140,10 @@ val V2_VALIDATOR_QUALIFIED_NAMES = listOf( MODEL_VALIDATOR_SHORT_QUALIFIED_NAME ) +val MODEL_VALIDATOR_QUALIFIED_NAMES = listOf( + MODEL_VALIDATOR_QUALIFIED_NAME, + MODEL_VALIDATOR_SHORT_QUALIFIED_NAME +) val FIELD_VALIDATOR_Q_NAMES = listOf( VALIDATOR_Q_NAME, VALIDATOR_SHORT_Q_NAME, @@ -241,13 +245,9 @@ internal fun isSubClassOfCustomBaseModel(pyClass: PyClass, context: TypeEvalCont internal val PyClass.isBaseSettings: Boolean get() = qualifiedName == BASE_SETTINGS_Q_NAME -internal fun hasDecorator(pyDecoratable: PyDecoratable, refNames: List): Boolean { - return pyDecoratable.decoratorList?.decorators?.mapNotNull { it.callee as? PyReferenceExpression }?.any { - PyResolveUtil.resolveImportedElementQNameLocally(it).any { decoratorQualifiedName -> - refNames.any { refName -> decoratorQualifiedName == refName } - } - } ?: false -} +internal fun hasDecorator(pyDecoratable: PyDecoratable, refNames: List): Boolean = + pyDecoratable.decoratorList?.decorators?.any {it.include(refNames)} ?: false + internal val PyClass.isPydanticDataclass: Boolean get() = hasDecorator(this, DATA_CLASS_QUALIFIED_NAMES) @@ -269,11 +269,27 @@ internal fun isDataclassMissing(pyTargetExpression: PyTargetExpression): Boolean return pyTargetExpression.qualifiedName == DATACLASS_MISSING } -internal fun PyFunction.isValidatorMethod(pydanticVersion: KotlinVersion?): Boolean = +internal fun PyFunction.hasValidatorMethod(pydanticVersion: KotlinVersion?): Boolean = hasDecorator(this, if(pydanticVersion.isV2) V2_VALIDATOR_QUALIFIED_NAMES else VALIDATOR_QUALIFIED_NAMES) +internal fun PyDecorator.include(refNames: List): Boolean = (callee as? PyReferenceExpression)?.let { + PyResolveUtil.resolveImportedElementQNameLocally(it).any { decoratorQualifiedName -> + refNames.any { refName -> decoratorQualifiedName == refName } + } +} ?: false +internal val PyKeywordArgument.value: PyExpression? + get() = when (val value = valueExpression) { + is PyReferenceExpression -> (value.reference.resolve() as? PyTargetExpression)?.findAssignedValue() + else -> value + } +internal fun PyFunction.hasModelValidatorModeAfter(): Boolean = decoratorList?.decorators + ?.filter { it.include(MODEL_VALIDATOR_QUALIFIED_NAMES) } + ?.any { modelValidator -> + modelValidator.argumentList?.getKeywordArgument("mode") + ?.let { it.value as? PyStringLiteralExpression }?.stringValue == "after" + } ?: false internal val PyClass.isConfigClass: Boolean get() = name == "Config" diff --git a/src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt b/src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt index 9e475de9..f6f8a78a 100644 --- a/src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt @@ -2,9 +2,7 @@ package com.koxudaxi.pydantic import com.intellij.psi.PsiReference import com.jetbrains.python.inspections.PyInspectionExtension -import com.jetbrains.python.psi.PyElement -import com.jetbrains.python.psi.PyFunction -import com.jetbrains.python.psi.PyStringLiteralExpression +import com.jetbrains.python.psi.* import com.jetbrains.python.psi.types.TypeEvalContext class PydanticIgnoreInspection : PyInspectionExtension() { @@ -20,10 +18,10 @@ class PydanticIgnoreInspection : PyInspectionExtension() { } override fun ignoreMethodParameters(function: PyFunction, context: TypeEvalContext): Boolean { - return function.containingClass?.let { - isPydanticModel(it, - true, - context) && function.isValidatorMethod(PydanticCacheService.getVersion(function.project)) - } == true + val pyClass = function.containingClass ?: return false + if (!isPydanticModel(pyClass, true, context)) return false + if (!function.hasValidatorMethod(PydanticCacheService.getVersion(function.project))) return false + if (function.hasModelValidatorModeAfter()) return false + return true } } \ No newline at end of file diff --git a/src/com/koxudaxi/pydantic/PydanticInspection.kt b/src/com/koxudaxi/pydantic/PydanticInspection.kt index 3268e851..5e25f5da 100644 --- a/src/com/koxudaxi/pydantic/PydanticInspection.kt +++ b/src/com/koxudaxi/pydantic/PydanticInspection.kt @@ -36,7 +36,8 @@ class PydanticInspection : PyInspection() { super.visitPyFunction(node) if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return - if (!node.isValidatorMethod(pydanticCacheService.getOrPutVersion())) return + if (!node.hasValidatorMethod(pydanticCacheService.getOrPutVersion())) return + if (node.hasModelValidatorModeAfter()) return val paramList = node.parameterList val params = paramList.parameters val firstParam = params.firstOrNull() @@ -98,13 +99,9 @@ class PydanticInspection : PyInspection() { private fun inspectValidatorField(pyStringLiteralExpression: PyStringLiteralExpression) { if (pyStringLiteralExpression.reference?.resolve() != null) return val pyArgumentList = pyStringLiteralExpression.parent as? PyArgumentList ?: return - pyArgumentList.getKeywordArgument("check_fields")?.let { it -> - val checkFields = when (val value = it.valueExpression){ - is PyReferenceExpression -> (value.reference.resolve() as? PyTargetExpression)?.findAssignedValue() - else -> value - }?.let { PyEvaluator.evaluateAsBoolean(it) } + pyArgumentList.getKeywordArgument("check_fields")?.let { // ignore unresolved value - if (checkFields != true) return + if (PyEvaluator.evaluateAsBoolean(it.value)!= true) return } val stringValue = pyStringLiteralExpression.stringValue if (stringValue == "*") return diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index e8ad6860..d7ed09d4 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -79,8 +79,8 @@ class PydanticTypeProvider : PyTypeProviderBase() { getRefTypeFromFieldName(name, context, pyClass) } - param.isSelf && func.isValidatorMethod(PydanticCacheService.getVersion(func.project) - ) -> { + param.isSelf && func.hasValidatorMethod(PydanticCacheService.getVersion(func.project)) && !func.hasModelValidatorModeAfter() + -> { val pyClass = func.containingClass ?: return null if (!isPydanticModel(pyClass, false, context)) return null context.getType(pyClass) diff --git a/src/com/koxudaxi/pydantic/PydanticTypedValidatorMethodHandler.kt b/src/com/koxudaxi/pydantic/PydanticTypedValidatorMethodHandler.kt index 8933a263..70c7606a 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypedValidatorMethodHandler.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypedValidatorMethodHandler.kt @@ -17,7 +17,6 @@ import com.jetbrains.python.PythonLanguage import com.jetbrains.python.codeInsight.PyCodeInsightSettings import com.jetbrains.python.psi.PyFunction import com.jetbrains.python.psi.impl.PyPsiUtils -import com.jetbrains.python.psi.types.TypeEvalContext import java.util.regex.Pattern class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() { @@ -53,7 +52,7 @@ class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() { val defNode = maybeDef.node if (defNode != null && defNode.elementType === PyTokenTypes.DEF_KEYWORD) { val pyFunction = token.parent as? PyFunction ?: return Result.CONTINUE - if (!pyFunction.isValidatorMethod(PydanticCacheService.getVersion(project))) return Result.CONTINUE + if (!pyFunction.hasValidatorMethod(PydanticCacheService.getVersion(project))) return Result.CONTINUE val settings = CodeStyle.getLanguageSettings(file, PythonLanguage.getInstance()) val textToType = StringBuilder() textToType.append("(") diff --git a/testData/ignoreinspection/validatorModeAfter.py b/testData/ignoreinspection/validatorModeAfter.py new file mode 100644 index 00000000..69e67c62 --- /dev/null +++ b/testData/ignoreinspection/validatorModeAfter.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel, model_validator + + +class A(BaseModel): + a: str + + @model_validator(mode='after') + def validate_a(self): + pass diff --git a/testData/inspectionv2/validatorField.py b/testData/inspectionv2/validatorField.py index ae25cc23..7731332a 100644 --- a/testData/inspectionv2/validatorField.py +++ b/testData/inspectionv2/validatorField.py @@ -55,7 +55,15 @@ def validate_c(**kwargs): def validate_c(**kwargs): pass - @model_validator('x') + @model_validator(mode='before') + def validate_model_before(cls): + pass + + @model_validator(mode='after') + def validate_model_after(self): + pass + + @model_validator() def validate_model(cls): pass diff --git a/testData/inspectionv2/validatorSelf.py b/testData/inspectionv2/validatorSelf.py index f190a49f..37462bda 100644 --- a/testData/inspectionv2/validatorSelf.py +++ b/testData/inspectionv2/validatorSelf.py @@ -36,7 +36,13 @@ def validate_e(): pass + @model_validator(mode='after') + def validate_model_after(self): + pass + @model_validator(mode='before') + def validate_model_before(): + pass def dummy(self): pass diff --git a/testSrc/com/koxudaxi/pydantic/PydanticIgnoreInspectionTest.kt b/testSrc/com/koxudaxi/pydantic/PydanticIgnoreInspectionTest.kt index 55550190..9cce24a7 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticIgnoreInspectionTest.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticIgnoreInspectionTest.kt @@ -72,4 +72,8 @@ open class PydanticIgnoreInspectionTest : PydanticTestCase() { fun testDecoratorField() { doIgnoreUnresolvedReference(false) } + + fun testValidatorModeAfter() { + doIgnoreMethodParametersTest(false) + } }