diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index 4a3ad1e5..1279c273 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -281,8 +281,8 @@ fun getPydanticVersion(project: Project, context: TypeEvalContext): KotlinVersio val version = getPsiElementByQualifiedName(VERSION_QUALIFIED_NAME, project, context) as? PyTargetExpression ?: return null val versionString = (version.findAssignedValue()?.lastChild?.firstChild?.nextSibling as? PyStringLiteralExpression)?.stringValue - ?: return null - return pydanticVersionCache.getOrPut(versionString, { + ?:(version.findAssignedValue() as? PyStringLiteralExpressionImpl)?.stringValue ?: return null + return pydanticVersionCache.getOrPut(versionString) { val versionList = versionString.split(VERSION_SPLIT_PATTERN).map { it.toIntOrNull() ?: 0 } val pydanticVersion = when { versionList.size == 1 -> KotlinVersion(versionList[0], 0) @@ -292,7 +292,7 @@ fun getPydanticVersion(project: Project, context: TypeEvalContext): KotlinVersio } ?: KotlinVersion(0, 0) pydanticVersionCache[versionString] = pydanticVersion pydanticVersion - }) + } } fun isValidField(field: PyTargetExpression, context: TypeEvalContext): Boolean { diff --git a/src/com/koxudaxi/pydantic/PydanticDynamicModel.kt b/src/com/koxudaxi/pydantic/PydanticDynamicModel.kt index 898d0ad9..0a17dc35 100644 --- a/src/com/koxudaxi/pydantic/PydanticDynamicModel.kt +++ b/src/com/koxudaxi/pydantic/PydanticDynamicModel.kt @@ -1,15 +1,39 @@ package com.koxudaxi.pydantic +import com.intellij.icons.AllIcons import com.intellij.lang.ASTNode +import com.jetbrains.python.codeInsight.PyCustomMember import com.jetbrains.python.psi.PyClass +import com.jetbrains.python.psi.PyElement +import com.jetbrains.python.psi.PyExpression import com.jetbrains.python.psi.impl.PyClassImpl +import com.jetbrains.python.psi.types.PyCallableParameter import com.jetbrains.python.psi.types.PyClassLikeType import com.jetbrains.python.psi.types.TypeEvalContext -class PydanticDynamicModel(astNode: ASTNode, val baseModel: PyClass) : PyClassImpl(astNode) { +class PydanticDynamicModel(astNode: ASTNode, val baseModel: PyClass, val attributes: Map) : PyClassImpl(astNode) { + val members: List = attributes.values.map { it.pyCustomMember } + private val memberResolver: Map = attributes.entries.filterNot { it.value.isInAncestor } .associate { it.key to it.value.pyElement } + + fun resolveMember(name: String): PyElement? = memberResolver[name] + override fun getSuperClassTypes(context: TypeEvalContext): MutableList { return baseModel.getType(context)?.let { mutableListOf(it) } ?: mutableListOf() } + data class Attribute(val pyCallableParameter: PyCallableParameter, val pyCustomMember: PyCustomMember, val pyElement: PyElement, val isInAncestor: Boolean) + + companion object { + fun createAttribute(name: String, parameter: PyCallableParameter, originalPyExpression: PyExpression, context: TypeEvalContext, isInAncestor: Boolean): Attribute { + val type = parameter.getType(context) + return Attribute(parameter, + PyCustomMember(name, null) { type } + .toPsiElement(originalPyExpression) + .withIcon(AllIcons.Nodes.Field), + originalPyExpression, + isInAncestor + ) + } + } } \ No newline at end of file diff --git a/src/com/koxudaxi/pydantic/PydanticDynamicModelClassType.kt b/src/com/koxudaxi/pydantic/PydanticDynamicModelClassType.kt index 8fd85018..d5191345 100644 --- a/src/com/koxudaxi/pydantic/PydanticDynamicModelClassType.kt +++ b/src/com/koxudaxi/pydantic/PydanticDynamicModelClassType.kt @@ -1,10 +1,15 @@ package com.koxudaxi.pydantic - -import com.jetbrains.python.codeInsight.PyCustomMember -import com.jetbrains.python.psi.* +import com.jetbrains.python.psi.types.PyCallableType +import com.jetbrains.python.psi.types.PyCallableTypeImpl import com.jetbrains.python.psi.types.PyClassTypeImpl -class PydanticDynamicModelClassType(source: PyClass, isDefinition: Boolean, val members: List, private val memberResolver: Map) : PyClassTypeImpl(source, isDefinition) { - fun resolveMember(name: String): PyElement? = memberResolver[name] +class PydanticDynamicModelClassType(private val source: PydanticDynamicModel, isDefinition: Boolean) : + PyClassTypeImpl(source, isDefinition) { + + val pyCallableType: PyCallableType + get() = PyCallableTypeImpl( + source.attributes.values.map { attribute -> attribute.pyCallableParameter }, + this.toInstance() + ) } \ No newline at end of file diff --git a/src/com/koxudaxi/pydantic/PydanticDynamicModelMemberProvider.kt b/src/com/koxudaxi/pydantic/PydanticDynamicModelMemberProvider.kt index 816e89d6..36a9772c 100644 --- a/src/com/koxudaxi/pydantic/PydanticDynamicModelMemberProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticDynamicModelMemberProvider.kt @@ -7,14 +7,19 @@ import com.jetbrains.python.psi.types.* class PydanticDynamicModelMemberProvider : PyClassMembersProviderBase() { override fun resolveMember(type: PyClassType, name: String, location: PsiElement?, resolveContext: PyResolveContext): PsiElement? { - if (type is PydanticDynamicModelClassType) { - type.resolveMember(name)?.let { return it } - } + val pyClass = type.pyClass + if (pyClass is PydanticDynamicModel && !type.isDefinition) + pyClass.resolveMember(name)?.let { return it } return super.resolveMember(type, name, location, resolveContext) } override fun getMembers(clazz: PyClassType?, location: PsiElement?, context: TypeEvalContext): MutableCollection { - if (clazz !is PydanticDynamicModelClassType) return mutableListOf() - return clazz.members.toMutableList() + if (clazz == null || clazz.isDefinition) return mutableListOf() + val pyClass = clazz.pyClass + return if (pyClass is PydanticDynamicModel) { + pyClass.members.toMutableList() + } else { + mutableListOf() + } } } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 37e34798..4e278ecb 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -1,11 +1,8 @@ package com.koxudaxi.pydantic -import com.intellij.icons.AllIcons import com.intellij.openapi.util.Ref import com.intellij.psi.PsiElement import com.intellij.psi.util.PsiTreeUtil -import com.intellij.psi.util.QualifiedName -import com.jetbrains.python.codeInsight.PyCustomMember import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.* import com.jetbrains.python.psi.types.* @@ -13,20 +10,35 @@ import com.koxudaxi.pydantic.PydanticConfigService.Companion.getInstance import one.util.streamex.StreamEx class PydanticTypeProvider : PyTypeProviderBase() { - - override fun getReferenceExpressionType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? { + override fun getReferenceExpressionType( + referenceExpression: PyReferenceExpression, + context: TypeEvalContext + ): PyType? { return getPydanticTypeForCallee(referenceExpression, context) } - override fun getCallType(pyFunction: PyFunction, callSite: PyCallSiteExpression, context: TypeEvalContext): Ref? { + override fun getCallType( + pyFunction: PyFunction, + callSite: PyCallSiteExpression, + context: TypeEvalContext + ): Ref? { return when (pyFunction.qualifiedName) { - CON_LIST_Q_NAME -> Ref.create(createConListPyType(callSite, context) - ?: PyCollectionTypeImpl.createTypeByQName(callSite as PsiElement, LIST_Q_NAME, true)) + CON_LIST_Q_NAME -> Ref.create( + createConListPyType(callSite, context) + ?: PyCollectionTypeImpl.createTypeByQName(callSite as PsiElement, LIST_Q_NAME, true) + ) + CREATE_MODEL -> Ref.create( + getPydanticDynamicModelTypeForFunction(pyFunction, callSite.getArguments(null), context) + ) else -> null } } - override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref? { + override fun getReferenceType( + referenceTarget: PsiElement, + context: TypeEvalContext, + anchor: PsiElement? + ): Ref? { if (referenceTarget is PyTargetExpression) { val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null if (!isPydanticModel(pyClass, false, context)) return null @@ -53,8 +65,15 @@ class PydanticTypeProvider : PyTypeProviderBase() { } } - private fun getRefTypeFromFieldNameInPyClass(name: String, pyClass: PyClass, context: TypeEvalContext, ellipsis: PyNoneLiteralExpression, pydanticVersion: KotlinVersion?): Ref? { - return pyClass.findClassAttribute(name, false, context)?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion) } + private fun getRefTypeFromFieldNameInPyClass( + name: String, + pyClass: PyClass, + context: TypeEvalContext, + ellipsis: PyNoneLiteralExpression, + pydanticVersion: KotlinVersion? + ): Ref? { + return pyClass.findClassAttribute(name, false, context) + ?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion) } } private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): Ref? { @@ -62,54 +81,73 @@ class PydanticTypeProvider : PyTypeProviderBase() { val pydanticVersion = getPydanticVersion(pyClass.project, context) return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion) - ?: pyClass.getAncestorClasses(context) - .filter { isPydanticModel(it, false, context) } - .mapNotNull { ancestor -> - getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion) - }.firstOrNull() + ?: pyClass.getAncestorClasses(context) + .filter { isPydanticModel(it, false, context) } + .mapNotNull { ancestor -> + getRefTypeFromFieldNameInPyClass(name, ancestor, context, ellipsis, pydanticVersion) + }.firstOrNull() } - private fun getRefTypeFromField(pyTargetExpression: PyTargetExpression, ellipsis: PyNoneLiteralExpression, - context: TypeEvalContext, pyClass: PyClass, - pydanticVersion: KotlinVersion?): Ref? { - return fieldToParameter(pyTargetExpression, ellipsis, context, pyClass, pydanticVersion, getConfig(pyClass, context, true)) - ?.let { parameter -> Ref.create(parameter.getType(context)) } + private fun getRefTypeFromField( + pyTargetExpression: PyTargetExpression, ellipsis: PyNoneLiteralExpression, + context: TypeEvalContext, pyClass: PyClass, + pydanticVersion: KotlinVersion? + ): Ref? { + return fieldToParameter( + pyTargetExpression, + ellipsis, + context, + pyClass, + pydanticVersion, + getConfig(pyClass, context, true) + ) + ?.let { parameter -> Ref.create(parameter.getType(context)) } } - private fun getPydanticTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? { + private fun getPydanticTypeForCallee( + referenceExpression: PyReferenceExpression, + context: TypeEvalContext + ): PyType? { if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null val resolveResults = getResolveElements(referenceExpression, context) return PyUtil.filterTopPriorityResults(resolveResults) - .asSequence() - .map { - when { - it is PyClass -> getPydanticTypeForClass(it, context, true) - it is PyParameter && it.isSelf -> { - PsiTreeUtil.getParentOfType(it, PyFunction::class.java) - ?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD } - ?.let { it.containingClass?.let { getPydanticTypeForClass(it, context) } } - } - it is PyNamedParameter -> it.getArgumentType(context)?.let { pyType -> - getPyClassTypeByPyTypes(pyType).filter { pyClassType -> - pyClassType.isDefinition - }.map { filteredPyClassType -> getPydanticTypeForClass(filteredPyClassType.pyClass, context, true) }.firstOrNull() - } - it is PyTargetExpression -> (it as? PyTypedElement) - ?.let { pyTypedElement -> - context.getType(pyTypedElement) - ?.let { pyType -> getPyClassTypeByPyTypes(pyType) } - ?.filter { pyClassType -> pyClassType.isDefinition } - ?.map { filteredPyClassType -> - getPydanticTypeForClass(filteredPyClassType.pyClass, context, true) - }?.firstOrNull() - } ?: getPydanticDynamicModelTypeForTargetExpression(it, context, true) - else -> null + .asSequence() + .map { + when { + it is PyClass -> getPydanticTypeForClass(it, context, true) + it is PyParameter && it.isSelf -> { + PsiTreeUtil.getParentOfType(it, PyFunction::class.java) + ?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD } + ?.let { it.containingClass?.let { getPydanticTypeForClass(it, context) } } + } + it is PyNamedParameter -> it.getArgumentType(context)?.let { pyType -> + getPyClassTypeByPyTypes(pyType).filter { pyClassType -> + pyClassType.isDefinition + }.map { filteredPyClassType -> + getPydanticTypeForClass( + filteredPyClassType.pyClass, + context, + true + ) + }.firstOrNull() } + it is PyTargetExpression -> (it as? PyTypedElement) + ?.let { pyTypedElement -> + context.getType(pyTypedElement) + ?.let { pyType -> getPyClassTypeByPyTypes(pyType) } + ?.filter { pyClassType -> pyClassType.isDefinition } + ?.filterNot { pyClassType -> pyClassType is PydanticDynamicModelClassType } + ?.map { filteredPyClassType -> + getPydanticTypeForClass(filteredPyClassType.pyClass, context, true) + }?.firstOrNull() + } ?: getPydanticDynamicModelTypeForTargetExpression(it, context)?.pyCallableType + else -> null } - .firstOrNull() + } + .firstOrNull() } @@ -121,100 +159,150 @@ class PydanticTypeProvider : PyTypeProviderBase() { // TODO support PySubscriptionExpression val typeArgumentListType = context.getType(typeArgumentList) ?: return null val typeArgumentListReturnType = (typeArgumentListType as? PyCallableType)?.getReturnType(context) - ?: return null - return PyCollectionTypeImpl.createTypeByQName(pyCallExpression as PsiElement, LIST_Q_NAME, true, listOf(typeArgumentListReturnType)) + ?: return null + return PyCollectionTypeImpl.createTypeByQName( + pyCallExpression as PsiElement, + LIST_Q_NAME, + true, + listOf(typeArgumentListReturnType) + ) + } + + + private fun getPydanticDynamicModelPyClass( + pyTargetExpression: PyTargetExpression, + context: TypeEvalContext + ): PyClass? { + val pyCallableType = getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression, context) + ?: return null + return pyCallableType.pyClass } - private fun getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression: PyTargetExpression, context: TypeEvalContext, init: Boolean = false): PyCallableType? { + private fun getPydanticDynamicModelTypeForTargetExpression( + pyTargetExpression: PyTargetExpression, + context: TypeEvalContext + ): PydanticDynamicModelClassType? { val pyCallExpression = pyTargetExpression.findAssignedValue() as? PyCallExpression ?: return null - return getPydanticDynamicModelTypeForTargetExpression(pyCallExpression, context, init) + return getPydanticDynamicModelTypeForTargetExpression(pyCallExpression, context) } - private fun getPydanticDynamicModelTypeForTargetExpression(pyCallExpression: PyCallExpression, context: TypeEvalContext, init: Boolean = false): PyCallableType? { - val argumentList = pyCallExpression.argumentList ?: return null + private fun getPydanticDynamicModelTypeForTargetExpression( + pyCallExpression: PyCallExpression, + context: TypeEvalContext + ): PydanticDynamicModelClassType? { + val arguments = pyCallExpression.arguments.toList() + if (arguments.isEmpty()) return null val referenceExpression = (pyCallExpression.callee as? PyReferenceExpression) ?: return null val resolveResults = getResolveElements(referenceExpression, context) - val pyFunction = PyUtil.filterTopPriorityResults(resolveResults).asSequence().filterIsInstance().map { it.takeIf { pyFunction -> isPydanticCreateModel(pyFunction) } }.firstOrNull() - ?: return null + val pyFunction = PyUtil.filterTopPriorityResults(resolveResults) + .asSequence() + .filterIsInstance() + .map { it.takeIf { pyFunction -> isPydanticCreateModel(pyFunction) } }.firstOrNull() + ?: return null - return getPydanticDynamicModelTypeForFunction(pyFunction, argumentList, context, init) + return getPydanticDynamicModelTypeForFunction(pyFunction, arguments, context) } - private fun getPydanticDynamicModelTypeForFunction(pyFunction: PyFunction, pyArgumentList: PyArgumentList, context: TypeEvalContext, init: Boolean = false): PyCallableType? { + private fun getPydanticDynamicModelTypeForFunction( + pyFunction: PyFunction, + pyArguments: List, + context: TypeEvalContext + ): PydanticDynamicModelClassType? { val project = pyFunction.project - val typed = !init || getInstance(project).currentInitTyped - val collected = linkedMapOf>() + val typed = getInstance(project).currentInitTyped + val pydanticVersion = getPydanticVersion(pyFunction.project, context) + val collected = linkedMapOf() + val newVersion = pydanticVersion == null || pydanticVersion.isAtLeast(1, 5) + val modelNameParameterName = if (newVersion) "__model_name" else "model_name" + + val keywordArguments: Map = pyArguments + .filter { it is PyKeywordArgument || (it as? PyStarArgumentImpl)?.isKeyword == true } + .mapNotNull { it.name?.let { name -> name to it } } + .toMap() + val modelNameArgument = if (pyArguments.size == keywordArguments.size) { + // TODO: Support model name on StartArgument + (keywordArguments[modelNameParameterName] as? PyKeywordArgument)?.valueExpression + } else { + pyArguments.firstOrNull() + } ?: return null + val modelName = when (modelNameArgument) { + is PyReferenceExpression -> PyUtil.filterTopPriorityResults(getResolveElements(modelNameArgument, context)) + .filterIsInstance() + .map { it.findAssignedValue() } + .firstOrNull() + .let { PyPsiUtils.strValue(it) } + else -> PyPsiUtils.strValue(modelNameArgument) + } ?: return null // TODO get config // val config = getConfig(pyClass, context, true) - val baseClass = when (val baseArgument = pyArgumentList.getKeywordArgument("__base__")?.valueExpression) { - is PyReferenceExpression -> { - PyUtil.filterTopPriorityResults(getResolveElements(baseArgument, context)) - .filterIsInstance().firstOrNull { isPydanticModel(it, false, context) } - } - is PyClass -> baseArgument - else -> null - }?.let { baseClass -> - val baseClassCollected = linkedMapOf>() - (context.getType(baseClass) as? PyClassLikeType).let { baseClassType -> - for (currentType in StreamEx.of(baseClassType).append(baseClass.getAncestorTypes(context))) { - if (currentType !is PyClassType) continue - val current = currentType.pyClass - if (!isPydanticModel(current, false, context)) continue - getClassVariables(current, context) - .map { Pair(fieldToParameter(it, context, hashMapOf(), typed), it) } - .filter { (parameter, _) -> parameter?.name?.let { !collected.containsKey(it) } ?: false } - .forEach { (parameter, field) -> - parameter?.name?.let { name -> - val type = parameter.getType(context) - val member = PyCustomMember(name, null) { type } - .toPsiElement(field) - .withIcon(AllIcons.Nodes.Field) - baseClassCollected[name] = Triple(parameter, member, field) - } + // TODO: Support __base__ on StartArgument + val baseClass = + when (val baseArgument = (keywordArguments["__base__"] as? PyKeywordArgument)?.valueExpression) { + is PyReferenceExpression -> { + PyUtil.filterTopPriorityResults(getResolveElements(baseArgument, context)) + .map { + when (it) { + is PyTargetExpression -> getPydanticDynamicModelPyClass(it, context) + is PyClass -> it.takeIf { isPydanticModel(it, false, context) } + else -> null } + }.firstOrNull() } + is PyClass -> baseArgument.takeIf { isPydanticModel(baseArgument, false, context) } as? PyClass + else -> null } - baseClassCollected.entries.reversed().forEach { - collected[it.key] = it.value - } - baseClass - } ?: getPydanticBaseModel(project, context) ?: return null - var modelNameIsPositionalArgument = true - val modelNameArgument = pyArgumentList.getKeywordArgument("__model_name")?.valueExpression?.apply { - modelNameIsPositionalArgument = false - } ?: pyArgumentList.arguments.firstOrNull() ?: return null - val modelName = when (modelNameArgument) { - is PyReferenceExpression -> PyUtil.filterTopPriorityResults(getResolveElements(modelNameArgument, context)) - .filterIsInstance() - .map { it.findAssignedValue() } - .firstOrNull() - .let { PyPsiUtils.strValue(it) } - else -> PyPsiUtils.strValue(modelNameArgument) - } ?: return null - val langLevel = LanguageLevel.forElement(pyFunction) - val dynamicModelClassText = "class ${modelName}: pass" - val modelClass = PydanticDynamicModel(PyElementGenerator.getInstance(project).createFromText(langLevel, PyClass::class.java, dynamicModelClassText).node, baseClass) - val argumentWithoutModelName = when (modelNameIsPositionalArgument) { - true -> pyArgumentList.arguments.asSequence().drop(1) - else -> pyArgumentList.arguments.asSequence() - } - argumentWithoutModelName - .filter { it is PyKeywordArgument || (it as? PyStarArgumentImpl)?.isKeyword == true } - .filter { isValidFieldName(it.name) || it.name != "model_name" } - .forEach { - val parameter = fieldToParameter(it, context, hashMapOf(), typed)!! - parameter.name?.let { name -> - val type = parameter.getType(context) - val member = PyCustomMember(name, null) { type } - .toPsiElement(it) - .withIcon(AllIcons.Nodes.Field) - collected[name] = Triple(parameter, member, it) + ?.let { baseClass -> + val baseClassCollected = linkedMapOf() + (context.getType(baseClass) as? PyClassLikeType).let { baseClassType -> + for (currentType in StreamEx.of(baseClassType).append(baseClass.getAncestorTypes(context))) { + if (currentType !is PyClassType) continue + val current = currentType.pyClass + if (!isPydanticModel(current, false, context)) continue + if (current is PydanticDynamicModel) { + baseClassCollected.putAll(current.attributes) + continue + } + baseClassCollected.putAll(getClassVariables(current, context) + .mapNotNull { it to fieldToParameter(it, context, hashMapOf(), typed) } + .mapNotNull { (field, parameter) -> + parameter?.name?.let { name -> Triple(field, parameter, name) } + } + .filterNot { (_, _, name) -> collected.containsKey(name) } + .map { (field, parameter, name) -> + name to PydanticDynamicModel.createAttribute(name, + parameter, + field, + context, + true) + } + ) + + } } - } + collected.putAll(baseClassCollected.entries.reversed().map { it.key to it.value }) + baseClass + } ?: getPydanticBaseModel(project, context) ?: return null + + collected.putAll(keywordArguments + .filter { (name, _) -> isValidFieldName(name) && !name.startsWith('_') } + .filter { (name, _) -> (newVersion || name != "model_name") } + .map { (name, field) -> + val parameter = fieldToParameter(field, context, hashMapOf(), typed)!! + name to PydanticDynamicModel.createAttribute(name, parameter, field, context, false) + } + ) - val modelClassType = PydanticDynamicModelClassType(modelClass, false, collected.values.map { it.second }, collected.entries.map { it.key to it.value.third }.toMap()) - return PyCallableTypeImpl(collected.values.map { it.first }, modelClassType.toInstance()) + return PydanticDynamicModelClassType( + PydanticDynamicModel( + PyElementGenerator.getInstance(project) + .createFromText(LanguageLevel.forElement(pyFunction), + PyClass::class.java, + "class ${modelName}: pass").node, + baseClass, + collected + ), + true) } fun getPydanticTypeForClass(pyClass: PyClass, context: TypeEvalContext, init: Boolean = false): PyCallableType? { @@ -233,10 +321,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { if (!isPydanticModel(current, false, context)) continue getClassVariables(current, context) - .filterNot { isUntouchedClass(it.findAssignedValue(), config, context) } - .mapNotNull { fieldToParameter(it, ellipsis, context, current, pydanticVersion, config, typed) } - .filter { parameter -> parameter.name?.let { !collected.containsKey(it) } ?: false } - .forEach { parameter -> collected[parameter.name!!] = parameter } + .filterNot { isUntouchedClass(it.findAssignedValue(), config, context) } + .mapNotNull { fieldToParameter(it, ellipsis, context, current, pydanticVersion, config, typed) } + .filter { parameter -> parameter.name?.let { !collected.containsKey(it) } ?: false } + .forEach { parameter -> collected[parameter.name!!] = parameter } } return PyCallableTypeImpl(collected.values.reversed(), clsType.toInstance()) } @@ -246,14 +334,16 @@ class PydanticTypeProvider : PyTypeProviderBase() { } - internal fun fieldToParameter(field: PyTargetExpression, - ellipsis: PyNoneLiteralExpression, - context: TypeEvalContext, - pyClass: PyClass, - pydanticVersion: KotlinVersion?, - config: HashMap, - typed: Boolean = true, - isDataclass: Boolean = false): PyCallableParameter? { + internal fun fieldToParameter( + field: PyTargetExpression, + ellipsis: PyNoneLiteralExpression, + context: TypeEvalContext, + pyClass: PyClass, + pydanticVersion: KotlinVersion?, + config: HashMap, + typed: Boolean = true, + isDataclass: Boolean = false + ): PyCallableParameter? { if (!isValidField(field, context)) return null if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax @@ -276,16 +366,18 @@ class PydanticTypeProvider : PyTypeProviderBase() { } return PyCallableParameterImpl.nonPsi( - getFieldName(field, context, config, pydanticVersion), - typeForParameter, - defaultValue + getFieldName(field, context, config, pydanticVersion), + typeForParameter, + defaultValue ) } - internal fun fieldToParameter(field: PyExpression, - context: TypeEvalContext, - config: HashMap, - typed: Boolean = true): PyCallableParameter? { + internal fun fieldToParameter( + field: PyExpression, + context: TypeEvalContext, + config: HashMap, + typed: Boolean = true + ): PyCallableParameter? { var type: PyType? = null var defaultValue: PyExpression? = null when (val tupleValue = PsiTreeUtil.findChildOfType(field, PyTupleExpression::class.java)) { @@ -310,24 +402,28 @@ class PydanticTypeProvider : PyTypeProviderBase() { } return PyCallableParameterImpl.nonPsi( - field.name, + field.name, // getFieldName(field, context, config, pydanticVersion), - typeForParameter, - defaultValue + typeForParameter, + defaultValue ) } - private fun getTypeForParameter(field: PyTargetExpression, - context: TypeEvalContext): PyType? { + private fun getTypeForParameter( + field: PyTargetExpression, + context: TypeEvalContext + ): PyType? { return context.getType(field) } - private fun getDefaultValueForParameter(field: PyTargetExpression, - ellipsis: PyNoneLiteralExpression, - context: TypeEvalContext, - pydanticVersion: KotlinVersion?, - isDataclass: Boolean): PyExpression? { + private fun getDefaultValueForParameter( + field: PyTargetExpression, + ellipsis: PyNoneLiteralExpression, + context: TypeEvalContext, + pydanticVersion: KotlinVersion?, + isDataclass: Boolean + ): PyExpression? { val value = field.findAssignedValue() if (value is PyExpression) { @@ -335,7 +431,7 @@ class PydanticTypeProvider : PyTypeProviderBase() { } val annotationValue = field.annotation?.value ?: return null - fun parseAnnotation(pyExpression: PyExpression, context: TypeEvalContext) :PyExpression? { + fun parseAnnotation(pyExpression: PyExpression, context: TypeEvalContext): PyExpression? { val qualifiedName = getQualifiedName(pyExpression, context) ?: return null when (qualifiedName) { ANY_Q_NAME -> return ellipsis @@ -361,11 +457,13 @@ class PydanticTypeProvider : PyTypeProviderBase() { } - private fun getDefaultValueByAssignedValue(field: PyTargetExpression, - ellipsis: PyNoneLiteralExpression, - context: TypeEvalContext, - pydanticVersion: KotlinVersion?, - isDataclass: Boolean): PyExpression? { + private fun getDefaultValueByAssignedValue( + field: PyTargetExpression, + ellipsis: PyNoneLiteralExpression, + context: TypeEvalContext, + pydanticVersion: KotlinVersion?, + isDataclass: Boolean + ): PyExpression? { val assignedValue = field.findAssignedValue()!! if (assignedValue.text == "...") { @@ -378,32 +476,32 @@ class PydanticTypeProvider : PyTypeProviderBase() { val resolveResults = getResolveElements(referenceExpression, context) if (isDataclass) { PyUtil.filterTopPriorityResults(resolveResults) - .any { - isDataclassFieldByPsiElement(it) - } - .let { - return when { - it -> getDefaultValueForDataclass(assignedValue, context) - else -> assignedValue - } + .any { + isDataclassFieldByPsiElement(it) + } + .let { + return when { + it -> getDefaultValueForDataclass(assignedValue, context) + else -> assignedValue } + } } else { val versionZero = pydanticVersion?.major == 0 PyUtil.filterTopPriorityResults(resolveResults) - .any { - when { - versionZero -> isPydanticSchemaByPsiElement(it, context) - else -> isPydanticFieldByPsiElement(it) - } - + .any { + when { + versionZero -> isPydanticSchemaByPsiElement(it, context) + else -> isPydanticFieldByPsiElement(it) } - .let { - return when { - it -> getDefaultValue(assignedValue) - else -> assignedValue - } + + } + .let { + return when { + it -> getDefaultValue(assignedValue) + else -> assignedValue } + } } } @@ -415,7 +513,11 @@ class PydanticTypeProvider : PyTypeProviderBase() { return getDefaultFromField(assignedValue)?.takeIf { it.text != "..." } } - private fun getDefaultValueForDataclass(assignedValue: PyCallExpression, context: TypeEvalContext, argumentName: String): PyExpression? { + private fun getDefaultValueForDataclass( + assignedValue: PyCallExpression, + context: TypeEvalContext, + argumentName: String + ): PyExpression? { val defaultValue = assignedValue.getKeywordArgument(argumentName) return when { defaultValue == null -> null diff --git a/testData/mock/pydanticv1/main.py b/testData/mock/pydanticv1/main.py index 44f96b01..5ce60cad 100644 --- a/testData/mock/pydanticv1/main.py +++ b/testData/mock/pydanticv1/main.py @@ -1,3 +1,5 @@ +from typing import * + class BaseModel: class Config: pass @@ -8,6 +10,7 @@ class Config: def from_orm(cls, obj): pass + class Extra(str): allow = 'allow' ignore = 'ignore' @@ -49,12 +52,12 @@ class BaseConfig: json_encoders = {} def create_model( - model_name: str, - *, - __config__: Type[BaseConfig] = None, - __base__: Type[BaseModel] = None, - __module__: Optional[str] = None, - __validators__: Dict[str, classmethod] = None, - **field_definitions: Any, + model_name: str, + *, + __config__: Type[BaseConfig] = None, + __base__: Type[BaseModel] = None, + __module__: Optional[str] = None, + __validators__: Dict[str, classmethod] = None, + **field_definitions: Any, ) -> Type[BaseModel]: pass \ No newline at end of file diff --git a/testData/mock/pydanticv18/__init__.py b/testData/mock/pydanticv18/__init__.py new file mode 100644 index 00000000..2a22843e --- /dev/null +++ b/testData/mock/pydanticv18/__init__.py @@ -0,0 +1,6 @@ +from .main import BaseModel, BaseConfig, create_model +from .class_validators import validator, root_validator +from .fields import Field, Schema +from .env_settings import BaseSettings +from .networks import * +from .types import * diff --git a/testData/mock/pydanticv18/class_validators.py b/testData/mock/pydanticv18/class_validators.py new file mode 100644 index 00000000..872d7a74 --- /dev/null +++ b/testData/mock/pydanticv18/class_validators.py @@ -0,0 +1,6 @@ +def validator(*args: str, **kwargs: str): + pass + + +def root_validator(_func=None, *, pre=False): + pass diff --git a/testData/mock/pydanticv18/dataclasses.py b/testData/mock/pydanticv18/dataclasses.py new file mode 100644 index 00000000..092e1cd8 --- /dev/null +++ b/testData/mock/pydanticv18/dataclasses.py @@ -0,0 +1,2 @@ +def dataclass(): + pass \ No newline at end of file diff --git a/testData/mock/pydanticv18/env_settings.py b/testData/mock/pydanticv18/env_settings.py new file mode 100644 index 00000000..0ef1a356 --- /dev/null +++ b/testData/mock/pydanticv18/env_settings.py @@ -0,0 +1,5 @@ +from .main import BaseModel + + +class BaseSettings(BaseModel): + pass diff --git a/testData/mock/pydanticv18/fields.py b/testData/mock/pydanticv18/fields.py new file mode 100644 index 00000000..f02946fc --- /dev/null +++ b/testData/mock/pydanticv18/fields.py @@ -0,0 +1,25 @@ +def Field( + default, + *, + alias: str = None, + title: str = None, + description: str = None, + const: bool = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + min_items: int = None, + max_items: int = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra, +): + + pass + + +def Schema(*args, **kwargs): + return Field(*args, **kwargs) diff --git a/testData/mock/pydanticv18/main.py b/testData/mock/pydanticv18/main.py new file mode 100644 index 00000000..e7f188af --- /dev/null +++ b/testData/mock/pydanticv18/main.py @@ -0,0 +1,65 @@ +from typing import * + +class BaseModel: + class Config: + pass + + ___slots__ = () + + @classmethod + def from_orm(cls, obj): + pass + +Model = TypeVar('Model', bound='BaseModel') + + +class Extra(str): + allow = 'allow' + ignore = 'ignore' + forbid = 'forbid' + +class GetterDict: + pass + + +class json: + def loads(self): + pass + + def dumps(self): + pass + + +class BaseConfig: + title = None + anystr_strip_whitespace = False + min_anystr_length = None + max_anystr_length = None + validate_all = False + extra = Extra.ignore + allow_mutation = True + allow_population_by_field_name = False + use_enum_values = False + fields = {} + validate_assignment = False + error_msg_templates = {} + arbitrary_types_allowed = False + orm_mode: bool = False + getter_dict = GetterDict + alias_generator = None + keep_untouched = () + schema_extra = {} + json_loads = json.loads + json_dumps = json.dumps + json_encoders = {} + +def create_model( + __model_name: str, + *, + __config__: Type[BaseConfig] = None, + __base__: Type['Model'] = None, + __module__: str = __name__, + __validators__: Dict[str, classmethod] = None, + **field_definitions: Any, +) -> Type['Model']: + pass \ No newline at end of file diff --git a/testData/mock/pydanticv18/networks.py b/testData/mock/pydanticv18/networks.py new file mode 100644 index 00000000..a059166e --- /dev/null +++ b/testData/mock/pydanticv18/networks.py @@ -0,0 +1,4 @@ +class AnyUrl(str): + pass +class HttpUrl(AnyUrl): + pass \ No newline at end of file diff --git a/testData/mock/pydanticv18/types.py b/testData/mock/pydanticv18/types.py new file mode 100644 index 00000000..89ea7122 --- /dev/null +++ b/testData/mock/pydanticv18/types.py @@ -0,0 +1,15 @@ + +def conlist(item_type, *, min_items = None, max_items = None) Type[List[T]]: + pass + +def constr( + *, + strip_whitespace: bool = False, + to_lower: bool = False, + strict: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, +) -> Type[str]: + pass \ No newline at end of file diff --git a/testData/mock/pydanticv18/version.py b/testData/mock/pydanticv18/version.py new file mode 100644 index 00000000..ed905d10 --- /dev/null +++ b/testData/mock/pydanticv18/version.py @@ -0,0 +1,8 @@ +class StrictVersion: + def __init__(self, version): + self.version = version + + +__all__ = ['VERSION'] + +VERSION = '1.8' diff --git a/testData/typeinspection/dynamicModel.py b/testData/typeinspection/dynamicModel.py new file mode 100644 index 00000000..66cf12ab --- /dev/null +++ b/testData/typeinspection/dynamicModel.py @@ -0,0 +1,57 @@ +from pydantic import BaseModel, create_model + +DynamicFoobarModel = create_model('DynamicFoobarModel', foo=(str, ...), bar=123) + +class StaticFoobarModel(BaseModel): + foo: str + bar: int = 123 + + +DynamicFoobarModel(foo='name', bar=123) +DynamicFoobarModel(foo=123, bar='name') + +BarModel = create_model( + model_name='BarModel', + apple='russet', + banana='yellow', + __base__=StaticFoobarModel, +) + +BarModel(foo='name', bar=123, apple='green', banana='red') +BarModel(foo=123, bar='name', apple=123, banana=456) + +model_name = 'DynamicBarModel' +DynamicBarModel = create_model( + model_name, + apple='russet', + banana='yellow', + __base__=DynamicFoobarModel, +) + +DynamicBarModel(foo='name', bar=123, apple='green', banana='red') +DynamicBarModel(foo=123, bar='name', apple=123, banana=456) + +DynamicModifiedBarModel = create_model('DynamicModifiedFoobarModel', foo=(int, ...), bar='abc', __base__=DynamicBarModel) + +DynamicModifiedBarModel(foo=456, bar='efg', apple='green', banana='red') +DynamicModifiedBarModel(foo='123', bar=456, apple=123, banana=456) + + +DynamicBrokenModel = create_model( + 'DynamicBrokenModel', + apple='russet', + banana='yellow', + __base__=BrokenBase, +) + +DynamicBrokenModel(foo='name', bar=123, apple='green', banana='red') + +class PythonClass: + pass + +DynamicBrokenModel = create_model( + 'DynamicBrokenModel', + __base__=PythonClass, +) + +DynamicBrokenModel() \ No newline at end of file diff --git a/testData/typeinspectionv18/dynamicModel.py b/testData/typeinspectionv18/dynamicModel.py new file mode 100644 index 00000000..5bf9265b --- /dev/null +++ b/testData/typeinspectionv18/dynamicModel.py @@ -0,0 +1,37 @@ +from pydantic import BaseModel, create_model + +DynamicFoobarModel = create_model('DynamicFoobarModel', foo=(str, ...), bar=123) + +class StaticFoobarModel(BaseModel): + foo: str + bar: int = 123 + + +DynamicFoobarModel(foo='name', bar=123) +DynamicFoobarModel(foo=123, bar='name') + +BarModel = create_model( + __model_name='BarModel', + apple='russet', + banana='yellow', + __base__=StaticFoobarModel, +) + +BarModel(foo='name', bar=123, apple='green', banana='red') +BarModel(foo=123, bar='name', apple=123, banana=456) + +model_name = 'DynamicBarModel' +DynamicBarModel = create_model( + model_name, + apple='russet', + banana='yellow', + __base__=DynamicFoobarModel, +) + +DynamicBarModel(foo='name', bar=123, apple='green', banana='red') +DynamicBarModel(foo=123, bar='name', apple=123, banana=456) + +DynamicModifiedBarModel = create_model('DynamicModifiedFoobarModel', foo=(int, ...), bar='abc', __base__=DynamicBarModel) + +DynamicModifiedBarModel(foo=456, bar='efg', apple='green', banana='red') +DynamicModifiedBarModel(foo='123', bar=456, apple=123, banana=456) diff --git a/testSrc/com/koxudaxi/pydantic/PydanticInspectionBase.kt b/testSrc/com/koxudaxi/pydantic/PydanticInspectionBase.kt index d452b674..5bc65ff6 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticInspectionBase.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticInspectionBase.kt @@ -4,7 +4,7 @@ import com.jetbrains.python.inspections.PyInspection import kotlin.reflect.KClass -abstract class PydanticInspectionBase : PydanticTestCase() { +abstract class PydanticInspectionBase(version: String = "v1") : PydanticTestCase(version) { @Suppress("UNCHECKED_CAST") protected open val inspectionClass: KClass = PydanticInspection::class as KClass diff --git a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionTest.kt b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionTest.kt index c28dbfb1..386863ca 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionTest.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionTest.kt @@ -109,5 +109,9 @@ open class PydanticTypeInspectionTest : PydanticInspectionBase() { fun testSkipMember() { doTest() } + + fun testDynamicModel() { + doTest() + } } diff --git a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt new file mode 100644 index 00000000..eaca8cdd --- /dev/null +++ b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt @@ -0,0 +1,15 @@ +package com.koxudaxi.pydantic + +import com.jetbrains.python.inspections.PyInspection +import com.jetbrains.python.inspections.PyTypeCheckerInspection +import kotlin.reflect.KClass + + +open class PydanticTypeInspectionV18Test : PydanticInspectionBase("v18") { + + @Suppress("UNCHECKED_CAST") + override val inspectionClass: KClass = PyTypeCheckerInspection::class as KClass + fun testDynamicModel() { + doTest() + } +} \ No newline at end of file