Skip to content

Commit

Permalink
inspect untyped fields (#93)
Browse files Browse the repository at this point in the history
* inspect untyped fields
  • Loading branch information
koxudaxi committed Dec 12, 2019
1 parent 1e161dd commit 352aa2d
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 24 deletions.
7 changes: 6 additions & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
<idea-plugin url="https://github.com/koxudaxi/pydantic-pycharm-plugin">
<id>com.koxudaxi.pydantic</id>
<name>Pydantic</name>
<version>0.0.28</version>
<version>0.0.29</version>
<vendor email="koaxudai@gmail.com">Koudai Aono @koxudaxi</vendor>
<change-notes><![CDATA[
<h2>version 0.0.29</h2>
<p>Features, BugFixes</p>
<ul>
<li>Inspect untyped fields [#93] </li>
</ul>
<h2>version 0.0.28</h2>
<p>Features, BugFixes</p>
<ul>
Expand Down
5 changes: 5 additions & 0 deletions resources/inspectionDescriptions/PydanticInspection.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<html>
<body>
This inspection checks Pydantic models.
</body>
</html>
32 changes: 18 additions & 14 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ val CONFIG_TYPES = mapOf(
"allow_mutation" to Boolean
)

internal fun getPyClassByPyCallExpression(pyCallExpression: PyCallExpression, context: TypeEvalContext): PyClass? {
fun getPyClassByPyCallExpression(pyCallExpression: PyCallExpression, context: TypeEvalContext): PyClass? {
val callee = pyCallExpression.callee ?: return null
val pyType = when (val type = context.getType(callee)) {
is PyClass -> return type
Expand All @@ -65,16 +65,16 @@ internal fun getPyClassByPyCallExpression(pyCallExpression: PyCallExpression, co
return getPyClassTypeByPyTypes(pyType).firstOrNull { isPydanticModel(it.pyClass) }?.pyClass
}

internal fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context: TypeEvalContext): PyClass? {
val pyCallExpression = PsiTreeUtil.getParentOfType(pyKeywordArgument, PyCallExpression::class.java) ?: return null
return getPyClassByPyCallExpression(pyCallExpression, context)
}

internal fun isPydanticModel(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
fun isPydanticModel(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
return (isSubClassOfPydanticBaseModel(pyClass, context) || isPydanticDataclass(pyClass)) && !isPydanticBaseModel(pyClass)
}

internal fun isPydanticBaseModel(pyClass: PyClass): Boolean {
fun isPydanticBaseModel(pyClass: PyClass): Boolean {
return pyClass.qualifiedName == BASE_MODEL_Q_NAME
}

Expand Down Expand Up @@ -154,13 +154,13 @@ private fun getAliasedFieldName(field: PyTargetExpression, context: TypeEvalCont
}


internal fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
val resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context)
return referenceExpression.getReference(resolveContext).multiResolve(false)

}

internal fun getPyClassTypeByPyTypes(pyType: PyType): List<PyClassType> {
fun getPyClassTypeByPyTypes(pyType: PyType): List<PyClassType> {
return when (pyType) {
is PyUnionType ->
pyType.members
Expand All @@ -174,13 +174,13 @@ internal fun getPyClassTypeByPyTypes(pyType: PyType): List<PyClassType> {
}


internal fun isPydanticSchemaByPsiElement(psiElement: PsiElement, context: TypeEvalContext): Boolean {
fun isPydanticSchemaByPsiElement(psiElement: PsiElement, context: TypeEvalContext): Boolean {
PsiTreeUtil.getContextOfType(psiElement, PyClass::class.java)
?.let { return isPydanticSchema(it, context) }
return false
}

internal fun isPydanticFieldByPsiElement(psiElement: PsiElement): Boolean {
fun isPydanticFieldByPsiElement(psiElement: PsiElement): Boolean {
when (psiElement) {
is PyFunction -> return isPydanticField(psiElement)
else -> PsiTreeUtil.getContextOfType(psiElement, PyFunction::class.java)
Expand All @@ -189,7 +189,7 @@ internal fun isPydanticFieldByPsiElement(psiElement: PsiElement): Boolean {
return false
}

internal fun getPydanticVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
fun getPydanticVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
val module = project.modules.firstOrNull() ?: return null
val pythonSdk = module.pythonSdk
val contextAnchor = ModuleBasedContextAnchor(module)
Expand All @@ -210,11 +210,11 @@ internal fun getPydanticVersion(project: Project, context: TypeEvalContext): Kot
})
}

internal fun isValidFieldName(name: String): Boolean {
fun isValidFieldName(name: String): Boolean {
return name.first() != '_'
}

internal fun getConfigValue(name: String, value: Any?, context: TypeEvalContext): Any? {
fun getConfigValue(name: String, value: Any?, context: TypeEvalContext): Any? {
if (value is PyReferenceExpression) {
val resolveResults = getResolveElements(value, context)
val targetExpression = PyUtil.filterTopPriorityResults(resolveResults).firstOrNull() ?: return null
Expand All @@ -232,7 +232,7 @@ internal fun getConfigValue(name: String, value: Any?, context: TypeEvalContext)
return null
}

internal fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: Boolean): HashMap<String, Any?> {
fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: Boolean): HashMap<String, Any?> {
val config = hashMapOf<String, Any?>()
pyClass.getAncestorClasses(context)
.reversed()
Expand Down Expand Up @@ -265,7 +265,7 @@ internal fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: B
return config
}

internal fun getFieldName(field: PyTargetExpression,
fun getFieldName(field: PyTargetExpression,
context: TypeEvalContext,
config: HashMap<String, Any?>,
pydanticVersion: KotlinVersion?): String? {
Expand All @@ -286,9 +286,13 @@ internal fun getFieldName(field: PyTargetExpression,
}


internal fun getPydanticBaseConfig(project: Project, context: TypeEvalContext): PyClass? {
fun getPydanticBaseConfig(project: Project, context: TypeEvalContext): PyClass? {
val module = project.modules.firstOrNull() ?: return null
val pythonSdk = module.pythonSdk
val contextAnchor = ModuleBasedContextAnchor(module)
return BASE_CONFIG_QUALIFIED_NAME.resolveToElement(QNameResolveContext(contextAnchor, pythonSdk, context)) as? PyClass
}

fun getPyClassByAttribute(pyPsiElement: PsiElement?): PyClass? {
return pyPsiElement?.parent?.parent as? PyClass
}
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ class PydanticCompletionContributor : CompletionContributor() {


override fun addCompletions(parameters: CompletionParameters, context: ProcessingContext, result: CompletionResultSet) {
val configClass = parameters.position.parent?.parent?.parent?.parent as? PyClass ?: return
val configClass = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!isConfigClass(configClass)) return
val pydanticModel = configClass.parent?.parent as? PyClass ?:return
val pydanticModel = getPyClassByAttribute(configClass) ?:return
if (!isPydanticModel(pydanticModel)) return
val typeEvalContext = parameters.getTypeEvalContext()

Expand All @@ -264,7 +264,7 @@ class PydanticCompletionContributor : CompletionContributor() {
override val icon: Icon = AllIcons.Nodes.Class

override fun addCompletions(parameters: CompletionParameters, context: ProcessingContext, result: CompletionResultSet) {
val pydanticModel = parameters.position.parent?.parent?.parent?.parent as? PyClass ?: return
val pydanticModel = getPyClassByAttribute(parameters.position.parent?.parent) ?: return
if (!isPydanticModel(pydanticModel)) return
if (pydanticModel.findNestedClass("Config", false) != null) return
val element = PrioritizedLookupElement.withGrouping(
Expand Down
32 changes: 27 additions & 5 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.koxudaxi.pydantic
import com.intellij.codeInspection.LocalInspectionToolSession
import com.intellij.codeInspection.ProblemHighlightType
import com.intellij.codeInspection.ProblemsHolder
import com.intellij.codeInspection.ui.MultipleCheckboxOptionsPanel
import com.intellij.psi.PsiElementVisitor
import com.jetbrains.python.PyBundle
import com.jetbrains.python.PyNames
Expand All @@ -16,19 +17,24 @@ import com.jetbrains.python.psi.impl.PyTargetExpressionImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyClassTypeImpl
import javax.swing.JComponent

var defaultWarnUntypedFields = false

class PydanticInspection : PyInspection() {
var warnUntypedFields = defaultWarnUntypedFields

override fun buildVisitor(holder: ProblemsHolder,
isOnTheFly: Boolean,
session: LocalInspectionToolSession): PsiElementVisitor = Visitor(holder, session)

private class Visitor(holder: ProblemsHolder, session: LocalInspectionToolSession) : PyInspectionVisitor(holder, session) {
inner class Visitor(holder: ProblemsHolder, session: LocalInspectionToolSession) : PyInspectionVisitor(holder, session) {

override fun visitPyFunction(node: PyFunction?) {
super.visitPyFunction(node)

val pyClass = node?.parent?.parent as? PyClass ?: return
if (node == null) return
val pyClass = getPyClassByAttribute(node) ?: return
if (!isPydanticModel(pyClass, myTypeEvalContext) || !isValidatorMethod(node)) return
val paramList = node.parameterList
val params = paramList.parameters
Expand All @@ -49,7 +55,6 @@ class PydanticInspection : PyInspection() {
super.visitPyCallExpression(node)

if (node == null) return

inspectPydanticModelCallableExpression(node)
inspectFromOrm(node)

Expand All @@ -59,9 +64,10 @@ class PydanticInspection : PyInspection() {
super.visitPyAssignmentStatement(node)

if (node == null) return

if (this@PydanticInspection.warnUntypedFields) {
inspectWarnUntypedFields(node)
}
inspectReadOnlyProperty(node)

}

private fun inspectPydanticModelCallableExpression(pyCallExpression: PyCallExpression) {
Expand Down Expand Up @@ -110,5 +116,21 @@ class PydanticInspection : PyInspection() {
ProblemHighlightType.GENERIC_ERROR)

}

private fun inspectWarnUntypedFields(node: PyAssignmentStatement){
val pyClass = getPyClassByAttribute(node) ?: return
if (!isPydanticModel(pyClass, myTypeEvalContext)) return
if (node.annotation != null) return

registerProblem(node,
"Untyped fields disallowed", ProblemHighlightType.WARNING)

}
}

override fun createOptionsPanel(): JComponent? {
val panel = MultipleCheckboxOptionsPanel(this)
panel.addCheckbox( "Warning untyped fields", "warnUntypedFields")
return panel
}
}
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {

override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref<PyType>? {
if (referenceTarget is PyTargetExpression) {
val pyClass = referenceTarget.parent?.parent?.parent as? PyClass ?: return null
val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null
if (!isPydanticModel(pyClass, context)) return null
val name = referenceTarget.name ?: return null
getRefTypeFromFieldName(name, context, pyClass)?.let { return it }
Expand Down
21 changes: 21 additions & 0 deletions testData/inspection/warnUntypedFields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import BaseModel


class A(BaseModel):
<warning descr="Untyped fields disallowed">a = '123'</warning>


class B(BaseModel):
b: str = '123'


class C:
c = '123'

class D:
d

def e():
ee = '123'

f = '123'
22 changes: 22 additions & 0 deletions testData/inspection/warnUntypedFieldsDisable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pydantic import BaseModel


class A(BaseModel):
a = '123'


class B(BaseModel):
b: str = '123'


class C:
c = '123'


class D:
d

def e():
ee = '123'

f = '123'
13 changes: 13 additions & 0 deletions testSrc/com/koxudaxi/pydantic/PydanticInspectionTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,17 @@ open class PydanticInspectionTest : PydanticInspectionBase() {
fun testReadOnlyProperty() {
doTest()
}

fun testWarnUntypedFieldsDisable() {
doTest()
}

fun testWarnUntypedFields() {
try {
defaultWarnUntypedFields = true
doTest()
} finally {
defaultWarnUntypedFields = false
}
}
}

0 comments on commit 352aa2d

Please sign in to comment.