Skip to content

Commit

Permalink
Add PydanticDataclassTypeProvider.kt
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed Apr 13, 2021
1 parent 6451d04 commit 93fb36d
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 35 deletions.
1 change: 1 addition & 0 deletions resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@
</projectListeners>
<extensions defaultExtensionNs="Pythonid">
<typeProvider implementation="com.koxudaxi.pydantic.PydanticTypeProvider"/>
<typeProvider implementation="com.koxudaxi.pydantic.PydanticDataclassTypeProvider"/>
<inspectionExtension implementation="com.koxudaxi.pydantic.PydanticIgnoreInspection"/>
<pyDataclassParametersProvider implementation="com.koxudaxi.pydantic.PydanticParametersProvider"/>
<pyAnnotator implementation="com.koxudaxi.pydantic.PydanticAnnotator"/>
Expand Down
20 changes: 10 additions & 10 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression):
val pyCallExpression = pyKeywordArgument.parent.parent as? PyCallExpression ?: return false
val referenceExpression = pyCallExpression.callee as? PyReferenceExpression ?: return false
val context = TypeEvalContext.userInitiated(referenceExpression.project, referenceExpression.containingFile)
val resolveResults = getResolveElements(referenceExpression, context)
return PyUtil.filterTopPriorityResults(resolveResults)
return getResolvedPsiElements(referenceExpression, context)
.filterIsInstance<PyFunction>()
.filter { pyFunction -> isPydanticField(pyFunction) || isConStr(pyFunction) }
.any()
Expand Down Expand Up @@ -238,6 +237,11 @@ fun getResolveElements(referenceExpression: PyReferenceExpression, context: Type

}


fun getResolvedPsiElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): List<PsiElement> {
return getResolveElements(referenceExpression, context).let { PyUtil.filterTopPriorityResults(it) }
}

fun getPyClassTypeByPyTypes(pyType: PyType): List<PyClassType> {
return when (pyType) {
is PyUnionType -> pyType.members.mapNotNull { it }.flatMap { getPyClassTypeByPyTypes(it) }
Expand Down Expand Up @@ -325,8 +329,7 @@ fun isValidFieldName(name: String?): Boolean {

fun getConfigValue(name: String, value: Any?, context: TypeEvalContext): Any? {
if (value is PyReferenceExpression) {
val resolveResults = getResolveElements(value, context)
val targetExpression = PyUtil.filterTopPriorityResults(resolveResults).firstOrNull() ?: return null
val targetExpression = getResolvedPsiElements(value, context).firstOrNull() ?: return null
val assignedValue = (targetExpression as? PyTargetExpression)?.findAssignedValue() ?: return null
return getConfigValue(name, assignedValue, context)
}
Expand Down Expand Up @@ -481,8 +484,7 @@ fun getPyTypeFromPyExpression(pyExpression: PyExpression, context: TypeEvalConte
return when (pyExpression) {
is PyType -> pyExpression
is PyReferenceExpression -> {
val resolveResults = getResolveElements(pyExpression, context)
PyUtil.filterTopPriorityResults(resolveResults)
getResolvedPsiElements(pyExpression, context)
.filterIsInstance<PyClass>()
.map { pyClass -> pyClass.getType(context)?.getReturnType(context) }
.firstOrNull()
Expand Down Expand Up @@ -522,9 +524,8 @@ internal fun getFieldFromPyExpression(
val callee = (psiElement as? PyCallExpression)
?.let { it.callee as? PyReferenceExpression }
?: return null
val results = getResolveElements(callee, context)
val versionZero = pydanticVersion?.major == 0
if (!PyUtil.filterTopPriorityResults(results).any {
if (!getResolvedPsiElements(callee, context).any {
when {
versionZero -> isPydanticSchemaByPsiElement(it, context)
else -> isPydanticFieldByPsiElement(it)
Expand Down Expand Up @@ -561,8 +562,7 @@ internal fun getQualifiedName(pyExpression: PyExpression, context: TypeEvalConte
return when (pyExpression) {
is PySubscriptionExpression -> pyExpression.qualifier?.let { getQualifiedName(it, context) }
is PyReferenceExpression -> {
val resolveResults = getResolveElements(pyExpression, context)
return PyUtil.filterTopPriorityResults(resolveResults)
return getResolvedPsiElements(pyExpression, context)
.filterIsInstance<PyQualifiedNameOwner>()
.mapNotNull { it.qualifiedName }
.firstOrNull()
Expand Down
83 changes: 83 additions & 0 deletions src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.koxudaxi.pydantic

import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.*
import com.jetbrains.python.psi.types.*

/**
* `PydanticDataclassTypeProvider` gets actual pydantic dataclass types
*
* PyCharm 2021.1 detects decorated object type by type-hint of decorators.
* Unfortunately, pydantic.dataclasses.dataclass returns `Dataclass` type.
* `Dataclass` is not an actual model, which is Stub type for static-type checking.
* But, PyCharm can detect actual dataclass type by parsing the type definition.
* `PydanticDataclassTypeProvider` ignore `Dataclass` and get actual dataclass type using `PyDataclassTypeProvider`
*
*/
class PydanticDataclassTypeProvider : PyTypeProviderBase() {
private val pyDataclassTypeProvider = PyDataclassTypeProvider()

override fun getReferenceExpressionType(
referenceExpression: PyReferenceExpression,
context: TypeEvalContext,
): PyType? {
return getPydanticDataclass(referenceExpression, context)
}


private fun getDataclassCallableType(
referenceTarget: PsiElement,
context: TypeEvalContext,
callSite: PyCallExpression? = null,
): PyCallableType? {
return pyDataclassTypeProvider.getReferenceType(
referenceTarget,
context,
callSite ?: PyCallExpressionImpl(referenceTarget.node)
)?.get() as? PyCallableType
}

private fun getPydanticDataclassType(
referenceTarget: PsiElement,
context: TypeEvalContext,
pyReferenceExpression: PyReferenceExpression,
definition: Boolean,
): PyType? {
val callSite = PyCallExpressionNavigator.getPyCallExpressionByCallee(pyReferenceExpression)
val dataclassCallableType = getDataclassCallableType(referenceTarget, context, callSite) ?: return null
val dataclassType = (dataclassCallableType).getReturnType(context) as? PyClassType ?: return null
if (!isPydanticDataclass(dataclassType.pyClass)) return null

return when {
callSite is PyCallExpression && definition -> dataclassCallableType
definition -> (dataclassType.declarationElement as? PyTypedElement)?.let { context.getType(it) }
else -> dataclassType
}
}


private fun getPydanticDataclass(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {
return getResolvedPsiElements(referenceExpression, context)
.asSequence()
.mapNotNull {
when {
it is PyClass && isPydanticDataclass(it) ->
getPydanticDataclassType(it, context, referenceExpression, true)
it is PyTargetExpression -> (it as? PyTypedElement)
?.let { pyTypedElement -> context.getType(pyTypedElement) }
?.let { pyType -> getPyClassTypeByPyTypes(pyType) }
?.filter { pyClassType -> isPydanticDataclass(pyClassType.pyClass) }
?.mapNotNull { pyClassType ->
getPydanticDataclassType(pyClassType.pyClass,
context,
referenceExpression,
pyClassType.isDefinition)
}
?.firstOrNull()
else -> null
}
}.firstOrNull()
}
}
46 changes: 21 additions & 25 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ import one.util.streamex.StreamEx
class PydanticTypeProvider : PyTypeProviderBase() {
override fun getReferenceExpressionType(
referenceExpression: PyReferenceExpression,
context: TypeEvalContext
context: TypeEvalContext,
): PyType? {
return getPydanticTypeForCallee(referenceExpression, context)
}

override fun getCallType(
pyFunction: PyFunction,
callSite: PyCallSiteExpression,
context: TypeEvalContext
context: TypeEvalContext,
): Ref<PyType>? {
return when (pyFunction.qualifiedName) {
CON_LIST_Q_NAME -> Ref.create(
Expand All @@ -37,7 +37,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
override fun getReferenceType(
referenceTarget: PsiElement,
context: TypeEvalContext,
anchor: PsiElement?
anchor: PsiElement?,
): Ref<PyType>? {
if (referenceTarget is PyTargetExpression) {
val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null
Expand Down Expand Up @@ -70,7 +70,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pyClass: PyClass,
context: TypeEvalContext,
ellipsis: PyNoneLiteralExpression,
pydanticVersion: KotlinVersion?
pydanticVersion: KotlinVersion?,
): Ref<PyType>? {
return pyClass.findClassAttribute(name, false, context)
?.let { return getRefTypeFromField(it, ellipsis, context, pyClass, pydanticVersion) }
Expand All @@ -91,7 +91,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getRefTypeFromField(
pyTargetExpression: PyTargetExpression, ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext, pyClass: PyClass,
pydanticVersion: KotlinVersion?
pydanticVersion: KotlinVersion?,
): Ref<PyType>? {
return fieldToParameter(
pyTargetExpression,
Expand All @@ -107,13 +107,11 @@ class PydanticTypeProvider : PyTypeProviderBase() {

private fun getPydanticTypeForCallee(
referenceExpression: PyReferenceExpression,
context: TypeEvalContext
context: TypeEvalContext,
): PyType? {
if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null

val resolveResults = getResolveElements(referenceExpression, context)

return PyUtil.filterTopPriorityResults(resolveResults)
return getResolvedPsiElements(referenceExpression, context)
.asSequence()
.map {
when {
Expand Down Expand Up @@ -171,7 +169,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {

private fun getPydanticDynamicModelPyClass(
pyTargetExpression: PyTargetExpression,
context: TypeEvalContext
context: TypeEvalContext,
): PyClass? {
val pyCallableType = getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression, context)
?: return null
Expand All @@ -180,21 +178,20 @@ class PydanticTypeProvider : PyTypeProviderBase() {

private fun getPydanticDynamicModelTypeForTargetExpression(
pyTargetExpression: PyTargetExpression,
context: TypeEvalContext
context: TypeEvalContext,
): PydanticDynamicModelClassType? {
val pyCallExpression = pyTargetExpression.findAssignedValue() as? PyCallExpression ?: return null
return getPydanticDynamicModelTypeForTargetExpression(pyCallExpression, context)
}

private fun getPydanticDynamicModelTypeForTargetExpression(
pyCallExpression: PyCallExpression,
context: TypeEvalContext
context: TypeEvalContext,
): PydanticDynamicModelClassType? {
val arguments = pyCallExpression.arguments.toList()
if (arguments.isEmpty()) return null
val referenceExpression = (pyCallExpression.callee as? PyReferenceExpression) ?: return null
val resolveResults = getResolveElements(referenceExpression, context)
val pyFunction = PyUtil.filterTopPriorityResults(resolveResults)
val pyFunction = getResolvedPsiElements(referenceExpression, context)
.asSequence()
.filterIsInstance<PyFunction>()
.map { it.takeIf { pyFunction -> isPydanticCreateModel(pyFunction) } }.firstOrNull()
Expand All @@ -207,7 +204,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getPydanticDynamicModelTypeForFunction(
pyFunction: PyFunction,
pyArguments: List<PyExpression>,
context: TypeEvalContext
context: TypeEvalContext,
): PydanticDynamicModelClassType? {
val project = pyFunction.project
val typed = getInstance(project).currentInitTyped
Expand All @@ -227,7 +224,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pyArguments.firstOrNull()
} ?: return null
val modelName = when (modelNameArgument) {
is PyReferenceExpression -> PyUtil.filterTopPriorityResults(getResolveElements(modelNameArgument, context))
is PyReferenceExpression -> getResolvedPsiElements(modelNameArgument, context)
.filterIsInstance<PyTargetExpression>()
.map { it.findAssignedValue() }
.firstOrNull()
Expand All @@ -240,7 +237,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
val baseClass =
when (val baseArgument = (keywordArguments["__base__"] as? PyKeywordArgument)?.valueExpression) {
is PyReferenceExpression -> {
PyUtil.filterTopPriorityResults(getResolveElements(baseArgument, context))
getResolvedPsiElements(baseArgument, context)
.map {
when (it) {
is PyTargetExpression -> getPydanticDynamicModelPyClass(it, context)
Expand Down Expand Up @@ -342,7 +339,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pydanticVersion: KotlinVersion?,
config: HashMap<String, Any?>,
typed: Boolean = true,
isDataclass: Boolean = false
isDataclass: Boolean = false,
): PyCallableParameter? {
if (!isValidField(field, context)) return null
if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax
Expand Down Expand Up @@ -375,7 +372,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun fieldToParameter(
field: PyExpression,
context: TypeEvalContext,
typed: Boolean = true
typed: Boolean = true,
): PyCallableParameter {
var type: PyType? = null
var defaultValue: PyExpression? = null
Expand Down Expand Up @@ -409,7 +406,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {

private fun getTypeForParameter(
field: PyTargetExpression,
context: TypeEvalContext
context: TypeEvalContext,
): PyType? {

return context.getType(field)
Expand All @@ -420,7 +417,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext,
pydanticVersion: KotlinVersion?,
isDataclass: Boolean
isDataclass: Boolean,
): PyExpression? {

val value = field.findAssignedValue()
Expand Down Expand Up @@ -460,7 +457,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext,
pydanticVersion: KotlinVersion?,
isDataclass: Boolean
isDataclass: Boolean,
): PyExpression? {
val assignedValue = field.findAssignedValue()!!

Expand Down Expand Up @@ -514,15 +511,14 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getDefaultValueForDataclass(
assignedValue: PyCallExpression,
context: TypeEvalContext,
argumentName: String
argumentName: String,
): PyExpression? {
val defaultValue = assignedValue.getKeywordArgument(argumentName)
return when {
defaultValue == null -> null
defaultValue.text == "..." -> null
defaultValue is PyReferenceExpression -> {
val resolveResults = getResolveElements(defaultValue, context)
PyUtil.filterTopPriorityResults(resolveResults).any { isDataclassMissingByPsiElement(it) }.let {
getResolvedPsiElements(defaultValue, context).any { isDataclassMissingByPsiElement(it) }.let {
return when {
it -> null
else -> defaultValue
Expand Down

0 comments on commit 93fb36d

Please sign in to comment.