Skip to content

Commit

Permalink
Fix default value by variable for Field is not recognized
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed Jul 1, 2021
1 parent fecc583 commit e984f31
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
11 changes: 9 additions & 2 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,15 @@ internal fun getTypeExpressionFromAnnotated(annotated: PyExpression): PyExpressi
?.getOrNull(0)
?.let { it as? PyExpression }

internal fun getDefaultFromField(field: PyCallExpression): PyExpression? = field.getKeywordArgument("default")
?: field.getArgument(0, PyExpression::class.java).takeIf { it?.name == null }
internal fun getDefaultFromField(field: PyCallExpression, context: TypeEvalContext): PyExpression? =
field.getKeywordArgument("default")
?: field.getArgument(0, PyExpression::class.java)?.let {
when {
it is PyReferenceExpression -> getResolvedPsiElements(it, context).firstOrNull() as? PyExpression
it.name == null -> it
else -> null
}
}

internal fun getDefaultFactoryFromField(field: PyCallExpression): PyExpression? =
field.getKeywordArgument("default_factory")
Expand Down
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class PydanticInspection : PyInspection() {
if (qualifiedName != ANNOTATED_Q_NAME) return

val annotatedField = getFieldFromAnnotated(annotationValue, myTypeEvalContext) ?: return
val default = getDefaultFromField(annotatedField)
val default = getDefaultFromField(annotatedField, myTypeEvalContext)
if (default != null) {
registerProblem(
default.parent,
Expand All @@ -204,7 +204,7 @@ class PydanticInspection : PyInspection() {
val assignedValueField =
assignedValue?.let { getFieldFromPyExpression(assignedValue, myTypeEvalContext, null) }
if (assignedValueField != null) {
val default: PyExpression? = getDefaultFromField(assignedValueField)
val default: PyExpression? = getDefaultFromField(assignedValueField, myTypeEvalContext)
val defaultFactory: PyExpression? = getDefaultFactoryFromField(assignedValueField)
if (!validateDefaultAndDefaultFactory(default, defaultFactory)) return
}
Expand All @@ -221,7 +221,7 @@ class PydanticInspection : PyInspection() {
return
}
val annotatedField = getFieldFromAnnotated(annotationValue, myTypeEvalContext) ?: return
val default = getDefaultFromField(annotatedField)
val default = getDefaultFromField(annotatedField, myTypeEvalContext)
val defaultFactory = getDefaultFactoryFromField(annotatedField)
if (!validateDefaultAndDefaultFactory(assignedValue, defaultFactory)) return
if (default != null) {
Expand Down
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -709,19 +709,19 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
.let {
return when {
it -> getDefaultValue(assignedValue)
it -> getDefaultValue(assignedValue, context)
else -> assignedValue
}
}
}
}

private fun getDefaultValue(assignedValue: PyCallExpression): PyExpression? {
private fun getDefaultValue(assignedValue: PyCallExpression, typeEvalContext: TypeEvalContext): PyExpression? {
getDefaultFactoryFromField(assignedValue)
?.let {
return assignedValue
}
return getDefaultFromField(assignedValue)?.takeIf { it.text != "..." }
return getDefaultFromField(assignedValue, typeEvalContext)?.takeIf { it.text != "..." }
}

private fun getDefaultValueForDataclass(
Expand Down
9 changes: 8 additions & 1 deletion testData/typeinspection/fieldField.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

from pydantic import BaseModel, Field

NUMBER = 123

class A(BaseModel):
a: int = Field(int(123))
b = Field(123)
c = Field(default=int(123))
d: int = Field(...)
e: int = Field(NUMBER)
f: int = Field(default=NUMBER)
g = Field(NUMBER)
h: int = Field(NUMBER)
i: int = Field(default=NUMBER)
j = Field(NUMBER)

A(a=int(123), b=int(123), c=int(123), d=int(123))
A(a=int(123), b=int(123), c=int(123), d=int(123), e=int(123), f=int(123), g=int(123))
6 changes: 5 additions & 1 deletion testData/typeinspection/fieldFieldInvalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

from pydantic import BaseModel, Field

NUMBER = 123

class A(BaseModel):
a: int = Field(int(123))
b = Field(int(123))
c = Field(default=int(123))
e: int = Field(NUMBER)
f: int = Field(default=NUMBER)
g = Field(NUMBER)

A(<warning descr="Expected type 'int', got 'str' instead">a=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">b=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">c=str('123')</warning>)
A(<warning descr="Expected type 'int', got 'str' instead">a=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">b=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">c=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">e=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">f=str('123')</warning>, <warning descr="Expected type 'int', got 'str' instead">g=str('123')</warning>)

0 comments on commit e984f31

Please sign in to comment.