Skip to content

Commit

Permalink
Add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed May 7, 2021
1 parent 0bbe7d4 commit 1b986a2
Show file tree
Hide file tree
Showing 10 changed files with 436 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const val ANNOTATED_Q_NAME = "typing.Annotated"
const val CLASSVAR_Q_NAME = "typing.ClassVar"
const val GENERIC_Q_NAME = "typing.Generic"
const val TYPE_Q_NAME = "typing.Type"

const val TUPLE_Q_NAME = "typing.Tuple"

val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

Expand Down
7 changes: 4 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHin
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyEvaluator
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyGenericType
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.TypeEvalContext
import javax.swing.Icon
Expand Down Expand Up @@ -70,7 +71,7 @@ class PydanticCompletionContributor : CompletionContributor() {
pydanticVersion: KotlinVersion?,
config: HashMap<String, Any?>,
isDataclass: Boolean,
genericTypeMap: Map<PyType, PyType>,
genericTypeMap: Map<PyGenericType, PyType>?,
): String {

val parameter = typeProvider.fieldToParameter(pyTargetExpression,
Expand Down Expand Up @@ -105,7 +106,7 @@ class PydanticCompletionContributor : CompletionContributor() {
config: HashMap<String, Any?>,
excludes: HashSet<String>?,
isDataclass: Boolean,
genericTypeMap: Map<PyType, PyType>,
genericTypeMap: Map<PyGenericType, PyType>?,
) {
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext)
getClassVariables(pyClass, typeEvalContext)
Expand Down Expand Up @@ -138,7 +139,7 @@ class PydanticCompletionContributor : CompletionContributor() {
pyClass: PyClass, typeEvalContext: TypeEvalContext,
ellipsis: PyNoneLiteralExpression,
config: HashMap<String, Any?>,
genericTypeMap: Map<PyType, PyType>,
genericTypeMap: Map<PyGenericType, PyType>?,
excludes: HashSet<String>? = null,
isDataclass: Boolean,
) {
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() {

return when {
callSite is PyCallExpression && definition -> dataclassCallableType
definition -> (dataclassType.declarationElement as? PyTypedElement)?.let { context.getType(it) }
definition -> dataclassType.toClass()
else -> dataclassType
}
}
Expand Down
105 changes: 75 additions & 30 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -109,29 +109,51 @@ class PydanticTypeProvider : PyTypeProviderBase() {

}


private fun getPyType(pyExpression: PyExpression, context: TypeEvalContext): PyType? {
return when (val type = context.getType(pyExpression)) {
is PyClassLikeType -> type.toInstance()
else -> type
}
}

private fun getInjectedGenericType(
pyExpression: PyExpression,
context: TypeEvalContext,
): PyType? {
return when (pyExpression) {
is PySubscriptionExpression -> {
val typingType = (pyExpression.rootOperand as? PyReferenceExpression)
?.let { pyReferenceExpression ->
getResolvedPsiElements(pyReferenceExpression, context)
.filterIsInstance<PyQualifiedNameOwner>()
.any { it.qualifiedName == TYPE_Q_NAME }
if (pyExpression is PySubscriptionExpression) {
val rootOperand = (pyExpression.rootOperand as? PyReferenceExpression)
?.let { pyReferenceExpression ->
getResolvedPsiElements(pyReferenceExpression, context)
.asSequence()
.filterIsInstance<PyQualifiedNameOwner>()
.firstOrNull()
}
when (val qualifiedName = rootOperand?.qualifiedName) {
TYPE_Q_NAME -> return (pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) }
in listOf(TUPLE_Q_NAME, UNION_Q_NAME, OPTIONAL_Q_NAME) -> {
val indexExpression = pyExpression.indexExpression
when (indexExpression) {
is PyTupleExpression -> indexExpression.elements
.map { element -> getInjectedGenericType(element, context) }
is PySubscriptionExpression -> listOf(getInjectedGenericType(indexExpression, context))
is PyTypedElement -> listOf(getPyType(indexExpression, context))
else -> null
}?.let {
return when (qualifiedName) {
UNION_Q_NAME -> PyUnionType.union(it)
OPTIONAL_Q_NAME -> PyUnionType.union(it + PyNoneType.INSTANCE)
else -> PyTupleType.create(indexExpression as PsiElement, it)
}
}
return if (typingType == true) {
(pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) }
} else {
(context.getType(pyExpression) as? PyClassLikeType)?.toInstance()
}
}
else -> (context.getType(pyExpression) as? PyClassLikeType)?.toInstance()
}
return getPyType(pyExpression, context)
}

private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List<PyType?> {

private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List<PyGenericType?> {
return pyClass.superClassExpressions
.mapNotNull {
when (it) {
Expand All @@ -154,9 +176,10 @@ class PydanticTypeProvider : PyTypeProviderBase() {
is PyTupleExpression -> indexExpression.elements
.map { context.getType(it) }.filterIsInstance<PyGenericType>().toList()
is PyGenericType -> listOf(context.getType(indexExpression))
is PyTypedElement -> (context.getType(indexExpression) as? PyGenericType)?.let { listOf(it) }
else -> null
}
}.filterNotNull().distinct()
} ?: emptyList()
}.filterIsInstance<PyGenericType>().distinct()
}

override fun prepareCalleeTypeForCall(
Expand Down Expand Up @@ -193,7 +216,9 @@ class PydanticTypeProvider : PyTypeProviderBase() {
it.containingClass?.let {
getPydanticTypeForClass(it,
context,
pyCallExpression = pyCallExpression)
true,
pyCallExpression
)
}
}
}
Expand All @@ -216,10 +241,12 @@ class PydanticTypeProvider : PyTypeProviderBase() {
?.filter { pyClassType -> pyClassType.isDefinition }
?.filterNot { pyClassType -> pyClassType is PydanticDynamicModelClassType }
?.map { filteredPyClassType ->
getPydanticTypeForClass(filteredPyClassType.pyClass,
getPydanticTypeForClass(
filteredPyClassType.pyClass,
context,
true,
pyCallExpression)
pyCallExpression
)
}?.firstOrNull()
} ?: getPydanticDynamicModelTypeForTargetExpression(it, context)?.pyCallableType
else -> null
Expand Down Expand Up @@ -410,33 +437,43 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pyClass: PyClass,
context: TypeEvalContext,
pyCallExpression: PyCallExpression? = null,
): Map<PyType, PyType> {
if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return emptyMap()
if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return emptyMap()
val pyClassGenericTypeMap =
pyTypingTypeProvider.getGenericSubstitutions(pyClass, context).filterValues { it is PyType }
): Map<PyGenericType, PyType>? {
if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return null
if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return null

// class Response(GenericModel, Generic[TypeA, TypeB]): pass
val pyClassGenericTypeMap = pyTypingTypeProvider.getGenericSubstitutions(pyClass, context)
.mapNotNull { (key, value) ->
if (key is PyGenericType && value is PyType) {
Pair(key, value)
} else null
}.toMap()

// Response[TypeA]
val pySubscriptionExpression = when (val firstChild = pyCallExpression?.firstChild) {
is PySubscriptionExpression -> firstChild
is PyReferenceExpression -> getResolvedPsiElements(firstChild, context)
.firstOrNull()
?.let { it as? PyTargetExpression }
?.findAssignedValue() as? PySubscriptionExpression
else -> null
} ?: return pyClassGenericTypeMap
} ?: return pyClassGenericTypeMap.takeIf { it.isNotEmpty() }

// Response[TypeA, TypeB]()
val injectedTypes = (pySubscriptionExpression.indexExpression as? PyTupleExpression)
?.elements
?.map { getInjectedGenericType(it, context) }
// Response[TypeA]()
?: listOf((pySubscriptionExpression.indexExpression?.let { getInjectedGenericType(it, context) }))


return pyClassGenericTypeMap.toMutableMap().apply {
this.putAll(collectGenericTypes(pyClass, context)
.take(injectedTypes.size)
.mapIndexed { index, genericType -> genericType to injectedTypes[index] }
.filterIsInstance<Pair<PyType, PyType>>().toMap()
.filterIsInstance<Pair<PyGenericType, PyType>>().toMap()
)
}
}.takeIf { it.isNotEmpty() }
}

fun getPydanticTypeForClass(
Expand Down Expand Up @@ -502,14 +539,15 @@ class PydanticTypeProvider : PyTypeProviderBase() {
pyClass: PyClass,
pydanticVersion: KotlinVersion?,
config: HashMap<String, Any?>,
genericTypeMap: Map<PyType, PyType>,
genericTypeMap: Map<PyGenericType, PyType>?,
typed: Boolean = true,
isDataclass: Boolean = false,
): PyCallableParameter? {
if (!isValidField(field, context)) return null
if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax

val defaultValueFromField = getDefaultValueForParameter(field, ellipsis, context, pydanticVersion, isDataclass)
val defaultValueFromField =
getDefaultValueForParameter(field, ellipsis, context, pydanticVersion, isDataclass)
val defaultValue = when {
isSubClassOfBaseSetting(pyClass, context) -> ellipsis
else -> defaultValueFromField
Expand All @@ -523,7 +561,11 @@ class PydanticTypeProvider : PyTypeProviderBase() {
// get type from annotation
else -> getTypeForParameter(field, context)
}?.let {
genericTypeMap[it] ?: it
if (genericTypeMap == null) {
it
} else {
PyTypeChecker.substitute(it, genericTypeMap, context)
}
}

return PyCallableParameterImpl.nonPsi(
Expand Down Expand Up @@ -712,7 +754,10 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
}

private fun getDefaultValueForDataclass(assignedValue: PyCallExpression, context: TypeEvalContext): PyExpression? {
private fun getDefaultValueForDataclass(
assignedValue: PyCallExpression,
context: TypeEvalContext,
): PyExpression? {
val defaultValue = getDefaultValueForDataclass(assignedValue, context, "default")
val defaultFactoryValue = getDefaultValueForDataclass(assignedValue, context, "default_factory")
return when {
Expand Down
21 changes: 21 additions & 0 deletions testData/completionv18/genericField.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TypeVar, Type, List, Dict, Generic, Optional
from pydantic.generics import GenericModel


AT = TypeVar('AT')
BT = TypeVar('BT')
CT = TypeVar('CT')
DT = TypeVar('DT')
ET = TypeVar('ET')


class A(GenericModel, Generic[AT, BT, CT, DT]):
a: Type[AT]
b: List[BT]
c: Dict[CT, DT]

class B(A[int, BT, CT, DT], Generic[BT, CT, DT, ET]):
hij: Optional[ET]


B[str, float, bytes, bool]().<caret>
21 changes: 21 additions & 0 deletions testData/completionv18/genericKeywordArgument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TypeVar, Type, List, Dict, Generic, Optional
from pydantic.generics import GenericModel


AT = TypeVar('AT')
BT = TypeVar('BT')
CT = TypeVar('CT')
DT = TypeVar('DT')
ET = TypeVar('ET')


class A(GenericModel, Generic[AT, BT, CT, DT]):
a: Type[AT]
b: List[BT]
c: Dict[CT, DT]

class B(A[int, BT, CT, DT], Generic[BT, CT, DT, ET]):
hij: Optional[ET]


B[str, float, bytes, bool](<caret>)
38 changes: 38 additions & 0 deletions testData/mock/pydanticv18/generics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
get_type_hints,
)

from .main import BaseModel

_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type


class GenericModel(BaseModel):
__slots__ = ()
__concrete__: ClassVar[bool] = False

if TYPE_CHECKING:
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
__parameters__: ClassVar[Tuple[TypeVarType, ...]]

# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
pass
Loading

0 comments on commit 1b986a2

Please sign in to comment.