Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pydantic.dataclasses.dataclass #43

Merged
merged 2 commits into from
Aug 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
* Refactor support for renaming fields for subclasses of `BaseModel`
* (If the field name is refactored from the model definition or `__init__` call keyword arguments, PyCharm will present a dialog offering the choice to automatically rename the keyword where it occurs in a model initialization call.
* Search related-fields by class attributes and keyword arguments of `__init__` with `Ctrl+B` and `Cmd+B`

#### pydantic.dataclasses.dataclass
Support same features as `pydantic.BaseModel`

## How to install:
### MarketPlace
Expand Down
3 changes: 2 additions & 1 deletion resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<h2>version 0.0.14</h2>
<p>Features</p>
<ul>
<li>Support pydantic.dataclasses.dataclass [#43] </li>
<li>Search related-fields by class attributes and keyword arguments of __init__. with Ctrl+B and Cmd+B [#42] </li>
</ul>
<h2>version 0.0.13</h2>
Expand Down Expand Up @@ -39,7 +40,7 @@
</li>
<li>pydantic.dataclasses.dataclass
<ul>
<li>The plugin has not supported dataclass yet.</li>
<li>Support same features as `pydantic.BaseModel`</li>
</ul>
</li>
</ul>
Expand Down
43 changes: 43 additions & 0 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.koxudaxi.pydantic

import com.intellij.psi.util.QualifiedName
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.PyReferenceExpression
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.TypeEvalContext


fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument): PyClass? {
val pyCallExpression = pyKeywordArgument.parent?.parent as? PyCallExpression ?: return null
return pyCallExpression.callee?.reference?.resolve() as? PyClass ?: return null
}

fun isPydanticModel(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
return isSubClassOfPydanticBaseModel(pyClass, context) || isPydanticDataclass(pyClass)
}

fun isPydanticBaseModel(pyClass: PyClass): Boolean {
return pyClass.qualifiedName == "pydantic.main.BaseModel"
}

fun isSubClassOfPydanticBaseModel(pyClass: PyClass, context: TypeEvalContext?): Boolean {
return pyClass.isSubclass("pydantic.main.BaseModel", context)
}

fun isPydanticDataclass(pyClass: PyClass): Boolean {
val decorators = pyClass.decoratorList?.decorators ?: return false
for (decorator in decorators) {
val callee = (decorator.callee as? PyReferenceExpression) ?: continue

for (decoratorQualifiedName in PyResolveUtil.resolveImportedElementQNameLocally(callee)) {
if (decoratorQualifiedName == QualifiedName.fromDottedString("pydantic.dataclasses.dataclass")) return true
}
}
return false
}

fun isPydanticField(pyClass: PyClass, context: TypeEvalContext? = null): Boolean {
return pyClass.isSubclass("pydantic.schema.Schema", context) || pyClass.isSubclass("pydantic.field.Field", context)
}
11 changes: 0 additions & 11 deletions src/com/koxudaxi/pydantic/PydanticBaseModel.kt

This file was deleted.

11 changes: 5 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticFieldRenameFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
when (element) {
is PyTargetExpression -> {
val pyClass = element.containingClass ?: return false
if (pyClass.isSubclass("pydantic.main.BaseModel", null)) return true
if (isPydanticModel(pyClass)) return true
}
is PyKeywordArgument -> {
val pyClass = getPyClassByPyKeywordArgument(element) ?: return false
if (pyClass.isSubclass("pydantic.main.BaseModel", null)) return true
if (isPydanticModel(pyClass)) return true
}
}
return false
Expand Down Expand Up @@ -68,15 +68,15 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
addClassAttributes(pyClass, elementName)
addKeywordArguments(pyClass, elementName)
pyClass.getAncestorClasses(null).forEach { ancestorClass ->
if (ancestorClass.qualifiedName != "pydantic.main.BaseModel") {
if (ancestorClass.isSubclass("pydantic.main.BaseModel", null) &&
if (!isPydanticBaseModel(ancestorClass)) {
if (isPydanticModel(ancestorClass) &&
!added.contains(ancestorClass)) {
addAllElement(ancestorClass, elementName, added)
}
}
}
PyClassInheritorsSearch.search(pyClass, true).forEach { inheritorsPyClass ->
if (inheritorsPyClass.qualifiedName != "pydantic.main.BaseModel" && !added.contains(inheritorsPyClass)) {
if (!isPydanticBaseModel(inheritorsPyClass) && !added.contains(inheritorsPyClass)) {
addAllElement(inheritorsPyClass, elementName, added)
}
}
Expand All @@ -93,7 +93,6 @@ class PydanticFieldRenameFactory : AutomaticRenamerFactory {
callee?.arguments?.forEach { argument ->
if (argument is PyKeywordArgument && argument.name == elementName) {
myElements.add(argument)

}
}
}
Expand Down
16 changes: 8 additions & 8 deletions src/com/koxudaxi/pydantic/PydanticFieldSearchExecutor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.search.PyClassInheritorsSearch

private fun searchField(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>): Boolean {
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return false
if (!isPydanticModel(pyClass)) return false
val pyTargetExpression = pyClass.findClassAttribute(elementName, false, null) ?: return false
consumer.process(pyTargetExpression.reference)
return true
}

private fun searchKeywordArgument(pyClass: PyClass, elementName: String, consumer: Processor<in PsiReference>) {
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return
if (!isPydanticModel(pyClass)) return
ReferencesSearch.search(pyClass as PsiElement).forEach { psiReference ->
val callee = PsiTreeUtil.getParentOfType(psiReference.element, PyCallExpression::class.java)
callee?.arguments?.forEach { argument ->
Expand All @@ -37,8 +37,8 @@ private fun searchDirectReferenceField(pyClass: PyClass, elementName: String, co
if (searchField(pyClass, elementName, consumer)) return true

pyClass.getAncestorClasses(null).forEach { ancestorClass ->
if (ancestorClass.qualifiedName != "pydantic.main.BaseModel") {
if (ancestorClass.isSubclass("pydantic.main.BaseModel", null)) {
if (!isPydanticBaseModel(ancestorClass)) {
if (isPydanticModel(ancestorClass)) {
if (searchDirectReferenceField(ancestorClass, elementName, consumer)) {
return true
}
Expand All @@ -54,12 +54,12 @@ private fun searchAllElementReference(pyClass: PyClass?, elementName: String, ad
searchField(pyClass, elementName, consumer)
searchKeywordArgument(pyClass, elementName, consumer)
pyClass.getAncestorClasses(null).forEach { ancestorClass ->
if (ancestorClass.qualifiedName != "pydantic.main.BaseModel" && !added.contains(ancestorClass)){
if (isPydanticBaseModel(ancestorClass) && !added.contains(ancestorClass)){
searchField(pyClass, elementName, consumer)
}
}
PyClassInheritorsSearch.search(pyClass, true).forEach { inheritorsPyClass ->
if (inheritorsPyClass.qualifiedName != "pydantic.main.BaseModel" && !added.contains(inheritorsPyClass)) {
if (!isPydanticBaseModel(inheritorsPyClass) && !added.contains(inheritorsPyClass)) {
searchAllElementReference(inheritorsPyClass, elementName, added, consumer)
}
}
Expand All @@ -72,13 +72,13 @@ class PydanticFieldSearchExecutor : QueryExecutorBase<PsiReference, ReferencesSe
is PyKeywordArgument -> run<RuntimeException> {
val elementName = element.name ?: return@run
val pyClass = getPyClassByPyKeywordArgument(element) ?: return@run
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return@run
if (!isPydanticModel(pyClass)) return@run
searchDirectReferenceField(pyClass, elementName, consumer)
}
is PyTargetExpression -> run<RuntimeException> {
val elementName = element.name ?: return@run
val pyClass = element.containingClass ?: return@run
if (!pyClass.isSubclass("pydantic.main.BaseModel", null)) return@run
if (!isPydanticModel(pyClass)) return@run
searchAllElementReference(pyClass, elementName, mutableSetOf(), consumer)
}
}
Expand Down
69 changes: 0 additions & 69 deletions src/com/koxudaxi/pydantic/PydanticFieldStub.kt

This file was deleted.

19 changes: 0 additions & 19 deletions src/com/koxudaxi/pydantic/PydanticFieldStubType.kt

This file was deleted.

5 changes: 1 addition & 4 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ import com.jetbrains.python.inspections.PyInspectionVisitor
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.impl.PyClassImpl
import com.jetbrains.python.psi.impl.PyReferenceExpressionImpl
import com.jetbrains.python.psi.impl.PyStarArgumentImpl
import com.jetbrains.python.psi.impl.references.PyReferenceImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.PyClassTypeImpl

class PydanticInspection : PyInspection() {

Expand All @@ -29,7 +26,7 @@ class PydanticInspection : PyInspection() {

if (node != null) {
val pyClass: PyClass = (node.callee?.reference as? PyReferenceImpl)?.resolve() as? PyClass ?: return
if (!pyClass.isSubclass("pydantic.main.BaseModel", myTypeEvalContext)) return
if (!isPydanticModel(pyClass, myTypeEvalContext)) return
if ((node.callee as PyReferenceExpressionImpl).isQualified) return
for (argument in node.arguments) {
if (argument is PyKeywordArgument) {
Expand Down
11 changes: 0 additions & 11 deletions src/com/koxudaxi/pydantic/PydanticStub.kt

This file was deleted.

47 changes: 20 additions & 27 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}

val current = currentType.pyClass
if (!current.isSubclass("pydantic.main.BaseModel", context)) return null
if (!isPydanticModel(current, context)) return null

current
.classAttributes
Expand All @@ -107,14 +107,12 @@ class PydanticTypeProvider : PyTypeProviderBase() {
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext,
pyClass: PyClass): PyCallableParameter? {
val stub = field.stub
val fieldStub = if (stub == null) PydanticFieldStubImpl.create(field) else stub.getCustomStub(PydanticFieldStub::class.java)
if (fieldStub != null && !fieldStub.initValue()) return null
if (fieldStub == null && field.annotationValue == null && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

if (field.annotationValue == null && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

val defaultValue = when {
pyClass.isSubclass("pydantic.env_settings.BaseSettings", context) -> ellipsis
else -> getDefaultValueForParameter(field, fieldStub, ellipsis, context)
else -> getDefaultValueForParameter(field, ellipsis, context)
}

return PyCallableParameterImpl.nonPsi(field.name,
Expand All @@ -130,36 +128,31 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}

private fun getDefaultValueForParameter(field: PyTargetExpression,
fieldStub: PydanticFieldStub?,
ellipsis: PyNoneLiteralExpression,
context: TypeEvalContext): PyExpression? {
if (fieldStub == null) {
val value = field.findAssignedValue()
when {
value == null -> {
val annotation = (field.annotation?.value as? PySubscriptionExpressionImpl) ?: return null

when {
annotation.qualifier?.text == "Optional" -> return ellipsis
annotation.qualifier?.text == "Union" -> for (child in annotation.children) {
if (child is PyTupleExpression) {
for (type in child.children) {
if (type is PyNoneLiteralExpression) {
return ellipsis
}
val value = field.findAssignedValue()
when {
value == null -> {
val annotation = (field.annotation?.value as? PySubscriptionExpressionImpl) ?: return null

when {
annotation.qualifier?.text == "Optional" -> return ellipsis
annotation.qualifier?.text == "Union" -> for (child in annotation.children) {
if (child is PyTupleExpression) {
for (type in child.children) {
if (type is PyNoneLiteralExpression) {
return ellipsis
}
}
}
}
return value
}
field.hasAssignedValue() -> return getDefaultValueByAssignedValue(field, ellipsis, context)
else -> return null
return value
}
} else if (fieldStub.hasDefault() || fieldStub.hasDefaultFactory()) {
return ellipsis
field.hasAssignedValue() -> return getDefaultValueByAssignedValue(field, ellipsis, context)
else -> return null
}
return null
}

private fun getResolveElements(referenceExpression: PyReferenceExpression, context: TypeEvalContext): Array<ResolveResult> {
Expand All @@ -184,7 +177,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
.asSequence()
.forEach { it ->
val pyClass = PsiTreeUtil.getContextOfType(it, PyClass::class.java)
if (pyClass != null && pyClass.isSubclass("pydantic.schema.Schema", context)) {
if (pyClass != null && isPydanticField(pyClass, context)) {
val defaultValue = assignedValue.getKeywordArgument("default")
?: assignedValue.getArgument(0, PyExpression::class.java)
when {
Expand Down