Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pydantic v2 validators for 232 #737

Merged
merged 3 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import com.jetbrains.extensions.QNameResolveContext
import com.jetbrains.extensions.resolveToElement
import com.jetbrains.python.PyNames
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.packaging.PyPackageManagers
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyStarArgumentImpl
import com.jetbrains.python.psi.impl.PyTargetExpressionImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
Expand All @@ -33,6 +33,11 @@ const val VALIDATOR_Q_NAME = "pydantic.class_validators.validator"
const val VALIDATOR_SHORT_Q_NAME = "pydantic.validator"
const val ROOT_VALIDATOR_Q_NAME = "pydantic.class_validators.root_validator"
const val ROOT_VALIDATOR_SHORT_Q_NAME = "pydantic.root_validator"
const val FIELD_VALIDATOR_Q_NAME = "pydantic.field_validator"
const val FIELD_VALIDATOR_SHORT_Q_NAME = "pydantic.functional_validators.field_validator"
const val MODEL_VALIDATOR_Q_NAME = "pydantic.model_validator"
const val MODEL_VALIDATOR_SHORT_Q_NAME = "pydantic.functional_validators.model_validator"

const val SCHEMA_Q_NAME = "pydantic.schema.Schema"
const val FIELD_Q_NAME = "pydantic.fields.Field"
const val DATACLASS_FIELD_Q_NAME = "dataclasses.field"
Expand Down Expand Up @@ -85,6 +90,14 @@ val ROOT_VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(ROOT_VALIDATO

val ROOT_VALIDATOR_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(ROOT_VALIDATOR_SHORT_Q_NAME)

val FIELD_VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(FIELD_VALIDATOR_Q_NAME)

val FIELD_VALIDATOR_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(FIELD_VALIDATOR_SHORT_Q_NAME)

val MODEL_VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(MODEL_VALIDATOR_Q_NAME)

val MODEL_VALIDATOR_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(MODEL_VALIDATOR_SHORT_Q_NAME)

val DATA_CLASS_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_Q_NAME)

val DATA_CLASS_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_SHORT_Q_NAME)
Expand All @@ -103,6 +116,17 @@ val VALIDATOR_QUALIFIED_NAMES = listOf(
ROOT_VALIDATOR_SHORT_QUALIFIED_NAME
)

val V2_VALIDATOR_QUALIFIED_NAMES = listOf(
VALIDATOR_QUALIFIED_NAME,
VALIDATOR_SHORT_QUALIFIED_NAME,
ROOT_VALIDATOR_QUALIFIED_NAME,
ROOT_VALIDATOR_SHORT_QUALIFIED_NAME,
FIELD_VALIDATOR_QUALIFIED_NAME,
FIELD_VALIDATOR_SHORT_QUALIFIED_NAME,
MODEL_VALIDATOR_QUALIFIED_NAME,
MODEL_VALIDATOR_SHORT_QUALIFIED_NAME
)

val VERSION_SPLIT_PATTERN: Pattern = Pattern.compile("[.a-zA-Z]")!!

val pydanticVersionCache: HashMap<String, KotlinVersion> = hashMapOf()
Expand Down Expand Up @@ -210,7 +234,9 @@ internal fun isDataclassMissing(pyTargetExpression: PyTargetExpression): Boolean
return pyTargetExpression.qualifiedName == DATACLASS_MISSING
}

internal val PyFunction.isValidatorMethod: Boolean get() = hasDecorator(this, VALIDATOR_QUALIFIED_NAMES)
internal fun PyFunction.isValidatorMethod(pydanticVersion: KotlinVersion?): Boolean =
hasDecorator(this, if(pydanticVersion.isV2) V2_VALIDATOR_QUALIFIED_NAMES else VALIDATOR_QUALIFIED_NAMES)



internal val PyClass.isConfigClass: Boolean get() = name == "Config"
Expand Down Expand Up @@ -406,7 +432,7 @@ fun getConfig(
pydanticVersion: KotlinVersion? = null,
): HashMap<String, Any?> {
val config = hashMapOf<String, Any?>()
val version = pydanticVersion ?: PydanticCacheService.getVersion(pyClass.project, context)
val version = pydanticVersion ?: PydanticCacheService.getVersion(pyClass.project)
getAncestorPydanticModels(pyClass, false, context)
.reversed()
.map { getConfig(it, context, false, version) }
Expand Down Expand Up @@ -661,4 +687,8 @@ fun PyCallableType.getPydanticModel(includeDataclass: Boolean, context: TypeEval


val KotlinVersion?.isV2: Boolean
get() = this?.isAtLeast(2, 0) == true
get() = this?.isAtLeast(2, 0) == true

val Sdk.pydanticVersion: String?
get() = PyPackageManagers.getInstance()
.forSdk(this).packages?.find { it.name == "pydantic" }?.version
36 changes: 19 additions & 17 deletions src/com/koxudaxi/pydantic/PydanticCacheService.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package com.koxudaxi.pydantic

import com.intellij.openapi.project.Project
import com.jetbrains.python.psi.PyStringLiteralExpression
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.impl.PyStringLiteralExpressionImpl
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.sdk.pythonSdk

class PydanticCacheService(val project: Project) {
private var version: KotlinVersion? = null
Expand All @@ -17,34 +15,35 @@ class PydanticCacheService(val project: Project) {
.filterNot { it.startsWith("__") && it.endsWith("__") }
.toSet()
}
private fun getVersion(context: TypeEvalContext): KotlinVersion? {
val version = getPsiElementByQualifiedName(VERSION_QUALIFIED_NAME, project, context) as? PyTargetExpression
?: return null
val versionString =
(version.findAssignedValue()?.lastChild?.firstChild?.nextSibling as? PyStringLiteralExpression)?.stringValue
?: (version.findAssignedValue() as? PyStringLiteralExpressionImpl)?.stringValue ?: return null
return setVersion(versionString)
private fun getVersion(): KotlinVersion? {
val sdk = project.pythonSdk ?: return null
val versionString = sdk.pydanticVersion ?: return null
return getOrPutVersionFromVersionCache(versionString)
}

private fun setVersion(version: String): KotlinVersion {
private fun getOrPutVersionFromVersionCache(version: String): KotlinVersion? {
return pydanticVersionCache.getOrPut(version) {
val versionList = version.split(VERSION_SPLIT_PATTERN).map { it.toIntOrNull() ?: 0 }
val pydanticVersion = when {
versionList.size == 1 -> KotlinVersion(versionList[0], 0)
versionList.size == 2 -> KotlinVersion(versionList[0], versionList[1])
versionList.size >= 3 -> KotlinVersion(versionList[0], versionList[1], versionList[2])
else -> null
} ?: KotlinVersion(0, 0)
} ?: return null
pydanticVersionCache[version] = pydanticVersion
pydanticVersion
}
}

private fun getOrPutVersion(context: TypeEvalContext): KotlinVersion? {
internal fun getOrPutVersion(): KotlinVersion? {
if (version != null) return version
return getVersion(context).apply { version = this }
return getVersion().apply { version = this }
}

internal fun setVersion(version: String): KotlinVersion? {
return getOrPutVersionFromVersionCache(version).also { this.version = it }
}

private fun getOrAllowedConfigKwargs(context: TypeEvalContext): Set<String>? {
if (allowedConfigKwargs != null) return allowedConfigKwargs
return getAllowedConfigKwargs(context).apply { allowedConfigKwargs = this }
Expand All @@ -55,16 +54,19 @@ class PydanticCacheService(val project: Project) {
allowedConfigKwargs = null
}

internal fun isV2(typeEvalContext: TypeEvalContext) = this.getOrPutVersion(typeEvalContext).isV2
internal val isV2 get() = this.getOrPutVersion().isV2

companion object {
fun getVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
return getInstance(project).getOrPutVersion(context)
fun getVersion(project: Project): KotlinVersion? {
return getInstance(project).getOrPutVersion()
}

fun setVersion(project: Project, version: String): KotlinVersion? {
return getInstance(project).setVersion(version)
}
fun getOrPutVersionFromVersionCache(project: Project, version: String): KotlinVersion? {
return getInstance(project).getOrPutVersionFromVersionCache(version)
}

fun getAllowedConfigKwargs(project: Project, context: TypeEvalContext): Set<String>? {
return getInstance(project).getOrAllowedConfigKwargs(context)
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class PydanticCompletionContributor : CompletionContributor() {
genericTypeMap: Map<PyGenericType, PyType>?,
withEqual: Boolean
) {
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, typeEvalContext)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
getClassVariables(pyClass, typeEvalContext)
.filter { it.name != null }
.filterNot { isUntouchedClass(it.findAssignedValue(), config, typeEvalContext) }
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PydanticIgnoreInspection : PyInspectionExtension() {
return function.containingClass?.let {
isPydanticModel(it,
true,
context) && function.isValidatorMethod
context) && function.isValidatorMethod(PydanticCacheService.getVersion(function.project))
} == true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PydanticInsertArgumentsQuickFix(private val onlyRequired: Boolean) : Local
}.nullize()?.toMap() ?: return null
val elementGenerator = PyElementGenerator.getInstance(project)
val ellipsis = elementGenerator.createEllipsis()
val pydanticVersion = PydanticCacheService.getVersion(project, context)
val pydanticVersion = PydanticCacheService.getVersion(project)
val fields = (listOf(pyClass) + getAncestorPydanticModels(pyClass, true, context)).flatMap {
it.classAttributes.filter { attribute -> unFilledArguments.contains(attribute.name) }
.mapNotNull { attribute -> attribute.name?.let { name -> name to attribute }}
Expand Down
8 changes: 4 additions & 4 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PydanticInspection : PyInspection() {
super.visitPyFunction(node)

if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return
if (!node.isValidatorMethod) return
if (!node.isValidatorMethod(pydanticCacheService.getOrPutVersion())) return
val paramList = node.parameterList
val params = paramList.parameters
val firstParam = params.firstOrNull()
Expand Down Expand Up @@ -87,7 +87,7 @@ class PydanticInspection : PyInspection() {
override fun visitPyClass(node: PyClass) {
super.visitPyClass(node)

if(pydanticCacheService.isV2(myTypeEvalContext)) {
if(pydanticCacheService.isV2) {
inspectCustomRootFieldV2(node)
}
inspectConfig(node)
Expand Down Expand Up @@ -217,7 +217,7 @@ class PydanticInspection : PyInspection() {
}

private fun inspectConfig(pyClass: PyClass) {
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
if (pydanticVersion?.isAtLeast(1, 8) != true) return
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
validateConfig(pyClass, myTypeEvalContext)?.forEach {
Expand All @@ -237,7 +237,7 @@ class PydanticInspection : PyInspection() {
val pyClass = pyClassType.pyClass
val attributeName = (node.leftHandSideExpression as? PyTargetExpressionImpl)?.name ?: return
val config = getConfig(pyClass, myTypeEvalContext, true)
val version = PydanticCacheService.getVersion(pyClass.project, myTypeEvalContext)
val version = PydanticCacheService.getVersion(pyClass.project)
if (config["allow_mutation"] == false || (version?.isAtLeast(1, 8) == true && config["frozen"] == true)) {
registerProblem(
node,
Expand Down
10 changes: 4 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticPackageManagerListener.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@ import com.intellij.openapi.project.ProjectManager
import com.intellij.openapi.projectRoots.Sdk
import com.intellij.openapi.util.Disposer
import com.jetbrains.python.packaging.PyPackageManager
import com.jetbrains.python.packaging.PyPackageManagers
import com.jetbrains.python.sdk.PythonSdkUtil
import com.jetbrains.python.statistics.sdks

class PydanticPackageManagerListener : PyPackageManager.Listener {
private fun updateVersion(sdk: Sdk) {
val version = PyPackageManagers.getInstance()
.forSdk(sdk).packages?.find { it.name == "pydantic" }?.version
val version = sdk.pydanticVersion
ProjectManager.getInstance().openProjects
.filter { it.sdks.contains(sdk) }
.forEach {
when (version) {
is String -> PydanticCacheService.setVersion(it, version)
else -> PydanticCacheService.clear(it)
PydanticCacheService.clear(it)
if (version is String) {
PydanticCacheService.getOrPutVersionFromVersionCache(it, version)
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class PydanticTypeProvider : PyTypeProviderBase() {
getRefTypeFromFieldName(name, context, pyClass)
}

param.isSelf && func.isValidatorMethod -> {
param.isSelf && func.isValidatorMethod(PydanticCacheService.getVersion(func.project)
) -> {
val pyClass = func.containingClass ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
context.getType(pyClass)
Expand All @@ -103,7 +104,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): PyType? {
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()

val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion)
?: getAncestorPydanticModels(pyClass, false, context).firstNotNullOfOrNull { ancestor ->
getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion)
Expand Down Expand Up @@ -298,7 +299,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
): PydanticDynamicModelClassType? {
val project = pyFunction.project
val typed = getInstance(project).currentInitTyped
val pydanticVersion = PydanticCacheService.getVersion(pyFunction.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyFunction.project)
val collected = linkedMapOf<String, PydanticDynamicModel.Attribute>()
val newVersion = pydanticVersion == null || pydanticVersion.isAtLeast(1, 5)
val modelNameParameterName = if (newVersion) "__model_name" else "model_name"
Expand Down Expand Up @@ -494,7 +495,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
}
val genericTypeMap = getGenericTypeMap(pyClass, context, pyCallExpression)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, context)
val pydanticVersion = PydanticCacheService.getVersion(pyClass.project)
val config = getConfig(pyClass, context, true)
for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) {
if (currentType !is PyClassType) continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import com.jetbrains.python.PythonLanguage
import com.jetbrains.python.codeInsight.PyCodeInsightSettings
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.psi.impl.PyPsiUtils
import com.jetbrains.python.psi.types.TypeEvalContext
import java.util.regex.Pattern

class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() {
Expand Down Expand Up @@ -52,7 +53,7 @@ class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() {
val defNode = maybeDef.node
if (defNode != null && defNode.elementType === PyTokenTypes.DEF_KEYWORD) {
val pyFunction = token.parent as? PyFunction ?: return Result.CONTINUE
if (!pyFunction.isValidatorMethod) return Result.CONTINUE
if (!pyFunction.isValidatorMethod(PydanticCacheService.getVersion(project))) return Result.CONTINUE
val settings = CodeStyle.getLanguageSettings(file, PythonLanguage.getInstance())
val textToType = StringBuilder()
textToType.append("(")
Expand Down
45 changes: 45 additions & 0 deletions testData/inspectionv2/validatorSelf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pydantic import BaseModel, field_validator, model_validator

def check(func):
def inner():
func()
return inner

class A(BaseModel):
a: str
b: str
c: str
d: str
e: str

@field_validator('a')
def validate_a(<weak_warning descr="Usually first parameter of such methods is named 'cls'">self</weak_warning>):
pass

@field_validator('b')
def validate_b(<weak_warning descr="Usually first parameter of such methods is named 'cls'">fles</weak_warning>):
pass

@field_validator('c')
def validate_b(*args):
pass

@field_validator('d')
def validate_c(**kwargs):
pass

@field_validator('e')
def validate_e<error descr="Method must have a first parameter, usually called 'cls'">()</error>:
pass

@model_validator()
def validate_model<error descr="Method must have a first parameter, usually called 'cls'">()</error>:
pass


def dummy(self):
pass

@check
def task(self):
pass
2 changes: 1 addition & 1 deletion testData/mock/pydanticv2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ValidationInfo,
ValidatorFunctionWrapHandler,
)

from .field_validator import field_validator, model_validator
from . import dataclasses
from .analyzed_type import AnalyzedType
from .config import BaseConfig, ConfigDict, Extra
Expand Down
15 changes: 15 additions & 0 deletions testData/mock/pydanticv2/functional_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

def field_validator(
__field: str,
*fields: str,
mode: FieldValidatorModes = 'after',
check_fields: bool | None = None,
) -> Callable[[Any], Any]:
pass


def model_validator(
*,
mode: Literal['wrap', 'before', 'after'],
) -> Any:
pass
3 changes: 3 additions & 0 deletions testSrc/com/koxudaxi/pydantic/PydanticInspectionV2Test.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ open class PydanticInspectionV2Test : PydanticInspectionBase(version = "v2") {
pydanticConfigService.mypyWarnUntypedFields = false
doTest()
}
fun testValidatorSelf() {
doTest()
}
}
Loading