Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve field type translation #44

Merged
merged 2 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions ormdantic/generator/_table.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Module providing PydanticSQLTableGenerator."""
import uuid
from datetime import date, datetime
from types import UnionType
from typing import Any, get_args, get_origin

from pydantic import BaseModel, ConstrainedStr
from pydantic import BaseModel
from pydantic.fields import ModelField
from sqlalchemy import (
JSON,
Boolean,
Column,
Date,
DateTime,
Float,
ForeignKey,
Integer,
Expand Down Expand Up @@ -87,19 +91,28 @@ def _get_column(
raise TypeConversionError(field.type_) # pragma: no cover
if get_origin(field.outer_type_) == dict:
return Column(field_name, JSON, **kwargs)
if issubclass(field.type_, BaseModel):
return Column(field_name, JSON, **kwargs)
if field.type_ is uuid.UUID:
col_type = (
postgresql.UUID if self._engine.name == "postgres" else String(36)
)
return Column(field_name, col_type, **kwargs)
if field.type_ is str or issubclass(field.type_, ConstrainedStr):
if issubclass(field.type_, BaseModel):
return Column(field_name, JSON, **kwargs)
if issubclass(field.type_, str):
return Column(field_name, String(field.field_info.max_length), **kwargs)
if field.type_ is int:
return Column(field_name, Integer, **kwargs)
if field.type_ is float:
if issubclass(field.type_, float):
return Column(field_name, Float, **kwargs)
if issubclass(field.type_, int):
# bool is a subclass of int -> nested check
if issubclass(field.type_, bool):
return Column(field_name, Boolean, **kwargs)
return Column(field_name, Integer, **kwargs)
if issubclass(field.type_, date):
# datetime is a subclass of date -> nested check
if issubclass(field.type_, datetime):
return Column(field_name, DateTime, **kwargs)
return Column(field_name, Date, **kwargs)

# Catchall for dict/list or any other.
return Column(field_name, JSON, **kwargs)

Expand Down
35 changes: 35 additions & 0 deletions tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import unittest
from datetime import date, datetime
from typing import Any
from uuid import UUID, uuid4

Expand Down Expand Up @@ -37,6 +38,10 @@ class Flavor(BaseModel):
name: str = Field(..., max_length=63)
strength: int | None = None
coffee: Coffee | UUID | None = None
created_at: date = Field(default_factory=date.today)
updated_at: date = Field(default_factory=date.today)
expire: datetime = Field(default_factory=datetime.now)
exist: bool = False


@database.table(pk="id")
Expand All @@ -52,6 +57,7 @@ class Coffee(BaseModel):
ice: list # type: ignore
size: Money
attributes: dict[str, Any] | None = None
exist: bool = False


@database.table(pk="id")
Expand Down Expand Up @@ -98,6 +104,24 @@ async def test_insert_and_find_one(self) -> None:
mocha.dict(), (await database[Flavor].find_one(mocha.id)).dict() # type: ignore
)

async def test_insert_and_find_one_date(self) -> None:
# Test Date and Time fields
flavor = Flavor(name="mocha", created_at=date(2021, 1, 1))
mocha = await database[Flavor].insert(flavor)
# Find new record and compare.
self.assertDictEqual(
mocha.dict(), (await database[Flavor].find_one(mocha.id)).dict() # type: ignore
)

async def test_insert_and_find_one_bool(self) -> None:
# Insert record.
flavor = Flavor(name="mocha", exist=True)
mocha = await database[Flavor].insert(flavor)
# Find new record and compare.
self.assertDictEqual(
mocha.dict(), (await database[Flavor].find_one(mocha.id)).dict() # type: ignore
)

async def test_find_many(self) -> None:
# Insert 3 records.
mocha1 = await database[Flavor].insert(Flavor(name="mocha"))
Expand Down Expand Up @@ -139,6 +163,17 @@ async def test_update(self) -> None:
# Find the updated record.
self.assertEqual(flavor.name, (await database[Flavor].find_one(flavor.id)).name) # type: ignore

async def test_update_datetime(self) -> None:
# Insert record.
flavor = await database[Flavor].insert(
Flavor(name="mocha", expire=datetime(2021, 1, 1, 1, 1, 1))
)
# Update record.
flavor.expire = datetime(2021, 1, 1, 1, 1, 2)
await database[Flavor].update(flavor)
# Find the updated record.
self.assertEqual(flavor.expire, (await database[Flavor].find_one(flavor.id)).expire) # type: ignore

async def test_upsert(self) -> None:
# Upsert record as insert.
flavor = await database[Flavor].upsert(Flavor(name="vanilla"))
Expand Down