Skip to content

Commit

Permalink
Improve code style
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed May 7, 2021
1 parent cc5bec9 commit 3d6c2c9
Showing 1 changed file with 22 additions and 40 deletions.
62 changes: 22 additions & 40 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ class PydanticTypeProvider : PyTypeProviderBase() {
context: TypeEvalContext,
anchor: PsiElement?,
): Ref<PyType>? {
if (referenceTarget is PyTargetExpression) {
val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
val name = referenceTarget.name ?: return null
getRefTypeFromFieldName(name, context, pyClass)?.let { return it }
}
return null
if (referenceTarget !is PyTargetExpression) return null
val pyClass = getPyClassByAttribute(referenceTarget.parent) ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
val name = referenceTarget.name ?: return null
return getRefTypeFromFieldName(name, context, pyClass)
}

override fun getParameterType(param: PyNamedParameter, func: PyFunction, context: TypeEvalContext): Ref<PyType>? {
Expand Down Expand Up @@ -106,7 +104,6 @@ class PydanticTypeProvider : PyTypeProviderBase() {
getGenericTypeMap(pyClass, context)
)
?.let { parameter -> Ref.create(parameter.getType(context)) }

}


Expand Down Expand Up @@ -173,10 +170,8 @@ class PydanticTypeProvider : PyTypeProviderBase() {
if (!isGenericModel && (rootOperandType as? PyCustomType)?.classQName != GENERIC_Q_NAME) return@flatMap emptyList()

when (val indexExpression = pySubscriptionExpression.indexExpression) {
is PyTupleExpression -> indexExpression.elements
.map { context.getType(it) }.filterIsInstance<PyGenericType>().toList()
is PyGenericType -> listOf(context.getType(indexExpression))
is PyTypedElement -> (context.getType(indexExpression) as? PyGenericType)?.let { listOf(it) }
is PyTupleExpression -> indexExpression.elements.map { context.getType(it) }.toList()
is PyTypedElement -> listOf(context.getType(indexExpression))
else -> null
} ?: emptyList()
}.filterIsInstance<PyGenericType>().distinct()
Expand Down Expand Up @@ -209,19 +204,16 @@ class PydanticTypeProvider : PyTypeProviderBase() {
.map {
when {
it is PyClass -> getPydanticTypeForClass(it, context, true, pyCallExpression)
it is PyParameter && it.isSelf -> {
it is PyParameter && it.isSelf ->
PsiTreeUtil.getParentOfType(it, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.let {
it.containingClass?.let {
getPydanticTypeForClass(it,
context,
true,
pyCallExpression
)
}
?.containingClass?.let {
getPydanticTypeForClass(it,
context,
true,
pyCallExpression
)
}
}
it is PyNamedParameter -> it.getArgumentType(context)?.let { pyType ->
getPyClassTypeByPyTypes(pyType).filter { pyClassType ->
pyClassType.isDefinition
Expand Down Expand Up @@ -278,9 +270,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pyTargetExpression: PyTargetExpression,
context: TypeEvalContext,
): PyClass? {
val pyCallableType = getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression, context)
?: return null
return pyCallableType.pyClass
return getPydanticDynamicModelTypeForTargetExpression(pyTargetExpression, context)?.pyClass
}

private fun getPydanticDynamicModelTypeForTargetExpression(
Expand All @@ -303,8 +293,6 @@ class PydanticTypeProvider : PyTypeProviderBase() {
.filterIsInstance<PyFunction>()
.map { it.takeIf { pyFunction -> isPydanticCreateModel(pyFunction) } }.firstOrNull()
?: return null


return getPydanticDynamicModelTypeForFunction(pyFunction, arguments, context)
}

Expand Down Expand Up @@ -423,14 +411,13 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}

private fun getBaseSetting(pyClass: PyClass, context: TypeEvalContext): PyClass? {
pyClass.getSuperClasses(context).forEach {
return if (isBaseSetting(it)) {
return pyClass.getSuperClasses(context).mapNotNull {
if (isBaseSetting(it)) {
it
} else {
getBaseSetting(it, context)
}
}
return null
}.firstOrNull()
}

fun getGenericTypeMap(
Expand All @@ -443,11 +430,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {

// class Response(GenericModel, Generic[TypeA, TypeB]): pass
val pyClassGenericTypeMap = pyTypingTypeProvider.getGenericSubstitutions(pyClass, context)
.mapNotNull { (key, value) ->
if (key is PyGenericType && value is PyType) {
Pair(key, value)
} else null
}.toMap()
.mapNotNull { (key, value) -> key to value }.filterIsInstance<Pair<PyGenericType, PyType>>().toMap()

// Response[TypeA]
val pySubscriptionExpression = when (val firstChild = pyCallExpression?.firstChild) {
Expand Down Expand Up @@ -561,10 +544,9 @@ class PydanticTypeProvider : PyTypeProviderBase() {
// get type from annotation
else -> getTypeForParameter(field, context)
}?.let {
if (genericTypeMap == null) {
it
} else {
PyTypeChecker.substitute(it, genericTypeMap, context)
when (genericTypeMap) {
null -> it
else -> PyTypeChecker.substitute(it, genericTypeMap, context)
}
}

Expand Down

0 comments on commit 3d6c2c9

Please sign in to comment.