diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 50722e48..72d15897 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -59,6 +59,7 @@ class PydanticCompletionContributor : CompletionContributor() { context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap, + withEqual: Boolean ): String val typeProvider: PydanticTypeProvider = PydanticTypeProvider() @@ -110,6 +111,7 @@ class PydanticCompletionContributor : CompletionContributor() { excludes: HashSet?, isDataclass: Boolean, genericTypeMap: Map?, + withEqual: Boolean ) { val pydanticVersion = PydanticCacheService.getVersion(pyClass.project, typeEvalContext) getClassVariables(pyClass, typeEvalContext) @@ -118,7 +120,7 @@ class PydanticCompletionContributor : CompletionContributor() { .filter { isValidField(it, typeEvalContext) } .filter { !isDataclass || isInInit(it) } .forEach { - val elementName = getLookupNameFromFieldName(it, typeEvalContext, pydanticVersion, config) + val elementName = getLookupNameFromFieldName(it, typeEvalContext, pydanticVersion, config, withEqual) if (excludes == null || !excludes.contains(elementName)) { val typeText = getTypeText(pyClass, typeEvalContext, @@ -148,6 +150,7 @@ class PydanticCompletionContributor : CompletionContributor() { genericTypeMap: Map?, excludes: HashSet? = null, isDataclass: Boolean, + trimEqual: Boolean ) { val newElements: LinkedHashMap = LinkedHashMap() @@ -161,7 +164,8 @@ class PydanticCompletionContributor : CompletionContributor() { config, excludes, isDataclass, - genericTypeMap) + genericTypeMap, + !trimEqual) } addFieldElement(pyClass, @@ -171,11 +175,14 @@ class PydanticCompletionContributor : CompletionContributor() { config, excludes, isDataclass, - genericTypeMap) + genericTypeMap, + !trimEqual) result.runRemainingContributors(parameters) { completionResult -> - completionResult.lookupElement.lookupString + completionResult.lookupElement.lookupString.let { + if (trimEqual) it.trimEnd('=') else it + } .takeIf { name -> !newElements.containsKey(name) && (excludes == null || !excludes.contains(name)) } ?.let { result.passResult(completionResult) } } @@ -238,8 +245,10 @@ class PydanticCompletionContributor : CompletionContributor() { context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap, + withEqual: Boolean ): String { - return "${getFieldName(field, context, config, pydanticVersion)}=" + val suffix = if(withEqual) "=" else "" + return "${getFieldName(field, context, config, pydanticVersion)}$suffix" } override val icon: Icon = AllIcons.Nodes.Parameter @@ -265,7 +274,9 @@ class PydanticCompletionContributor : CompletionContributor() { .mapNotNull { (it as? PyKeywordArgument)?.name } .map { "${it}=" } .toHashSet() - + val keyword = parameters.originalPosition?.text + val parameter = parameters.originalPosition?.parent?.text + val hasEqual = parameter?.startsWith("$keyword=") ?: false addAllFieldElement( parameters, result, @@ -276,6 +287,7 @@ class PydanticCompletionContributor : CompletionContributor() { typeProvider.getGenericTypeMap(pyClass, typeEvalContext, pyCallExpression), definedSet, pyClass.isPydanticDataclass, + hasEqual ) } } @@ -286,6 +298,7 @@ class PydanticCompletionContributor : CompletionContributor() { context: TypeEvalContext, pydanticVersion: KotlinVersion?, config: HashMap, + withEqual: Boolean ): String { return field.name!! } @@ -320,7 +333,8 @@ class PydanticCompletionContributor : CompletionContributor() { ellipsis, config, typeProvider.getGenericTypeMap(pyClass, typeEvalContext, pyTypedElement as? PyCallExpression), - isDataclass = pyClass.isPydanticDataclass, + isDataclass = pyClass.isPydanticDataclass , + trimEqual=false, ) } } diff --git a/testData/completionv18/insertedArgument.py b/testData/completionv18/insertedArgument.py new file mode 100644 index 00000000..2bc7388d --- /dev/null +++ b/testData/completionv18/insertedArgument.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, Field + + +class A(BaseModel): + abc_efg: str = '123' + abc_xyz: str = '456' + +A(abc_efg='789') diff --git a/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt b/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt index ba297774..9dd39ef4 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt @@ -94,4 +94,12 @@ open class PydanticCompletionV18Test : PydanticTestCase(version = "v18") { ) ) } + fun testInsertedArgument() { + doFieldTest( + listOf( + Pair("abc_efg", "str='123' A"), + Pair("abc_xyz", "str='456' A") + ) + ) + } }