Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GenericModel #289

Merged
merged 6 commits into from
May 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<h2>version 0.3.1</h2>
<p>Features</p>
<ul>
<li>Support GenericModel [#289]</li>
<li>Support frozen on config [#288]</li>
<li>Fix format [#287]</li>
<li>Improve handling pydantic version [#286]</li>
Expand Down
48 changes: 32 additions & 16 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.jetbrains.python.statistics.modules
import java.util.regex.Pattern

const val BASE_MODEL_Q_NAME = "pydantic.main.BaseModel"
const val GENERIC_MODEL_Q_NAME = "pydantic.generics.GenericModel"
const val DATA_CLASS_Q_NAME = "pydantic.dataclasses.dataclass"
const val DATA_CLASS_SHORT_Q_NAME = "pydantic.dataclass"
const val VALIDATOR_Q_NAME = "pydantic.class_validators.validator"
Expand Down Expand Up @@ -52,7 +53,9 @@ const val OPTIONAL_Q_NAME = "typing.Optional"
const val UNION_Q_NAME = "typing.Union"
const val ANNOTATED_Q_NAME = "typing.Annotated"
const val CLASSVAR_Q_NAME = "typing.ClassVar"

const val GENERIC_Q_NAME = "typing.Generic"
const val TYPE_Q_NAME = "typing.Type"
const val TUPLE_Q_NAME = "typing.Tuple"

val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

Expand Down Expand Up @@ -123,24 +126,38 @@ fun getPyClassByPyCallExpression(
is PyClassType -> type
else -> (callee.reference?.resolve() as? PyTypedElement)?.let { context.getType(it) } ?: return null
}
return getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, includeDataclass) }?.pyClass
return getPyClassTypeByPyTypes(pyType).firstOrNull {
isPydanticModel(it.pyClass,
includeDataclass,
context)
}?.pyClass
}

fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null
return getPyClassByPyCallExpression(pyCallExpression, true, context)
}

fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext? = null): Boolean {
fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean {
return (isSubClassOfPydanticBaseModel(pyClass,
context) || (includeDataclass && isPydanticDataclass(pyClass))) && !isPydanticBaseModel(pyClass)
context) || isSubClassOfPydanticGenericModel(pyClass, context) || (includeDataclass && isPydanticDataclass(
pyClass))) && !isPydanticBaseModel(pyClass) && !isPydanticGenericModel(
pyClass)
}

fun isPydanticBaseModel(pyClass: PyClass): Boolean {
return pyClass.qualifiedName == BASE_MODEL_Q_NAME
}

internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext?): Boolean {
fun isPydanticGenericModel(pyClass: PyClass): Boolean {
return pyClass.qualifiedName == GENERIC_MODEL_Q_NAME
}

internal fun isSubClassOfPydanticGenericModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(GENERIC_MODEL_Q_NAME, context)
}

internal fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return pyClass.isSubclass(BASE_MODEL_Q_NAME, context)
}

Expand Down Expand Up @@ -363,7 +380,7 @@ fun getConfig(
val version = pydanticVersion ?: PydanticVersionService.getVersion(pyClass.project, context)
pyClass.getAncestorClasses(context)
.reversed()
.filter { isPydanticModel(it, false) }
.filter { isPydanticModel(it, false, context) }
.map { getConfig(it, context, false, version) }
.forEach {
it.entries.forEach { entry ->
Expand Down Expand Up @@ -484,7 +501,8 @@ fun getPydanticUnFilledArguments(
context: TypeEvalContext,
): List<PyCallableParameter> {
val pydanticClass = pyClass ?: getPydanticPyClass(pyCallExpression, context) ?: return emptyList()
val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pydanticClass, context, true) ?: return emptyList()
val pydanticType = pydanticTypeProvider.getPydanticTypeForClass(pydanticClass, context, true, pyCallExpression)
?: return emptyList()
val currentArguments =
pyCallExpression.arguments.filter { it is PyKeywordArgument || (it as? PyStarArgumentImpl)?.isKeyword == true }
.mapNotNull { it.name }.toSet()
Expand All @@ -500,6 +518,7 @@ fun getPyTypeFromPyExpression(pyExpression: PyExpression, context: TypeEvalConte
is PyType -> pyExpression
is PyReferenceExpression -> {
getResolvedPsiElements(pyExpression, context)
.asSequence()
.filterIsInstance<PyClass>()
.map { pyClass -> pyClass.getType(context)?.getReturnType(context) }
.firstOrNull()
Expand Down Expand Up @@ -576,14 +595,11 @@ internal fun getDefaultFactoryFromField(field: PyCallExpression): PyExpression?
internal fun getQualifiedName(pyExpression: PyExpression, context: TypeEvalContext): String? {
return when (pyExpression) {
is PySubscriptionExpression -> pyExpression.qualifier?.let { getQualifiedName(it, context) }
is PyReferenceExpression -> {
return getResolvedPsiElements(pyExpression, context)
.filterIsInstance<PyQualifiedNameOwner>()
.mapNotNull { it.qualifiedName }
.firstOrNull()
}
else -> {
return null
}
is PyReferenceExpression -> return getResolvedPsiElements(pyExpression, context)
.asSequence()
.filterIsInstance<PyQualifiedNameOwner>()
.mapNotNull { it.qualifiedName }
.firstOrNull()
else -> return null
}
}
91 changes: 62 additions & 29 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHin
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyEvaluator
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyGenericType
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.TypeEvalContext
import javax.swing.Icon

Expand Down Expand Up @@ -69,6 +71,7 @@ class PydanticCompletionContributor : CompletionContributor() {
pydanticVersion: KotlinVersion?,
config: HashMap<String, Any?>,
isDataclass: Boolean,
genericTypeMap: Map<PyGenericType, PyType>?,
): String {

val parameter = typeProvider.fieldToParameter(pyTargetExpression,
Expand All @@ -77,6 +80,7 @@ class PydanticCompletionContributor : CompletionContributor() {
pyClass,
pydanticVersion,
config,
genericTypeMap,
isDataclass = isDataclass)
val defaultValue = parameter?.defaultValue?.let {
when {
Expand All @@ -102,6 +106,7 @@ class PydanticCompletionContributor : CompletionContributor() {
config: HashMap<String, Any?>,
excludes: HashSet<String>?,
isDataclass: Boolean,
genericTypeMap: Map<PyGenericType, PyType>?,
) {
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext)
getClassVariables(pyClass, typeEvalContext)
Expand All @@ -121,7 +126,8 @@ class PydanticCompletionContributor : CompletionContributor() {
ellipsis,
pydanticVersion,
config,
isDataclass))
isDataclass,
genericTypeMap))
.withIcon(icon), 1)
results[elementName] = PrioritizedLookupElement.withPriority(element, 100.0)
}
Expand All @@ -133,17 +139,34 @@ class PydanticCompletionContributor : CompletionContributor() {
pyClass: PyClass, typeEvalContext: TypeEvalContext,
ellipsis: PyNoneLiteralExpression,
config: HashMap<String, Any?>,
genericTypeMap: Map<PyGenericType, PyType>?,
excludes: HashSet<String>? = null,
isDataclass: Boolean,
) {

val newElements: LinkedHashMap<String, LookupElement> = LinkedHashMap()

pyClass.getAncestorClasses(typeEvalContext)
.filter { isPydanticModel(it, true) }
.forEach { addFieldElement(it, newElements, typeEvalContext, ellipsis, config, excludes, isDataclass) }
.filter { isPydanticModel(it, true, typeEvalContext) }
.forEach {
addFieldElement(it,
newElements,
typeEvalContext,
ellipsis,
config,
excludes,
isDataclass,
genericTypeMap)
}

addFieldElement(pyClass, newElements, typeEvalContext, ellipsis, config, excludes, isDataclass)
addFieldElement(pyClass,
newElements,
typeEvalContext,
ellipsis,
config,
excludes,
isDataclass,
genericTypeMap)

result.runRemainingContributors(parameters)
{ completionResult ->
Expand All @@ -160,12 +183,12 @@ class PydanticCompletionContributor : CompletionContributor() {
excludes: HashSet<String>, config: HashMap<String, Any?>,
) {

if (!isPydanticModel(pyClass, true)) return
if (!isPydanticModel(pyClass, true, typeEvalContext)) return

val fieldElements: HashSet<String> = HashSet()

pyClass.getAncestorClasses(typeEvalContext)
.filter { isPydanticModel(it, true) }
.filter { isPydanticModel(it, true, typeEvalContext) }
.forEach {
fieldElements.addAll(it.classAttributes
.filterNot { attribute ->
Expand Down Expand Up @@ -226,26 +249,29 @@ class PydanticCompletionContributor : CompletionContributor() {
val pyArgumentList = parameters.position.parent?.parent as? PyArgumentList ?: return

val typeEvalContext = parameters.getTypeEvalContext()
val pyClassType =
(pyArgumentList.parent as? PyCallExpression)?.let { typeEvalContext.getType(it) } as? PyClassType
val pyCallExpression = pyArgumentList.parent as? PyCallExpression
val pyClass =
pyCallExpression?.let { (typeEvalContext.getType(it) as? PyClassType)?.pyClass }
?: return

if (!isPydanticModel(pyClassType.pyClass, true, typeEvalContext)) return
if (!isPydanticModel(pyClass, true, typeEvalContext)) return

val definedSet = pyArgumentList.children
.mapNotNull { (it as? PyKeywordArgument)?.name }
.map { "${it}=" }
.toHashSet()
val config = getConfig(pyClassType.pyClass, typeEvalContext, true)
val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis()
addAllFieldElement(parameters,

addAllFieldElement(
parameters,
result,
pyClassType.pyClass,
pyClass,
typeEvalContext,
ellipsis,
config,
PyElementGenerator.getInstance(pyClass.project).createEllipsis(),
getConfig(pyClass, typeEvalContext, true),
typeProvider.getGenericTypeMap(pyClass, typeEvalContext, pyCallExpression),
definedSet,
isPydanticDataclass(pyClassType.pyClass))
isPydanticDataclass(pyClass),
)
}
}

Expand All @@ -267,25 +293,30 @@ class PydanticCompletionContributor : CompletionContributor() {
result: CompletionResultSet,
) {
val typeEvalContext = parameters.getTypeEvalContext()
val pyType =
(parameters.position.parent?.firstChild as? PyTypedElement)?.let { typeEvalContext.getType(it) }
?: return
val pyTypedElement = parameters.position.parent?.firstChild as? PyTypedElement ?: return

val pyClassType = getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true) }
?: return
val config = getConfig(pyClassType.pyClass, typeEvalContext, true)
val pyType = typeEvalContext.getType(pyTypedElement) ?: return

val pyClassType =
getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass, true, typeEvalContext) }
?: return
val pyClass = pyClassType.pyClass
val config = getConfig(pyClass, typeEvalContext, true)
if (pyClassType.isDefinition) { // class
removeAllFieldElement(parameters, result, pyClassType.pyClass, typeEvalContext, excludeFields, config)
removeAllFieldElement(parameters, result, pyClass, typeEvalContext, excludeFields, config)
return
}
val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis()
addAllFieldElement(parameters,
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()
addAllFieldElement(
parameters,
result,
pyClassType.pyClass,
pyClass,
typeEvalContext,
ellipsis,
config,
isDataclass = isPydanticDataclass(pyClassType.pyClass))
typeProvider.getGenericTypeMap(pyClass, typeEvalContext, pyTypedElement as? PyCallExpression),
isDataclass = isPydanticDataclass(pyClass),
)
}
}

Expand Down Expand Up @@ -328,8 +359,9 @@ class PydanticCompletionContributor : CompletionContributor() {
val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!isConfigClass(configClass)) return
val pydanticModel = getPyClassByAttribute(configClass) ?: return
if (!isPydanticModel(pydanticModel, true)) return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return


val definedSet = configClass.classAttributes
.mapNotNull { it.name }
Expand All @@ -354,7 +386,8 @@ class PydanticCompletionContributor : CompletionContributor() {
result: CompletionResultSet,
) {
val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!isPydanticModel(pydanticModel, true)) return
val typeEvalContext = parameters.getTypeEvalContext()
if (!isPydanticModel(pydanticModel, true, typeEvalContext)) return
if (pydanticModel.findNestedClass("Config", false) != null) return
val element = PrioritizedLookupElement.withGrouping(
LookupElementBuilder
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() {

return when {
callSite is PyCallExpression && definition -> dataclassCallableType
definition -> (dataclassType.declarationElement as? PyTypedElement)?.let { context.getType(it) }
definition -> dataclassType.toClass()
else -> dataclassType
}
}
Expand Down
31 changes: 19 additions & 12 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
when (element) {
is PyTargetExpression -> {
val pyClass = element.containingClass ?: return false
if (isPydanticModel(pyClass, true)) return true
val content = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
if (isPydanticModel(pyClass, true, content)) return true
}
is PyKeywordArgument -> {
val pyClass = getPyClassByPyKeywordArgument(element,
TypeEvalContext.codeAnalysis(element.project, element.containingFile)) ?: return false
if (isPydanticModel(pyClass, true)) return true
val context = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
val pyClass = getPyClassByPyKeywordArgument(element, context) ?: return false
if (isPydanticModel(pyClass, true, context)) return true
}
}
return false
Expand Down Expand Up @@ -55,33 +56,39 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
element.name?.let { name ->
element.containingClass
?.let { pyClass ->
addAllElement(pyClass, name, added)
val content = TypeEvalContext.codeAnalysis(element.project, element.containingFile)
addAllElement(pyClass, name, added, content)
}
suggestAllNames(name, newName)
}
is PyKeywordArgument ->
element.name?.let { name ->
getPyClassByPyKeywordArgument(element,
TypeEvalContext.userInitiated(element.project, element.containingFile))
val context = TypeEvalContext.userInitiated(element.project, element.containingFile)
getPyClassByPyKeywordArgument(element, context)
?.let { pyClass ->
addAllElement(pyClass, name, added)
addAllElement(pyClass, name, added, context)
}
suggestAllNames(name, newName)
}
}
}

private fun addAllElement(pyClass: PyClass, elementName: String, added: MutableSet<PyClass>) {
private fun addAllElement(
pyClass: PyClass,
elementName: String,
added: MutableSet<PyClass>,
context: TypeEvalContext,
) {
added.add(pyClass)
addClassAttributes(pyClass, elementName)
addKeywordArguments(pyClass, elementName)
pyClass.getAncestorClasses(null)
.filter { isPydanticModel(it, true) && !added.contains(it) }
.forEach { addAllElement(it, elementName, added) }
.filter { isPydanticModel(it, true, context) && !added.contains(it) }
.forEach { addAllElement(it, elementName, added, context) }

PyClassInheritorsSearch.search(pyClass, true)
.filterNot { added.contains(it) }
.forEach { addAllElement(it, elementName, added) }
.forEach { addAllElement(it, elementName, added, context) }
}

private fun addClassAttributes(pyClass: PyClass, elementName: String) {
Expand Down
Loading