Skip to content

Commit

Permalink
support features by parameters (#67)
Browse files Browse the repository at this point in the history
* support features by parameters
  • Loading branch information
koxudaxi authored Sep 19, 2019
1 parent 9232eaf commit 05706d1
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 94 deletions.
8 changes: 7 additions & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
<idea-plugin url="https://github.com/koxudaxi/pydantic-pycharm-plugin">
<id>com.koxudaxi.pydantic</id>
<name>Pydantic</name>
<version>0.0.19</version>
<version>0.0.20</version>
<vendor email="koaxudai@gmail.com">Koudai Aono @koxudaxi</vendor>
<change-notes><![CDATA[
<h2>version 0.0.20</h2>
<p>Features, BugFixes</p>
<ul>
<li>Support all features by parameters [#67] </li>
<li>Fix to handle models which have __init__ or __new__ methods [#67] </li>
</ul>
<h2>version 0.0.19</h2>
<p>BugFixes</p>
<ul>
Expand Down
34 changes: 27 additions & 7 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.jetbrains.python.psi.impl.PyCallExpressionImpl
import com.jetbrains.python.psi.impl.PyTargetExpressionImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.psi.types.*

const val BASE_MODEL_Q_NAME = "pydantic.main.BaseModel"
const val DATA_CLASS_Q_NAME = "pydantic.dataclasses.dataclass"
Expand All @@ -18,13 +18,13 @@ const val SCHEMA_Q_NAME = "pydantic.schema.Schema"
const val FIELD_Q_NAME = "pydantic.field.Field"
const val BASE_SETTINGS_Q_NAME = "pydantic.env_settings.BaseSettings"

internal fun getPyClassByPyCallExpression(pyCallExpression: PyCallExpression): PyClass? {
return pyCallExpression.callee?.reference?.resolve() as? PyClass
}

internal fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument): PyClass? {
internal fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null
return getPyClassByPyCallExpression(pyCallExpression)
val callee = pyCallExpression.callee ?: return null
val pyClass = context.getType(callee)
if (pyClass is PyClass) return pyClass
val pyType = (callee.reference?.resolve() as? PyTypedElement)?.let { context.getType(it) } ?: return null
return getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass) }?.pyClass
}

internal fun isPydanticModel(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
Expand Down Expand Up @@ -100,3 +100,23 @@ internal fun getResolveElements(referenceExpression: PyReferenceExpression, cont
return referenceExpression.getReference(resolveContext).multiResolve(false)

}

internal fun getPyClassTypeByPyTypes(pyType: PyType): List<PyClassType> {
return when (pyType) {
is PyUnionType ->
pyType.members
.mapNotNull { it }
.flatMap {
getPyClassTypeByPyTypes(it)
}

is PyCollectionType ->
pyType.elementTypes
.mapNotNull { it }
.flatMap {
getPyClassTypeByPyTypes(it)
}
is PyClassType -> listOf(pyType)
else -> listOf()
}
}
74 changes: 13 additions & 61 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@ import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.codeInsight.lookup.LookupElementBuilder
import com.intellij.icons.AllIcons
import com.intellij.patterns.PlatformPatterns.psiElement
import com.intellij.psi.util.PsiTreeUtil.getParentOfType
import com.intellij.util.ProcessingContext
import com.jetbrains.python.PyTokenTypes
import com.jetbrains.python.codeInsight.completion.getTypeEvalContext
import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHint
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyUnionType
import com.jetbrains.python.psi.types.TypeEvalContext
import javax.swing.Icon

Expand Down Expand Up @@ -58,45 +55,6 @@ class PydanticCompletionContributor : CompletionContributor() {
return "${typeHint}$defaultValue ${pyClass.name}"
}

private fun getPyClassFromPyNamedParameter(pyNamedParameter: PyNamedParameter, typeEvalContext: TypeEvalContext): PyClass? {
return when (val pyClassTypes = pyNamedParameter.getArgumentType(typeEvalContext)) {
is PyClassType -> pyClassTypes.pyClass
is PyUnionType -> pyClassTypes.members.filterIsInstance<PyClassType>()
.map { pyClassType -> pyClassType.pyClass }
.firstOrNull()
else -> null
}
}

protected fun getPyClassByPyReferenceExpression(pyReferenceExpression: PyReferenceExpression, typeEvalContext: TypeEvalContext, parameters: CompletionParameters?, result: CompletionResultSet?): PyClass? {
val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(typeEvalContext)
return pyReferenceExpression.multiFollowAssignmentsChain(resolveContext).mapNotNull {
when (val resolveElement = it.element) {
is PyClass -> {
if (parameters != null && result != null) {
removeAllFieldElement(parameters, result, resolveElement, typeEvalContext, excludeFields)
null
} else {
resolveElement
}
}
is PyCallExpression -> getPyClassByPyCallExpression(resolveElement)
is PyNamedParameter -> {
if ((parameters != null && result != null) && resolveElement.isSelf) {
getParentOfType(resolveElement, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.takeIf { it.containingClass is PyClass }
?.let {
removeAllFieldElement(parameters, result, it.containingClass!!, typeEvalContext, excludeFields)
return null
}
}
getPyClassFromPyNamedParameter(resolveElement, typeEvalContext)
}
else -> null
}
}.firstOrNull()
}

private fun addFieldElement(pyClass: PyClass, results: LinkedHashMap<String, LookupElement>,
typeEvalContext: TypeEvalContext,
Expand Down Expand Up @@ -177,22 +135,17 @@ class PydanticCompletionContributor : CompletionContributor() {
override fun addCompletions(parameters: CompletionParameters, context: ProcessingContext, result: CompletionResultSet) {
val pyArgumentList = parameters.position.parent!!.parent!! as PyArgumentList
val typeEvalContext = parameters.getTypeEvalContext()
val pyClassType = (typeEvalContext.getType(pyArgumentList.parent as PyCallExpression) as? PyClassType)
?: return

val pyClass = when (val pyCallableElement = pyArgumentList.parent!!) {
is PyReferenceExpression -> getPyClassByPyReferenceExpression(pyCallableElement, typeEvalContext, null, null)
?: return
is PyCallExpression -> getPyClassByPyCallExpression(pyCallableElement) ?: return
else -> return
}

if (!isPydanticModel(pyClass, typeEvalContext)) return
if (!isPydanticModel(pyClassType.pyClass, typeEvalContext)) return

val definedSet = pyArgumentList.children
.mapNotNull { (it as? PyKeywordArgument)?.name }
.map { "${it}=" }
.toHashSet()
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()
addAllFieldElement(parameters, result, pyClass, typeEvalContext, ellipsis, definedSet)
val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis()
addAllFieldElement(parameters, result, pyClassType.pyClass, typeEvalContext, ellipsis, definedSet)
}
}

Expand All @@ -205,16 +158,15 @@ class PydanticCompletionContributor : CompletionContributor() {

override fun addCompletions(parameters: CompletionParameters, context: ProcessingContext, result: CompletionResultSet) {
val typeEvalContext = parameters.getTypeEvalContext()
val pyClass = when (val instance = parameters.position.parent.firstChild) {
is PyReferenceExpression -> getPyClassByPyReferenceExpression(instance, typeEvalContext, parameters, result)
?: return
is PyCallExpression -> getPyClassByPyCallExpression(instance) ?: return
else -> return
}
val pyType = typeEvalContext.getType(parameters.position.parent.firstChild as PyTypedElement) ?: return

if (!isPydanticModel(pyClass, typeEvalContext)) return
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()
addAllFieldElement(parameters, result, pyClass, typeEvalContext, ellipsis)
val pyClassType = getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass) } ?: return
if (pyClassType.isDefinition) { // class
removeAllFieldElement(parameters, result, pyClassType.pyClass, typeEvalContext, excludeFields)
return
}
val ellipsis = PyElementGenerator.getInstance(pyClassType.pyClass.project).createEllipsis()
addAllFieldElement(parameters, result, pyClassType.pyClass, typeEvalContext, ellipsis)
}
}
}
5 changes: 3 additions & 2 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.search.PyClassInheritorsSearch
import com.jetbrains.python.psi.types.TypeEvalContext


class PydanticFieldRenameFactory : AutomaticRenamerFactory {
Expand All @@ -22,7 +23,7 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
if (isPydanticModel(pyClass)) return true
}
is PyKeywordArgument -> {
val pyClass = getPyClassByPyKeywordArgument(element) ?: return false
val pyClass = getPyClassByPyKeywordArgument(element, TypeEvalContext.codeAnalysis(element.project, element.containingFile)) ?: return false
if (isPydanticModel(pyClass)) return true
}
}
Expand Down Expand Up @@ -59,7 +60,7 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
}
is PyKeywordArgument ->
element.name?.let { name ->
getPyClassByPyKeywordArgument(element)
getPyClassByPyKeywordArgument(element, TypeEvalContext.userInitiated(element.project, element.containingFile))
?.let { pyClass ->
addAllElement(pyClass, name, added)
}
Expand Down
39 changes: 28 additions & 11 deletions src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ import com.intellij.psi.PsiReference
import com.intellij.psi.search.searches.ReferencesSearch
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.Processor
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.search.PyClassInheritorsSearch
import com.jetbrains.python.psi.types.TypeEvalContext


class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSearch.SearchParameters>() {
Expand All @@ -20,7 +18,7 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
is PyKeywordArgument -> run<RuntimeException> {
element.name
?.let { elementName ->
getPyClassByPyKeywordArgument(element)
getPyClassByPyKeywordArgument(element, TypeEvalContext.userInitiated(element.project, element.containingFile))
?.takeIf { pyClass -> isPydanticModel(pyClass) }
?.let { pyClass -> searchDirectReferenceField(pyClass, elementName, consumer) }
}
Expand All @@ -43,18 +41,37 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
return true
}

private fun searchKeywordArgumentByPsiReference(psiReference: PsiReference, elementName: String, consumer: Processor<in PsiReference>) {
PsiTreeUtil.getParentOfType(psiReference.element, PyCallExpression::class.java)
?.let { callee ->
callee.arguments.firstOrNull { it.name == elementName }?.let { consumer.process(it.reference) }
}
}

private fun searchKeywordArgument(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>) {
ReferencesSearch.search(pyClass as PsiElement).forEach { psiReference ->
PsiTreeUtil.getParentOfType(psiReference.element, PyCallExpression::class.java)
?.let { callee ->
callee.arguments
.filterIsInstance<PyKeywordArgument>()
.filter { it.name == elementName }
.forEach { consumer.process(it.reference) }
searchKeywordArgumentByPsiReference(psiReference, elementName, consumer)

PsiTreeUtil.getParentOfType(psiReference.element, PyNamedParameter::class.java)
?.let { param ->
param.getArgumentType(TypeEvalContext.userInitiated(
psiReference.element.project,
psiReference.element.containingFile))
?.let { pyType ->
getPyClassTypeByPyTypes(pyType)
.firstOrNull { pyClassType -> isPydanticModel(pyClassType.pyClass) }
?.let {
ReferencesSearch.search(param as PsiElement).forEach {
searchKeywordArgumentByPsiReference(it, elementName, consumer)
}
}
}
}
}

}


private fun searchDirectReferenceField(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>): Boolean {
if (searchField(pyClass, elementName, consumer)) return true

Expand Down
11 changes: 7 additions & 4 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.jetbrains.python.inspections.quickfix.RenameParameterQuickFix
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyReferenceExpressionImpl
import com.jetbrains.python.psi.impl.PyStarArgumentImpl
import com.jetbrains.python.psi.types.PyClassType

class PydanticInspection : PyInspection() {

Expand Down Expand Up @@ -44,14 +45,16 @@ class PydanticInspection : PyInspection() {
override fun visitPyCallExpression(node: PyCallExpression?) {
super.visitPyCallExpression(node)

val pyClass: PyClass = node?.let { getPyClassByPyCallExpression(node) } ?: return
if (!isPydanticModel(pyClass, myTypeEvalContext)) return
if ((node.callee as PyReferenceExpressionImpl).isQualified) return
if (node == null) return

val pyClassType = myTypeEvalContext.getType(node) as? PyClassType ?: return
if (!isPydanticModel(pyClassType.pyClass, myTypeEvalContext)) return
if ((node.callee as? PyReferenceExpressionImpl)?.isQualified == true) return
node.arguments
.filterNot { it is PyKeywordArgument || (it as? PyStarArgumentImpl)?.isKeyword == true }
.forEach {
registerProblem(it,
"class '${pyClass.name}' accepts only keyword arguments")
"class '${pyClassType.pyClass.name}' accepts only keyword arguments")
}
}
}
Expand Down
23 changes: 15 additions & 8 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ class PydanticTypeProvider : PyTypeProviderBase() {
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.let { it.containingClass?.let { getPydanticTypeForClass(it, context) } }
}
it is PyNamedParameter -> it.getArgumentType(context)?.let { pyType ->
getPyClassTypeByPyTypes(pyType).filter { pyClassType ->
pyClassType.isDefinition
}.map { filteredPyClassType -> getPydanticTypeForClass(filteredPyClassType.pyClass, context) }.firstOrNull()
}
it is PyTargetExpression -> (it as? PyTypedElement)?.let { pyTypedElement ->
context.getType(pyTypedElement)
?.let { pyType -> getPyClassTypeByPyTypes(pyType) }
?.filter { pyClassType -> pyClassType.isDefinition }
?.map { filteredPyClassType ->
getPydanticTypeForClass(filteredPyClassType.pyClass, context)
}?.firstOrNull()
}
else -> null
}
}
Expand All @@ -86,20 +99,14 @@ class PydanticTypeProvider : PyTypeProviderBase() {
if (!isPydanticModel(pyClass, context)) return null
val clsType = (context.getType(pyClass) as? PyClassLikeType) ?: return null
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()
val resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context)

val collected = linkedMapOf<String, PyCallableParameter>()

for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) {
if (currentType == null ||
!currentType.resolveMember(PyNames.INIT, null, AccessDirection.READ, resolveContext, false).isNullOrEmpty() ||
!currentType.resolveMember(PyNames.NEW, null, AccessDirection.READ, resolveContext, false).isNullOrEmpty() ||
currentType !is PyClassType) {
continue
}
if ( currentType !is PyClassType) continue

val current = currentType.pyClass
if (!isPydanticModel(current, context)) return null
if (!isPydanticModel(current, context)) continue

getClassVariables(current, context)
.mapNotNull { fieldToParameter(it, ellipsis, context, current) }
Expand Down
14 changes: 14 additions & 0 deletions testData/completion/assignedClass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from builtins import *
from pydantic import BaseModel


class A(BaseModel):
abc: str
cde = str('abc')
efg: str = str('abc')

class B(A):
hij: str

a = A
a(<caret>)
13 changes: 13 additions & 0 deletions testData/completion/classFields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from builtins import *
from pydantic import BaseModel


class A(BaseModel):
abc: str
cde = str('abc')
efg: str = str('abc')

class B(A):
hij: str

B.<caret>
13 changes: 13 additions & 0 deletions testData/completion/parameterAnnotationType.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from builtins import *
from pydantic import BaseModel
from typing import Type


class A(BaseModel):
abc: str
cde: str = str('abc')
efg: str = str('abc')


def get_a(a: Type[A]):
a.<caret>
Loading

0 comments on commit 05706d1

Please sign in to comment.