Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SQLModel #450

Merged
merged 2 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Changelog

## [Unreleased]
### Features
- Support SQLModel [[#450](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/450)]

## [0.3.11]
### Features
- Support IntelliJ IDEA 2022.1 [[#436](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/436)]
## 0.3.11
### Features
- Support IntelliJ IDEA 2022.1 [[#436](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull/436)]

### BugFixes
### BugFixes
- Fix Null Pointer Exception in PydanticTypeCheckerInspection [[#431](https://github.com/koxudaxi/pydantic-pycharm-plugin/pull//431)]

## 0.3.10
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
pluginGroup = com.koxudaxi.pydantic
pluginName = Pydantic
# SemVer format -> https://semver.org
pluginVersion = 0.3.11
pluginVersion = 0.3.12

# See https://plugins.jetbrains.com/docs/intellij/build-number-ranges.html
# for insight into build numbers and IntelliJ Platform versions.
Expand Down
37 changes: 30 additions & 7 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const val ROOT_VALIDATOR_SHORT_Q_NAME = "pydantic.root_validator"
const val SCHEMA_Q_NAME = "pydantic.schema.Schema"
const val FIELD_Q_NAME = "pydantic.fields.Field"
const val DATACLASS_FIELD_Q_NAME = "dataclasses.field"
const val SQL_MODEL_FIELD_Q_NAME = "sqlmodel.main.Field"
const val DEPRECATED_SCHEMA_Q_NAME = "pydantic.fields.Schema"
const val BASE_SETTINGS_Q_NAME = "pydantic.env_settings.BaseSettings"
const val VERSION_Q_NAME = "pydantic.version.VERSION"
Expand All @@ -58,6 +59,15 @@ const val GENERIC_Q_NAME = "typing.Generic"
const val TYPE_Q_NAME = "typing.Type"
const val TUPLE_Q_NAME = "typing.Tuple"

const val SQL_MODEL_Q_NAME = "sqlmodel.main.SQLModel"

val CUSTOM_BASE_MODEL_Q_NAMES = listOf(
SQL_MODEL_Q_NAME
)

val CUSTOM_MODEL_FIELD_Q_NAMES = listOf(
SQL_MODEL_FIELD_Q_NAME
)
val VERSION_QUALIFIED_NAME = QualifiedName.fromDottedString(VERSION_Q_NAME)

val BASE_CONFIG_QUALIFIED_NAME = QualifiedName.fromDottedString(BASE_CONFIG_Q_NAME)
Expand All @@ -76,6 +86,8 @@ val DATA_CLASS_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_Q_NAME

val DATA_CLASS_SHORT_QUALIFIED_NAME = QualifiedName.fromDottedString(DATA_CLASS_SHORT_Q_NAME)

val SQL_MODEL_QUALIFIED_NAME = QualifiedName.fromDottedString(SQL_MODEL_Q_NAME)

val DATA_CLASS_QUALIFIED_NAMES = listOf(
DATA_CLASS_QUALIFIED_NAME,
DATA_CLASS_SHORT_QUALIFIED_NAME
Expand Down Expand Up @@ -147,14 +159,16 @@ fun getPyClassByPyKeywordArgument(pyKeywordArgument: PyKeywordArgument, context:
}

fun isPydanticModel(pyClass: PyClass, includeDataclass: Boolean, context: TypeEvalContext): Boolean {
return (isSubClassOfPydanticBaseModel(pyClass,
context) || isSubClassOfPydanticGenericModel(pyClass,
context) || (includeDataclass && pyClass.isPydanticDataclass)) && !pyClass.isPydanticBaseModel
&& !pyClass.isPydanticGenericModel && !pyClass.isBaseSettings
return ((isSubClassOfPydanticBaseModel(pyClass,
context) && !pyClass.isPydanticCustomBaseModel) || isSubClassOfPydanticGenericModel(pyClass,
context) || (includeDataclass && pyClass.isPydanticDataclass) || isSubClassOfCustomBaseModel(pyClass,
context)) && !pyClass.isPydanticBaseModel
&& !pyClass.isPydanticGenericModel && !pyClass.isBaseSettings && !pyClass.isPydanticCustomBaseModel
}

val PyClass.isPydanticBaseModel: Boolean get() = qualifiedName == BASE_MODEL_Q_NAME

val PyClass.isPydanticCustomBaseModel: Boolean get() = qualifiedName in CUSTOM_BASE_MODEL_Q_NAMES

val PyClass.isPydanticGenericModel: Boolean get() = qualifiedName == GENERIC_MODEL_Q_NAME

Expand All @@ -171,6 +185,10 @@ internal fun isSubClassOfBaseSetting(pyClass: PyClass, context: TypeEvalContext)
return pyClass.isSubclass(BASE_SETTINGS_Q_NAME, context)
}

internal fun isSubClassOfCustomBaseModel(pyClass: PyClass, context: TypeEvalContext): Boolean {
return CUSTOM_BASE_MODEL_Q_NAMES.any { pyClass.isSubclass(it, context) }
}

internal val PyClass.isBaseSettings: Boolean get() = qualifiedName == BASE_SETTINGS_Q_NAME


Expand All @@ -191,9 +209,9 @@ internal fun isPydanticSchema(pyClass: PyClass, context: TypeEvalContext): Boole

internal val PyFunction.isPydanticField: Boolean get() = qualifiedName == FIELD_Q_NAME || qualifiedName == DEPRECATED_SCHEMA_Q_NAME


internal val PyFunction.isDataclassField: Boolean get() = qualifiedName == DATACLASS_FIELD_Q_NAME

internal val PyFunction.isCustomModelField: Boolean get() = qualifiedName in CUSTOM_MODEL_FIELD_Q_NAMES

internal val PyFunction.isPydanticCreateModel: Boolean get() = qualifiedName == CREATE_MODEL

Expand All @@ -217,7 +235,7 @@ internal fun isPydanticRegex(stringLiteralExpression: StringLiteralExpression):
val referenceExpression = pyCallExpression.callee as? PyReferenceExpression ?: return false
val context = TypeEvalContext.userInitiated(referenceExpression.project, referenceExpression.containingFile)
return getResolvedPsiElements(referenceExpression, context)
.filterIsInstance<PyFunction>().any { pyFunction -> pyFunction.isPydanticField || pyFunction.isConStr }
.filterIsInstance<PyFunction>().any { pyFunction -> pyFunction.isPydanticField || pyFunction.isConStr || pyFunction.isCustomModelField }
}

internal fun getClassVariables(pyClass: PyClass, context: TypeEvalContext): Sequence<PyTargetExpression> {
Expand Down Expand Up @@ -292,6 +310,11 @@ val PsiElement.isDataclassField: Boolean
pyFunction.isDataclassField
}

val PsiElement.isCustomModelField: Boolean
get() = validatePsiElementByFunction(this) { pyFunction: PyFunction ->
pyFunction.isCustomModelField
}

val PsiElement.isDataclassMissing: Boolean get() = validatePsiElementByFunction(this, ::isDataclassMissing)

val Project.sdk: Sdk? get() = pythonSdk ?: modules.mapNotNull { PythonSdkUtil.findPythonSdk(it) }.firstOrNull()
Expand Down Expand Up @@ -559,7 +582,7 @@ internal fun getFieldFromPyExpression(
if (!getResolvedPsiElements(callee, context).any {
when {
versionZero -> isPydanticSchemaByPsiElement(it, context)
else -> it.isPydanticField
else -> it.isPydanticField || it.isCustomModelField
}
}) return null
return psiElement
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
.any {
when {
versionZero -> isPydanticSchemaByPsiElement(it, context)
else -> it.isPydanticField
else -> it.isPydanticField || it.isCustomModelField
}

}
Expand Down
3 changes: 3 additions & 0 deletions testData/mock/stub/sqlmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .main import SQLModel as SQLModel
from .main import Field as Field
from .main import Relationship as Relationship
106 changes: 106 additions & 0 deletions testData/mock/stub/sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import *

def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...

def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
return lambda a: a


class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
...

@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)

def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)

# From Pydantic
def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
...

def __init__(
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
) -> None:
...


_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str]
metadata: ClassVar[MetaData]

class Config:
orm_mode = True

def __new__(cls, *args: Any, **kwargs: Any) -> Any:
...

def __init__(__pydantic_self__, **data: Any) -> None:
...
18 changes: 18 additions & 0 deletions testData/typeinspectionv18/sqlModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import *

from sqlmodel import Field, SQLModel


class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name = Field(default="dummy", primary_key=True)
age: Optional[int] = None

hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)

hero_4 = Hero(secret_name="test", <warning descr="null">)</warning>

hero_5 = Hero(<warning descr="Expected type 'str', got 'int' instead">name=123</warning>, <warning descr="Expected type 'str', got 'int' instead">secret_name=456</warning>, <warning descr="Expected type 'Optional[int]', got 'str' instead">age="abc"</warning>)
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ open class PydanticTypeInspectionV18Test : PydanticInspectionBase("v18") {
fun testGenericModel() {
doTest()
}

fun testSqlModel() {
doTest()
}
}