Skip to content

Commit

Permalink
Support SQLModel
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi committed Mar 10, 2022
1 parent 9e32907 commit 35937ba
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 8 deletions.
37 changes: 30 additions & 7 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const val ROOT_VALIDATOR_SHORT_Q_NAME = "pydantic.root_validator"
const val SCHEMA_Q_NAME = "pydantic.schema.Schema"
const val FIELD_Q_NAME = "pydantic.fields.Field"
const val DATACLASS_FIELD_Q_NAME = "dataclasses.field"
const val SQL_MODEL_FIELD_Q_NAME = "sqlmodel.main.Field"
const val DEPRECATED_SCHEMA_Q_NAME = "pydantic.fields.Schema"
const val BASE_SETTINGS_Q_NAME = "pydantic.env_settings.BaseSettings"
const val VERSION_Q_NAME = "pydantic.version.VERSION"
Expand All @@ -58,6 +59,15 @@ const val GENERIC_Q_NAME = "typing.Generic"
const val TYPE_Q_NAME = "typing.Type"
const val TUPLE_Q_NAME = "typing.Tuple"

const val SQL_MODEL_Q_NAME = "sqlmodel.main.SQLModel"

val CUSTOM_BASE_MODEL_Q_NAMES = listOf(
SQL_MODEL_Q_NAME
)

val CUSTOM_MODEL_FIELD_Q_NAMES = listOf(
SQL_MODEL_FIELD_Q_NAME
)
val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME)
Expand All @@ -76,6 +86,8 @@ val DATA_CLASS_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_Q_NAME

val DATA_CLASS_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_SHORT_Q_NAME)

val SQL_MODEL_QUALIFIED_NAME = QualifiedName.fromDottedString(SQL_MODEL_Q_NAME)

val DATA_CLASS_QUALIFIED_NAMES = listOf(
DATA_CLASS_QUALIFIED_NAME,
DATA_CLASS_SHORT_QUALIFIED_NAME
Expand Down Expand Up @@ -147,14 +159,16 @@ fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context:
}

fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean {
return (isSubClassOfPydanticBaseModel(pyClass,
context) || isSubClassOfPydanticGenericModel(pyClass,
context) || (includeDataclass && pyClass.isPydanticDataclass)) && !pyClass.isPydanticBaseModel
&& !pyClass.isPydanticGenericModel && !pyClass.isBaseSettings
return ((isSubClassOfPydanticBaseModel(pyClass,
context) && !pyClass.isPydanticCustomBaseModel) || isSubClassOfPydanticGenericModel(pyClass,
context) || (includeDataclass && pyClass.isPydanticDataclass) || isSubClassOfCustomBaseModel(pyClass,
context)) && !pyClass.isPydanticBaseModel
&& !pyClass.isPydanticGenericModel && !pyClass.isBaseSettings && !pyClass.isPydanticCustomBaseModel
}

val PyClass.isPydanticBaseModel: Boolean get() = qualifiedName == BASE_MODEL_Q_NAME

val PyClass.isPydanticCustomBaseModel: Boolean get() = qualifiedName in CUSTOM_BASE_MODEL_Q_NAMES

val PyClass.isPydanticGenericModel: Boolean get() = qualifiedName == GENERIC_MODEL_Q_NAME

Expand All @@ -171,6 +185,10 @@ internal fun isSubClassOfBaseSetting(pyClass: PyClass, context: TypeEvalContext)
return pyClass.isSubclass(BASE_SETTINGS_Q_NAME, context)
}

internal fun isSubClassOfCustomBaseModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return CUSTOM_BASE_MODEL_Q_NAMES.any { pyClass.isSubclass(it, context) }
}

internal val PyClass.isBaseSettings: Boolean get() = qualifiedName == BASE_SETTINGS_Q_NAME


Expand All @@ -191,9 +209,9 @@ internal fun isPydanticSchema(pyClass: PyClass, context: TypeEvalContext): Boole

internal val PyFunction.isPydanticField: Boolean get() = qualifiedName == FIELD_Q_NAME || qualifiedName == DEPRECATED_SCHEMA_Q_NAME


internal val PyFunction.isDataclassField: Boolean get() = qualifiedName == DATACLASS_FIELD_Q_NAME

internal val PyFunction.isCustomModelField: Boolean get() = qualifiedName in CUSTOM_MODEL_FIELD_Q_NAMES

internal val PyFunction.isPydanticCreateModel: Boolean get() = qualifiedName == CREATE_MODEL

Expand All @@ -217,7 +235,7 @@ internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression):
val referenceExpression = pyCallExpression.callee as? PyReferenceExpression ?: return false
val context = TypeEvalContext.userInitiated(referenceExpression.project, referenceExpression.containingFile)
return getResolvedPsiElements(referenceExpression, context)
.filterIsInstance<PyFunction>().any { pyFunction -> pyFunction.isPydanticField || pyFunction.isConStr }
.filterIsInstance<PyFunction>().any { pyFunction -> pyFunction.isPydanticField || pyFunction.isConStr || pyFunction.isCustomModelField }
}

internal fun getClassVariables(pyClass: PyClass, context: TypeEvalContext): Sequence<PyTargetExpression> {
Expand Down Expand Up @@ -292,6 +310,11 @@ val PsiElement.isDataclassField: Boolean
pyFunction.isDataclassField
}

val PsiElement.isCustomModelField: Boolean
get() = validatePsiElementByFunction(this) { pyFunction: PyFunction ->
pyFunction.isCustomModelField
}

val PsiElement.isDataclassMissing: Boolean get() = validatePsiElementByFunction(this, ::isDataclassMissing)

val Project.sdk: Sdk? get() = pythonSdk ?: modules.mapNotNull { PythonSdkUtil.findPythonSdk(it) }.firstOrNull()
Expand Down Expand Up @@ -559,7 +582,7 @@ internal fun getFieldFromPyExpression(
if (!getResolvedPsiElements(callee, context).any {
when {
versionZero -> isPydanticSchemaByPsiElement(it, context)
else -> it.isPydanticField
else -> it.isPydanticField || it.isCustomModelField
}
}) return null
return psiElement
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
.any {
when {
versionZero -> isPydanticSchemaByPsiElement(it, context)
else -> it.isPydanticField
else -> it.isPydanticField || it.isCustomModelField
}

}
Expand Down
3 changes: 3 additions & 0 deletions testData/mock/stub/sqlmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .main import SQLModel as SQLModel
from .main import Field as Field
from .main import Relationship as Relationship
106 changes: 106 additions & 0 deletions testData/mock/stub/sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import *

def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...

def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
return lambda a: a


class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
...

@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)

def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)

# From Pydantic
def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
...

def __init__(
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
) -> None:
...


_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str]
metadata: ClassVar[MetaData]

class Config:
orm_mode = True

def __new__(cls, *args: Any, **kwargs: Any) -> Any:
...

def __init__(__pydantic_self__, **data: Any) -> None:
...
18 changes: 18 additions & 0 deletions testData/typeinspectionv18/sqlModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import *

from sqlmodel import Field, SQLModel


class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name = Field(default="dummy", primary_key=True)
age: Optional[int] = None

hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)

hero_4 = Hero(secret_name="test", <warning descr="null">)</warning>

hero_5 = Hero(<warning descr="Expected type 'str', got 'int' instead">name=123</warning>, <warning descr="Expected type 'str', got 'int' instead">secret_name=456</warning>, <warning descr="Expected type 'Optional[int]', got 'str' instead">age="abc"</warning>)
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ open class PydanticTypeInspectionV18Test : PydanticInspectionBase("v18") {
fun testGenericModel() {
doTest()
}

fun testSqlModel() {
doTest()
}
}

0 comments on commit 35937ba

Please sign in to comment.