From 1b986a2f6b0c2955b5fa954676d7bd4dc197b128 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Fri, 7 May 2021 22:12:19 +0900 Subject: [PATCH] Add unittest --- src/com/koxudaxi/pydantic/Pydantic.kt | 2 +- .../pydantic/PydanticCompletionContributor.kt | 7 +- .../pydantic/PydanticDataclassTypeProvider.kt | 2 +- .../koxudaxi/pydantic/PydanticTypeProvider.kt | 105 +++++++--- testData/completionv18/genericField.py | 21 ++ .../completionv18/genericKeywordArgument.py | 21 ++ testData/mock/pydanticv18/generics.py | 38 ++++ testData/typeinspectionv18/genericModel.py | 191 ++++++++++++++++++ .../pydantic/PydanticCompletionV18Test.kt | 80 ++++++++ .../pydantic/PydanticTypeInspectionV18Test.kt | 4 + 10 files changed, 436 insertions(+), 35 deletions(-) create mode 100644 testData/completionv18/genericField.py create mode 100644 testData/completionv18/genericKeywordArgument.py create mode 100644 testData/mock/pydanticv18/generics.py create mode 100644 testData/typeinspectionv18/genericModel.py create mode 100644 testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt diff --git a/src/com/koxudaxi/pydantic/Pydantic.kt b/src/com/koxudaxi/pydantic/Pydantic.kt index a713eed9..1f90c497 100644 --- a/src/com/koxudaxi/pydantic/Pydantic.kt +++ b/src/com/koxudaxi/pydantic/Pydantic.kt @@ -55,7 +55,7 @@ const val ANNOTATED_Q_NAME = "typing.Annotated" const val CLASSVAR_Q_NAME = "typing.ClassVar" const val GENERIC_Q_NAME = "typing.Generic" const val TYPE_Q_NAME = "typing.Type" - +const val TUPLE_Q_NAME = "typing.Tuple" val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME) diff --git a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt index 694a3ec8..96d1e59b 100644 --- a/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt +++ b/src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt @@ -12,6 +12,7 @@ import com.jetbrains.python.documentation.PythonDocumentationProvider.getTypeHin import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyEvaluator import com.jetbrains.python.psi.types.PyClassType +import com.jetbrains.python.psi.types.PyGenericType import com.jetbrains.python.psi.types.PyType import com.jetbrains.python.psi.types.TypeEvalContext import javax.swing.Icon @@ -70,7 +71,7 @@ class PydanticCompletionContributor : CompletionContributor() { pydanticVersion: KotlinVersion?, config: HashMap, isDataclass: Boolean, - genericTypeMap: Map, + genericTypeMap: Map?, ): String { val parameter = typeProvider.fieldToParameter(pyTargetExpression, @@ -105,7 +106,7 @@ class PydanticCompletionContributor : CompletionContributor() { config: HashMap, excludes: HashSet?, isDataclass: Boolean, - genericTypeMap: Map, + genericTypeMap: Map?, ) { val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext) getClassVariables(pyClass, typeEvalContext) @@ -138,7 +139,7 @@ class PydanticCompletionContributor : CompletionContributor() { pyClass: PyClass, typeEvalContext: TypeEvalContext, ellipsis: PyNoneLiteralExpression, config: HashMap, - genericTypeMap: Map, + genericTypeMap: Map?, excludes: HashSet? = null, isDataclass: Boolean, ) { diff --git a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt index 10dd9fe4..7094a41d 100644 --- a/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticDataclassTypeProvider.kt @@ -53,7 +53,7 @@ class PydanticDataclassTypeProvider : PyTypeProviderBase() { return when { callSite is PyCallExpression && definition -> dataclassCallableType - definition -> (dataclassType.declarationElement as? PyTypedElement)?.let { context.getType(it) } + definition -> dataclassType.toClass() else -> dataclassType } } diff --git a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt index 24d18a05..9b9d07ca 100644 --- a/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt +++ b/src/com/koxudaxi/pydantic/PydanticTypeProvider.kt @@ -109,29 +109,51 @@ class PydanticTypeProvider : PyTypeProviderBase() { } + + private fun getPyType(pyExpression: PyExpression, context: TypeEvalContext): PyType? { + return when (val type = context.getType(pyExpression)) { + is PyClassLikeType -> type.toInstance() + else -> type + } + } + private fun getInjectedGenericType( pyExpression: PyExpression, context: TypeEvalContext, ): PyType? { - return when (pyExpression) { - is PySubscriptionExpression -> { - val typingType = (pyExpression.rootOperand as? PyReferenceExpression) - ?.let { pyReferenceExpression -> - getResolvedPsiElements(pyReferenceExpression, context) - .filterIsInstance() - .any { it.qualifiedName == TYPE_Q_NAME } + if (pyExpression is PySubscriptionExpression) { + val rootOperand = (pyExpression.rootOperand as? PyReferenceExpression) + ?.let { pyReferenceExpression -> + getResolvedPsiElements(pyReferenceExpression, context) + .asSequence() + .filterIsInstance() + .firstOrNull() + } + when (val qualifiedName = rootOperand?.qualifiedName) { + TYPE_Q_NAME -> return (pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) } + in listOf(TUPLE_Q_NAME, UNION_Q_NAME, OPTIONAL_Q_NAME) -> { + val indexExpression = pyExpression.indexExpression + when (indexExpression) { + is PyTupleExpression -> indexExpression.elements + .map { element -> getInjectedGenericType(element, context) } + is PySubscriptionExpression -> listOf(getInjectedGenericType(indexExpression, context)) + is PyTypedElement -> listOf(getPyType(indexExpression, context)) + else -> null + }?.let { + return when (qualifiedName) { + UNION_Q_NAME -> PyUnionType.union(it) + OPTIONAL_Q_NAME -> PyUnionType.union(it + PyNoneType.INSTANCE) + else -> PyTupleType.create(indexExpression as PsiElement, it) + } } - return if (typingType == true) { - (pyExpression.indexExpression as? PyTypedElement)?.let { context.getType(it) } - } else { - (context.getType(pyExpression) as? PyClassLikeType)?.toInstance() } } - else -> (context.getType(pyExpression) as? PyClassLikeType)?.toInstance() } + return getPyType(pyExpression, context) } - private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List { + + private fun collectGenericTypes(pyClass: PyClass, context: TypeEvalContext): List { return pyClass.superClassExpressions .mapNotNull { when (it) { @@ -154,9 +176,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { is PyTupleExpression -> indexExpression.elements .map { context.getType(it) }.filterIsInstance().toList() is PyGenericType -> listOf(context.getType(indexExpression)) + is PyTypedElement -> (context.getType(indexExpression) as? PyGenericType)?.let { listOf(it) } else -> null - } - }.filterNotNull().distinct() + } ?: emptyList() + }.filterIsInstance().distinct() } override fun prepareCalleeTypeForCall( @@ -193,7 +216,9 @@ class PydanticTypeProvider : PyTypeProviderBase() { it.containingClass?.let { getPydanticTypeForClass(it, context, - pyCallExpression = pyCallExpression) + true, + pyCallExpression + ) } } } @@ -216,10 +241,12 @@ class PydanticTypeProvider : PyTypeProviderBase() { ?.filter { pyClassType -> pyClassType.isDefinition } ?.filterNot { pyClassType -> pyClassType is PydanticDynamicModelClassType } ?.map { filteredPyClassType -> - getPydanticTypeForClass(filteredPyClassType.pyClass, + getPydanticTypeForClass( + filteredPyClassType.pyClass, context, true, - pyCallExpression) + pyCallExpression + ) }?.firstOrNull() } ?: getPydanticDynamicModelTypeForTargetExpression(it, context)?.pyCallableType else -> null @@ -410,11 +437,19 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass: PyClass, context: TypeEvalContext, pyCallExpression: PyCallExpression? = null, - ): Map { - if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return emptyMap() - if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return emptyMap() - val pyClassGenericTypeMap = - pyTypingTypeProvider.getGenericSubstitutions(pyClass, context).filterValues { it is PyType } + ): Map? { + if (!PyTypingTypeProvider.isGeneric(pyClass, context)) return null + if (!(isSubClassOfPydanticGenericModel(pyClass, context) && !isPydanticGenericModel(pyClass))) return null + + // class Response(GenericModel, Generic[TypeA, TypeB]): pass + val pyClassGenericTypeMap = pyTypingTypeProvider.getGenericSubstitutions(pyClass, context) + .mapNotNull { (key, value) -> + if (key is PyGenericType && value is PyType) { + Pair(key, value) + } else null + }.toMap() + + // Response[TypeA] val pySubscriptionExpression = when (val firstChild = pyCallExpression?.firstChild) { is PySubscriptionExpression -> firstChild is PyReferenceExpression -> getResolvedPsiElements(firstChild, context) @@ -422,11 +457,13 @@ class PydanticTypeProvider : PyTypeProviderBase() { ?.let { it as? PyTargetExpression } ?.findAssignedValue() as? PySubscriptionExpression else -> null - } ?: return pyClassGenericTypeMap + } ?: return pyClassGenericTypeMap.takeIf { it.isNotEmpty() } + // Response[TypeA, TypeB]() val injectedTypes = (pySubscriptionExpression.indexExpression as? PyTupleExpression) ?.elements ?.map { getInjectedGenericType(it, context) } + // Response[TypeA]() ?: listOf((pySubscriptionExpression.indexExpression?.let { getInjectedGenericType(it, context) })) @@ -434,9 +471,9 @@ class PydanticTypeProvider : PyTypeProviderBase() { this.putAll(collectGenericTypes(pyClass, context) .take(injectedTypes.size) .mapIndexed { index, genericType -> genericType to injectedTypes[index] } - .filterIsInstance>().toMap() + .filterIsInstance>().toMap() ) - } + }.takeIf { it.isNotEmpty() } } fun getPydanticTypeForClass( @@ -502,14 +539,15 @@ class PydanticTypeProvider : PyTypeProviderBase() { pyClass: PyClass, pydanticVersion: KotlinVersion?, config: HashMap, - genericTypeMap: Map, + genericTypeMap: Map?, typed: Boolean = true, isDataclass: Boolean = false, ): PyCallableParameter? { if (!isValidField(field, context)) return null if (!hasAnnotationValue(field) && !field.hasAssignedValue()) return null // skip fields that are invalid syntax - val defaultValueFromField = getDefaultValueForParameter(field, ellipsis, context, pydanticVersion, isDataclass) + val defaultValueFromField = + getDefaultValueForParameter(field, ellipsis, context, pydanticVersion, isDataclass) val defaultValue = when { isSubClassOfBaseSetting(pyClass, context) -> ellipsis else -> defaultValueFromField @@ -523,7 +561,11 @@ class PydanticTypeProvider : PyTypeProviderBase() { // get type from annotation else -> getTypeForParameter(field, context) }?.let { - genericTypeMap[it] ?: it + if (genericTypeMap == null) { + it + } else { + PyTypeChecker.substitute(it, genericTypeMap, context) + } } return PyCallableParameterImpl.nonPsi( @@ -712,7 +754,10 @@ class PydanticTypeProvider : PyTypeProviderBase() { } } - private fun getDefaultValueForDataclass(assignedValue: PyCallExpression, context: TypeEvalContext): PyExpression? { + private fun getDefaultValueForDataclass( + assignedValue: PyCallExpression, + context: TypeEvalContext, + ): PyExpression? { val defaultValue = getDefaultValueForDataclass(assignedValue, context, "default") val defaultFactoryValue = getDefaultValueForDataclass(assignedValue, context, "default_factory") return when { diff --git a/testData/completionv18/genericField.py b/testData/completionv18/genericField.py new file mode 100644 index 00000000..89d1183e --- /dev/null +++ b/testData/completionv18/genericField.py @@ -0,0 +1,21 @@ +from typing import TypeVar, Type, List, Dict, Generic, Optional +from pydantic.generics import GenericModel + + +AT = TypeVar('AT') +BT = TypeVar('BT') +CT = TypeVar('CT') +DT = TypeVar('DT') +ET = TypeVar('ET') + + +class A(GenericModel, Generic[AT, BT, CT, DT]): + a: Type[AT] + b: List[BT] + c: Dict[CT, DT] + +class B(A[int, BT, CT, DT], Generic[BT, CT, DT, ET]): + hij: Optional[ET] + + +B[str, float, bytes, bool](). \ No newline at end of file diff --git a/testData/completionv18/genericKeywordArgument.py b/testData/completionv18/genericKeywordArgument.py new file mode 100644 index 00000000..615863fa --- /dev/null +++ b/testData/completionv18/genericKeywordArgument.py @@ -0,0 +1,21 @@ +from typing import TypeVar, Type, List, Dict, Generic, Optional +from pydantic.generics import GenericModel + + +AT = TypeVar('AT') +BT = TypeVar('BT') +CT = TypeVar('CT') +DT = TypeVar('DT') +ET = TypeVar('ET') + + +class A(GenericModel, Generic[AT, BT, CT, DT]): + a: Type[AT] + b: List[BT] + c: Dict[CT, DT] + +class B(A[int, BT, CT, DT], Generic[BT, CT, DT, ET]): + hij: Optional[ET] + + +B[str, float, bytes, bool]() \ No newline at end of file diff --git a/testData/mock/pydanticv18/generics.py b/testData/mock/pydanticv18/generics.py new file mode 100644 index 00000000..1e5eac45 --- /dev/null +++ b/testData/mock/pydanticv18/generics.py @@ -0,0 +1,38 @@ +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + get_type_hints, +) + +from .main import BaseModel + +_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {} +GenericModelT = TypeVar('GenericModelT', bound='GenericModel') +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type + + +class GenericModel(BaseModel): + __slots__ = () + __concrete__: ClassVar[bool] = False + + if TYPE_CHECKING: + # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with + # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of + # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. + __parameters__: ClassVar[Tuple[TypeVarType, ...]] + + # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings + def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]: + pass \ No newline at end of file diff --git a/testData/typeinspectionv18/genericModel.py b/testData/typeinspectionv18/genericModel.py new file mode 100644 index 00000000..fdd07866 --- /dev/null +++ b/testData/typeinspectionv18/genericModel.py @@ -0,0 +1,191 @@ +from typing import Generic, TypeVar, Optional, List, Tuple, Any, Dict, Type, Union + +from pydantic import BaseModel +from pydantic.generics import GenericModel + +DataT = TypeVar('DataT') + + +class Error(BaseModel): + code: int + message: str + + +class DataModel(BaseModel): + numbers: List[int] + people: List[str] + + +class Response(GenericModel, Generic[DataT]): + data: Optional[DataT] + error: Optional[Error] + + +Response[int](data=1, error=None) + +Response[str](data=1, error=None) + +TypeX = TypeVar('TypeX') + +class BaseClass(GenericModel, Generic[TypeX]): + X: TypeX + + +class ChildClass(BaseClass[TypeX], Generic[TypeX]): + # Inherit from Generic[TypeX] + pass + + +ChildClass[int](X=1) + +ChildClass[str](X=1) + + + +TypeX = TypeVar('TypeX') +TypeY = TypeVar('TypeY') +TypeZ = TypeVar('TypeZ') + + +class BaseClass(GenericModel, Generic[TypeX, TypeY]): + x: TypeX + y: TypeY + + +class ChildClass(BaseClass[int, TypeY], Generic[TypeY, TypeZ]): + z: TypeZ + + +# Replace TypeY by str +ChildClass[str, float](x=1, y='y', z=3.1) + +ChildClass[float, bytes](x=b'1', y='y', z=1_3) + +DataT = TypeVar('DataT') + + +class Response(GenericModel, Generic[DataT]): + data: DataT + + @classmethod + def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: + return f'{params[0].__name__.title()}Response' + + +Response[int](data=1) + +Response[str](data='a') + + +Response[str](data=1) +Response[float](data='a') + + +T = TypeVar('T') + + +class InnerT(GenericModel, Generic[T]): + inner: T + + +class OuterT(GenericModel, Generic[T]): + outer: T + nested: InnerT[T] + + +nested = InnerT[int](inner=1) +OuterT[int](outer=1, nested=nested) + +nested = InnerT[str](inner='a') +OuterT[int](outer='a', nested=nested) + +AT = TypeVar('AT') +BT = TypeVar('BT') + + +class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + +Model(a='a', b='a') +#> a='a' b='a' + +IntT = TypeVar('IntT', bound=int) +typevar_model = Model[int, IntT] +typevar_model(a=1, b=1) + + +typevar_model(a='a', b='a') + +concrete_model = typevar_model[int] +concrete_model(a=1, b=1) + + +CT = TypeVar('CT') +DT = TypeVar('DT') +ET = TypeVar('ET') +FT = TypeVar('FT') + +class Model(GenericModel, Generic[CT, DT, ET, FT]): + a: Type[CT] + b: List[DT] + c: Dict[ET, FT] + +Model[int, int, str, int](a=int, b=[2], c={'c': 3}) + +Model[int, int, int, int](a=1, b=2, c=3) + +class Model(GenericModel, Generic[CT, DT]): + a: List[Type[CT]] + b: Union[List[DT], float] + +Model[int, int](a=[int], b=[2]) + +Model[int, int](a=1, b='2') + +class Model(GenericModel, Generic[CT, DT, ET, FT]): + a: CT + b: DT + c: ET + d: FT + +Model[Type[int], List[Type[int]], Optional[Type[int]], Tuple[Type[int]]](a=int, b=[int], c=int, d=(int,)) + +Model[Type[int], List[Type[int]], Optional[Type[int]], Tuple[Type[int]]](a=1, b=[2], c=[3], d=(4, )) + +class Model(GenericModel, Generic[CT, Broken]): + a: CT + b: Broken + +Model[int, int, int](a=1, b=2) + +Model[str, str, str](a=1, b=2) + +class Model(GenericModel, Generic[CT]): + a: CT + +Model[Union[int, float]](a=1) + +Model[Union[int, float]](a='1') + +class Model(GenericModel, Generic[CT, DT]): + a: CT + b: DT + +def x(b: ET) -> ET: + pass + +Model[x(int), Optional[x(int)]](a=1, b=2) + +Model[x(int), Optional[x(int)]](a='1', b='2') + +class Model(GenericModel, Generic[CT, DT]): + a: CT + b: DT + +y = int + +Model[y, Optional[y]](a=1, b=2) + +Model[y, Optional[y]](a='1', b='2') diff --git a/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt b/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt new file mode 100644 index 00000000..dce728cf --- /dev/null +++ b/testSrc/com/koxudaxi/pydantic/PydanticCompletionV18Test.kt @@ -0,0 +1,80 @@ +package com.koxudaxi.pydantic + +import com.intellij.codeInsight.lookup.LookupElementPresentation +import com.jetbrains.python.psi.PyTargetExpression + + +open class PydanticCompletionV18Test : PydanticTestCase(version = "v18") { + + + private fun doFieldTest(fieldNames: List>, additionalModules: List? = null) { + configureByFile(additionalModules) + val excludes = listOf( + "__annotations__", + "__base__", + "__bases__", + "__basicsize__", + "__dict__", + "__dictoffset__", + "__flags__", + "__itemsize__", + "__mro__", + "__name__", + "__qualname__", + "__slots__", + "__text_signature__", + "__weakrefoffset__", + "Ellipsis", + "EnvironmentError", + "IOError", + "NotImplemented", + "List", + "Type", + "Annotated", + "MISSING", + "WindowsError", + "__concrete__", + "__parameters__", + "___slots__", + "Generic", + "Dict", + "Optional", + ) + val actual = myFixture!!.completeBasic().filter { + it!!.psiElement is PyTargetExpression + }.filterNot { + excludes.contains(it!!.lookupString) + }.mapNotNull { + Pair(it!!.lookupString, LookupElementPresentation.renderElement(it).typeText ?: "null") + } + assertEquals(fieldNames, actual) + } + + fun testGenericField() { + doFieldTest( + listOf( + Pair("a", "Type[int] A"), + Pair("b", "List[str] A"), + Pair("c", "Dict[float, bytes] A"), + Pair("hij", "Optional[bool]=None B"), + ) + ) + } + + + fun testGenericKeywordArgument() { + doFieldTest( + listOf( + Pair("a=", "Type[int] A"), + Pair("b=", "List[str] A"), + Pair("c=", "Dict[float, bytes] A"), + Pair("hij=", "Optional[bool]=None B"), + Pair("AT", "null"), + Pair("BT", "null"), + Pair("CT", "null"), + Pair("DT", "null"), + Pair("ET", "null")) + ) + } + +} diff --git a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt index e4e9467b..16e28232 100644 --- a/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt +++ b/testSrc/com/koxudaxi/pydantic/PydanticTypeInspectionV18Test.kt @@ -20,4 +20,8 @@ open class PydanticTypeInspectionV18Test : PydanticInspectionBase("v18") { fun testBaseSetting() { doTest() } + + fun testGenericModel() { + doTest() + } } \ No newline at end of file