Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ConfigDict for 232 #747

Merged
merged 1 commit into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ You can install the stable version on PyCharm's `Marketplace` (Preference -> Plu
* Insert unfilled arguments with a QuickFix for subclasses of `pydantic.BaseModel`
* Support typing.Annotated (PEP 593)
* Regex arguments in `Field` and `constr` are treated as Python's regex string literals
* Config/ConfigDict support
* Auto-completion for Config/ConfigDict
* Read Model config such `frozen=True` from Config/ConfigDict

#### pydantic.generics.GenericModel
* Support same features as `pydantic.BaseModel`
Expand All @@ -60,6 +63,11 @@ You can install the stable version on PyCharm's `Marketplace` (Preference -> Plu
* Support same features as `pydantic.BaseModel`
* (This plugin version 0.3.12 or later)

### Supported Pydantic major versions
- v0
- v1
- v2

## Contribute
We are waiting for your contributions to `pydantic-pycharm-plugin`.

Expand Down
8 changes: 8 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ I got interviewed about this plugin for [JetBrains' PyCharm Blog](https://blog.j
* Insert unfilled arguments with a QuickFix for subclasses of `pydantic.BaseModel`
* Support typing.Annotated (PEP 593)
* Regex arguments in `Field` and `constr` are treated as Python's regex string literals
* Config/ConfigDict support
* Auto-completion for Config/ConfigDict
* Read Model config such `frozen=True` from Config/ConfigDict

#### pydantic.generics.GenericModel
* Support same features as `pydantic.BaseModel`
Expand All @@ -42,6 +45,11 @@ I got interviewed about this plugin for [JetBrains' PyCharm Blog](https://blog.j
* Support same features as `pydantic.BaseModel`
* (This plugin version 0.3.12 or later)

### Supported Pydantic major versions
- v0
- v1
- v2

## Demo
![demo1](demo1.gif)

Expand Down
80 changes: 70 additions & 10 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const val DEPRECATED_SCHEMA_Q_NAME = "pydantic.fields.Schema"
const val BASE_SETTINGS_Q_NAME = "pydantic.env_settings.BaseSettings"
const val VERSION_Q_NAME = "pydantic.version.VERSION"
const val BASE_CONFIG_Q_NAME = "pydantic.main.BaseConfig"
const val CONFIG_DICT_Q_NAME = "pydantic.config.ConfigDict"
const val CONFIG_DICT_SHORT_Q_NAME = "pydantic.ConfigDict"
const val CONFIG_DICT_DEFAULTS_Q_NAME = "pydantic._internal._config.config_defaults"
const val DATACLASS_MISSING = "dataclasses.MISSING"
const val CON_BYTES_Q_NAME = "pydantic.types.conbytes"
const val CON_DECIMAL_Q_NAME = "pydantic.types.condecimal"
Expand Down Expand Up @@ -80,6 +83,12 @@ val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME)

val CONFIG_DICT_QUALIFIED_NAME = QualifiedName.fromDottedString(CONFIG_DICT_Q_NAME)

val CONFIG_DICT_DEFAULTS_QUALIFIED_NAME = QualifiedName.fromDottedString(CONFIG_DICT_DEFAULTS_Q_NAME)

val CONFIG_DICT_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(CONFIG_DICT_SHORT_Q_NAME)

val BASE_MODEL_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_MODEL_Q_NAME)

val VALIDATOR_QUALIFIED_NAME = QualifiedName.fromDottedString(VALIDATOR_Q_NAME)
Expand Down Expand Up @@ -155,11 +164,17 @@ val CONFIG_TYPES = mapOf(
"allow_mutation" to ConfigType.BOOLEAN,
"frozen" to ConfigType.BOOLEAN,
"keep_untouched" to ConfigType.LIST_PYTYPE,
"extra" to ConfigType.EXTRA
"extra" to ConfigType.EXTRA,
"populate_by_name" to ConfigType.BOOLEAN,
"from_attributes" to ConfigType.BOOLEAN,
)

const val CUSTOM_ROOT_FIELD = "__root__"

const val MODEL_FIELD_PREFIX = "model_"

const val MODEL_CONFIG_FIELD = "model_config"

fun PyTypedElement.getType(context: TypeEvalContext): PyType? = context.getType(this)


Expand Down Expand Up @@ -369,16 +384,15 @@ fun getPsiElementByQualifiedName(
return qualifiedName.resolveToElement(QNameResolveContext(contextAnchor, pythonSdk, context))
}

fun isValidField(field: PyTargetExpression, context: TypeEvalContext): Boolean {
if (field.name?.isValidFieldName != true) return false
fun isValidField(field: PyTargetExpression, context: TypeEvalContext, isV2: Boolean): Boolean {
if (field.name?.isValidFieldName(isV2) != true) return false

val annotationValue = field.annotation?.value ?: return true
// TODO Support a variable.
return getQualifiedName(annotationValue, context) != CLASSVAR_Q_NAME
}

val String.isValidFieldName: Boolean get() = !startsWith('_') || this == CUSTOM_ROOT_FIELD

fun String.isValidFieldName(isV2: Boolean): Boolean = (!startsWith('_') || this == CUSTOM_ROOT_FIELD) && !(isV2 && this.startsWith(MODEL_FIELD_PREFIX))

fun getConfigValue(name: String, value: Any?, context: TypeEvalContext): Any? {
if (value is PyReferenceExpression) {
Expand Down Expand Up @@ -425,6 +439,7 @@ fun validateConfig(pyClass: PyClass, context: TypeEvalContext): List<PsiElement>
return results
}


fun getConfig(
pyClass: PyClass,
context: TypeEvalContext,
Expand All @@ -443,6 +458,31 @@ fun getConfig(
}
}
}
if (version?.isV2 == true) {
val configDict = pyClass.findClassAttribute(MODEL_CONFIG_FIELD, false, context)?.findAssignedValue().let {
when (it) {
is PyReferenceExpression -> {
val targetExpression = getResolvedPsiElements(it, context).firstOrNull() ?: return@let null
(targetExpression as? PyTargetExpression)?.findAssignedValue() ?: return@let null
}
else -> it
}
}
when (configDict) {
is PyDictLiteralExpression -> configDict.elements.forEach { element ->
element.key.text.drop(1).dropLast(1).let { name ->
config[name] = getConfigValue(name, element.value, context)
}
}
is PyCallExpression -> configDict.arguments.forEach { argument ->
argument.name?.let {name ->
configDict.getKeywordArgument(name)?.let { value ->
config[name] = getConfigValue(name, value, context)
}
}
}
}
}
pyClass.nestedClasses.firstOrNull { it.isConfigClass }?.let {
it.classAttributes.forEach { attribute ->
attribute.findAssignedValue()?.let { value ->
Expand All @@ -462,10 +502,18 @@ fun getConfig(
}

if (setDefault) {
DEFAULT_CONFIG.forEach { (key, value) ->
if (!config.containsKey(key)) {
config[key] = getConfigValue(key, value, context)
if (version?.isV2 == true) {
PydanticCacheService.getConfigDictDefaults(pyClass.project, context)
?.filterNot { config.containsKey(it.key) }
?.forEach { (name, value) ->
config[name] = value
}
}
} else {
DEFAULT_CONFIG.forEach { (key, value) ->
if (!config.containsKey(key)) {
config[key] = getConfigValue(key, value, context)
}
}
}
return config
Expand Down Expand Up @@ -495,17 +543,29 @@ fun getPydanticBaseConfig(project: Project, context: TypeEvalContext): PyClass?
return getPyClassFromQualifiedName(BASE_CONFIG_QUALIFIED_NAME, project, context)
}

fun getPydanticConfigDictDefaults(project: Project, context: TypeEvalContext): PyCallExpression? {
val targetExpression = getPyTargetExpressionFromQualifiedName(CONFIG_DICT_DEFAULTS_QUALIFIED_NAME, project, context) ?: return null
return targetExpression.findAssignedValue() as? PyCallExpression
}

fun getPydanticBaseModel(project: Project, context: TypeEvalContext): PyClass? {
return getPyClassFromQualifiedName(BASE_MODEL_QUALIFIED_NAME, project, context)
}

fun getPyClassFromQualifiedName(qualifiedName: QualifiedName, project: Project, context: TypeEvalContext): PyClass? {
fun getPsiElementFromQualifiedName(qualifiedName: QualifiedName, project: Project, context: TypeEvalContext): PsiElement? {
val module = project.modules.firstOrNull() ?: return null
val pythonSdk = module.pythonSdk
val contextAnchor = ModuleBasedContextAnchor(module)
return qualifiedName.resolveToElement(QNameResolveContext(contextAnchor, pythonSdk, context)) as? PyClass
return qualifiedName.resolveToElement(QNameResolveContext(contextAnchor, pythonSdk, context))
}

fun getPyClassFromQualifiedName(qualifiedName: QualifiedName, project: Project, context: TypeEvalContext): PyClass? {
return getPsiElementFromQualifiedName(qualifiedName, project, context) as? PyClass
}

fun getPyTargetExpressionFromQualifiedName(qualifiedName: QualifiedName, project: Project, context: TypeEvalContext): PyTargetExpression? {
return getPsiElementFromQualifiedName(qualifiedName, project, context) as? PyTargetExpression
}
fun getPyClassByAttribute(pyPsiElement: PsiElement?): PyClass? {
return pyPsiElement?.parent?.parent as? PyClass
}
Expand Down
32 changes: 27 additions & 5 deletions src/com/koxudaxi/pydantic/PydanticCacheService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@ import com.jetbrains.python.sdk.pythonSdk
class PydanticCacheService(val project: Project) {
private var version: KotlinVersion? = null
private var allowedConfigKwargs: Set<String>? = null
private var configDictDefaults: Map<String, Any?>? = null

private fun getConfigDictDefaults(project: Project, context: TypeEvalContext): Map<String, Any?>? {
val configDictDefaults = getPydanticConfigDictDefaults(project, context) ?: return null
return configDictDefaults.arguments.filter { CONFIG_TYPES.containsKey(it.name) }
.mapNotNull {
val name = it.name ?: return@mapNotNull null
name to getConfigValue(name, configDictDefaults.getKeywordArgument(name), context)
}.toMap()
}

private fun getAllowedConfigKwargs(context: TypeEvalContext): Set<String>? {
val baseConfig = getPydanticBaseConfig(project, context) ?: return null
return baseConfig.classAttributes
.mapNotNull { it.name }
.filterNot { it.startsWith("__") && it.endsWith("__") }
.toSet()
val baseConfigAttributes = when {
isV2 -> getPydanticConfigDictDefaults(project, context)?.arguments?.mapNotNull { it.name }
else -> { getPydanticBaseConfig(project, context)?.classAttributes?.mapNotNull { it.name } }
} ?: return null
return baseConfigAttributes
.filterNot { it.startsWith("__") && it.endsWith("__") }
.toSet()
}
private fun getVersion(): KotlinVersion? {
val sdk = project.pythonSdk ?: return null
Expand Down Expand Up @@ -49,9 +61,15 @@ class PydanticCacheService(val project: Project) {
return getAllowedConfigKwargs(context).apply { allowedConfigKwargs = this }
}

private fun getOrConfigDictDefaults(project: Project, context: TypeEvalContext): Map<String, Any?>? {
if ( configDictDefaults != null) return configDictDefaults
return getConfigDictDefaults(project, context).apply { configDictDefaults = this }
}

private fun clear() {
version = null
allowedConfigKwargs = null
configDictDefaults = null
}

internal val isV2 get() = this.getOrPutVersion().isV2
Expand All @@ -72,6 +90,10 @@ class PydanticCacheService(val project: Project) {
return getInstance(project).getOrAllowedConfigKwargs(context)
}

fun getConfigDictDefaults(project: Project, context: TypeEvalContext): Map<String, Any?>? {
return getInstance(project).getOrConfigDictDefaults(project, context)
}

fun clear(project: Project) {
return getInstance(project).clear()
}
Expand Down
35 changes: 25 additions & 10 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class PydanticCompletionContributor : CompletionContributor() {
getClassVariables(pyClass, typeEvalContext)
.filter { it.name != null }
.filterNot { isUntouchedClass(it.findAssignedValue(), config, typeEvalContext) }
.filter { isValidField(it, typeEvalContext) }
.filter { isValidField(it, typeEvalContext, pydanticVersion.isV2) }
.filter { !isDataclass || isInInit(it) }
.forEach {
val elementName = getLookupNameFromFieldName(it, typeEvalContext, pydanticVersion, config, withEqual)
Expand Down Expand Up @@ -198,7 +198,7 @@ class PydanticCompletionContributor : CompletionContributor() {
if (!isPydanticModel(pyClass, true, typeEvalContext)) return

val fieldElements: HashSet<String> = HashSet()

val isV2 = PydanticCacheService.getVersion(pyClass.project).isV2
getAncestorPydanticModels(pyClass, true, typeEvalContext)
.forEach {
fieldElements.addAll(it.classAttributes
Expand All @@ -208,7 +208,7 @@ class PydanticCompletionContributor : CompletionContributor() {
typeEvalContext)
}
.filter { attribute ->
isValidField(attribute, typeEvalContext)
isValidField(attribute, typeEvalContext, isV2)
}
.mapNotNull { attribute -> attribute?.name })
}
Expand All @@ -217,7 +217,7 @@ class PydanticCompletionContributor : CompletionContributor() {

fieldElements.addAll(pyClass.classAttributes
.filterNot { isUntouchedClass(it.findAssignedValue(), config, typeEvalContext) }
.filter { isValidField(it, typeEvalContext) }
.filter { isValidField(it, typeEvalContext, isV2) }
.mapNotNull { attribute -> attribute?.name })

result.runRemainingContributors(parameters)
Expand Down Expand Up @@ -401,12 +401,27 @@ class PydanticCompletionContributor : CompletionContributor() {
context: ProcessingContext,
result: CompletionResultSet,
) {
val pydanticModel = getPydanticModelByAttribute(parameters.position.parent?.parent, true, parameters.getTypeEvalContext()) ?: return
if (pydanticModel.findNestedClass("Config", false) != null) return
val element = PrioritizedLookupElement.withGrouping(
LookupElementBuilder
.create("class Config:")
.withIcon(icon), 1)
val typeEvalContext = parameters.getTypeEvalContext()
val pydanticModel = getPydanticModelByAttribute(parameters.position.parent?.parent, true, typeEvalContext) ?: return
val element = when {
PydanticCacheService.getInstance(pydanticModel.project).isV2 -> {
if (pydanticModel.findClassAttribute(MODEL_CONFIG_FIELD, false, typeEvalContext) != null) return
PrioritizedLookupElement.withGrouping(
LookupElementBuilder
.create("$MODEL_CONFIG_FIELD = ConfigDict()").withInsertHandler { context, _ ->
context.editor.caretModel.moveCaretRelatively(-1, 0, false, false, false)
}
.withIcon(icon), 1)
}
else -> {
if (pydanticModel.findNestedClass("Config", false) != null) return
PrioritizedLookupElement.withGrouping(
LookupElementBuilder
.create("class Config:")
.withIcon(icon), 1
)
}
}
result.addElement(PrioritizedLookupElement.withPriority(element, 100.0))
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class PydanticInspection : PyInspection() {
.flatMap { pydanticModel ->
getClassVariables(pydanticModel, myTypeEvalContext)
.filter { it.name != null }
.filter { isValidField(it, myTypeEvalContext) }
.filter { isValidField(it, myTypeEvalContext, pydanticCacheService.isV2) }
.map { it.name }
}.toSet()
pyCallExpression.arguments
Expand Down Expand Up @@ -250,7 +250,7 @@ class PydanticInspection : PyInspection() {
private fun inspectWarnUntypedFields(node: PyAssignmentStatement) {
if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return
if (node.annotation != null) return
if ((node.leftHandSideExpression as? PyTargetExpressionImpl)?.text?.isValidFieldName != true) return
if ((node.leftHandSideExpression as? PyTargetExpressionImpl)?.text?.isValidFieldName(pydanticCacheService.isV2) != true) return
registerProblem(
node,
"Untyped fields disallowed", ProblemHighlightType.WARNING
Expand Down
4 changes: 2 additions & 2 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
} ?: getPydanticBaseModel(project, context) ?: return null

collected.putAll(keywordArguments
.filter { (name, _) -> name.isValidFieldName && !name.startsWith('_') }
.filter { (name, _) -> name.isValidFieldName(pydanticVersion.isV2) && !name.startsWith('_') }
.filter { (name, _) -> (newVersion || name != "model_name") }
.map { (name, field) ->
val parameter = dynamicModelFieldToParameter(field, context, typed)
Expand Down Expand Up @@ -544,7 +544,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
typed: Boolean = true,
isDataclass: Boolean = false,
): PyCallableParameter? {
if (!isValidField(field, context)) return null
if (!isValidField(field, context, pydanticVersion.isV2)) return null
if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

val defaultValueFromField =
Expand Down
5 changes: 5 additions & 0 deletions testData/completionv2/configDict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

from pydantic import BaseModel

class A(BaseModel):
m<caret>
Loading