Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
Merge pull request #811 from Avaiga/feature/use-sqlite3-for-sql-repo
Browse files Browse the repository at this point in the history
Refactor - SQLRepository uses sqlite3 instead of sqlalchemy
  • Loading branch information
toan-quach authored Nov 17, 2023
2 parents 405568c + 9b4ddf2 commit 89b0202
Show file tree
Hide file tree
Showing 17 changed files with 343 additions and 190 deletions.
8 changes: 4 additions & 4 deletions src/taipy/core/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Core:
Core service
"""

__is_running = False
_is_running = False
__lock_is_running = Lock()

__logger = _TaipyLogger._get_logger()
Expand All @@ -51,11 +51,11 @@ def run(self, force_restart=False):
This function checks the configuration, manages application's version,
and starts a dispatcher and lock the Config.
"""
if self.__class__.__is_running:
if self.__class__._is_running:
raise CoreServiceIsAlreadyRunning

with self.__class__.__lock_is_running:
self.__class__.__is_running = True
self.__class__._is_running = True

self.__update_core_section()
self.__manage_version()
Expand All @@ -79,7 +79,7 @@ def stop(self):
self.__logger.info("Core service has been stopped.")

with self.__class__.__lock_is_running:
self.__class__.__is_running = False
self.__class__._is_running = False

@staticmethod
def __update_core_section():
Expand Down
25 changes: 25 additions & 0 deletions src/taipy/core/_repository/_base_taipy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@

import dataclasses
import enum
import json
from typing import Any, Dict

from sqlalchemy import Table

from ._decoder import _Decoder
from ._encoder import _Encoder


class _BaseModel:
__table__: Table

def __iter__(self):
for attr, value in self.__dict__.items():
yield attr, value
Expand All @@ -26,3 +34,20 @@ def to_dict(self) -> Dict[str, Any]:
if isinstance(v, enum.Enum):
model_dict[k] = repr(v)
return model_dict

@staticmethod
def _serialize_attribute(value):
return json.dumps(value, ensure_ascii=False, cls=_Encoder)

@staticmethod
def _deserialize_attribute(value):
if isinstance(value, str):
return json.loads(value.replace("'", '"'), cls=_Decoder)
return value

@staticmethod
def from_dict(data: Dict[str, Any]):
pass

def to_list(self):
pass
4 changes: 1 addition & 3 deletions src/taipy/core/_repository/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from enum import Enum
from typing import Any

from ..common.typing import Json


class _Encoder(json.JSONEncoder):
def _timedelta_to_str(self, obj: timedelta) -> str:
Expand All @@ -27,7 +25,7 @@ def _timedelta_to_str(self, obj: timedelta) -> str:
f"{int(total_seconds % 60)}s"
)

def default(self, o: Any) -> Json:
def default(self, o: Any):
if isinstance(o, Enum):
result = o.value
elif isinstance(o, datetime):
Expand Down
114 changes: 76 additions & 38 deletions src/taipy/core/_repository/_sql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
import pathlib
from typing import Any, Dict, Iterable, List, Optional, Type, Union

from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import NoResultFound

from .._repository._abstract_repository import _AbstractRepository
from ..common.typing import Converter, Entity, ModelType
from ..exceptions import ModelNotFound
from ._abstract_repository import _AbstractRepository
from .db._sql_session import _SQLSession
from .db._sql_connection import _SQLConnection


class _SQLRepository(_AbstractRepository[ModelType, Entity]):
def __init__(self, model_type: Type[ModelType], converter: Type[Converter], session=None):
def __init__(self, model_type: Type[ModelType], converter: Type[Converter]):
"""
Holds common methods to be used and extended when the need for saving
dataclasses in a SqlLite database.
Expand All @@ -33,68 +34,85 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess
Attributes:
model_type: Generic dataclass.
converter: A class that handles conversion to and from a database backend
db: An SQLAlchemy session object
db: An sqlite3 session object
"""
SessionLocal = _SQLSession.init_db()
self.db = session or SessionLocal()
self.db = _SQLConnection.init_db()
self.model_type = model_type
self.converter = converter
self.table = self.model_type.__table__

###############################
# ## Inherited methods ## #
###############################
def _save(self, entity: Entity):
obj = self.converter._entity_to_model(entity)
if self.db.query(self.model_type).filter_by(id=obj.id).first():
self.__update_entry(obj)
if self._exists(entity.id): # type: ignore
self._update_entry(obj)
return
self.__insert_model(obj)

def _exists(self, entity_id: str):
return bool(self.db.query(self.model_type.id).filter_by(id=entity_id).first()) # type: ignore
query = self.table.select().filter_by(id=entity_id)
return bool(self.db.execute(str(query), [entity_id]).fetchone())

def _load(self, entity_id: str) -> Entity:
if entry := self.db.query(self.model_type).filter(self.model_type.id == entity_id).first(): # type: ignore
query = self.table.select().filter_by(id=entity_id)

if entry := self.db.execute(str(query.compile(dialect=sqlite.dialect())), [entity_id]).fetchone():
entry = self.model_type.from_dict(entry)
return self.converter._model_to_entity(entry)
raise ModelNotFound(str(self.model_type.__name__), entity_id)

def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]:
query = self.db.query(self.model_type)
query = self.table.select()
entities: List[Entity] = []

for f in filters or [{}]:
filtered_query = query.filter_by(**f)
try:
entities.extend([self.converter._model_to_entity(m) for m in filtered_query.all()])
entries = self.db.execute(
str(filtered_query.compile(dialect=sqlite.dialect())),
[self.__serialize_filter_values(val) for val in list(f.values())],
).fetchall()

entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries])
except NoResultFound:
continue
return entities

def _delete(self, entity_id: str):
number_of_deleted_entries = self.db.query(self.model_type).filter_by(id=entity_id).delete()
if not number_of_deleted_entries:
delete_query = self.table.delete().filter_by(id=entity_id)
cursor = self.db.execute(str(delete_query.compile(dialect=sqlite.dialect())), [entity_id])

if cursor.rowcount == 0:
raise ModelNotFound(str(self.model_type.__name__), entity_id)

self.db.commit()

def _delete_all(self):
self.db.query(self.model_type).delete()
self.db.execute(str(self.table.delete().compile(dialect=sqlite.dialect())))
self.db.commit()

def _delete_many(self, ids: Iterable[str]):
for entity_id in ids:
self._delete(entity_id)

def _delete_by(self, attribute: str, value: str):
self.db.query(self.model_type).filter_by(**{attribute: value}).delete()
delete_by_query = self.table.delete().filter_by(**{attribute: value})

self.db.execute(str(delete_by_query.compile(dialect=sqlite.dialect())), [value])
self.db.commit()

def _search(self, attribute: str, value: Any, filters: Optional[List[Dict]] = None) -> List[Entity]:
query = self.db.query(self.model_type).filter_by(**{attribute: value})
query = self.table.select().filter_by(**{attribute: value})

entities: List[Entity] = []
for f in filters or [{}]:
filtered_query = query.filter_by(**f)
entities.extend([self.converter._model_to_entity(m) for m in filtered_query.all()])
entries = self.db.execute(
str(query.filter_by(**f).compile(dialect=sqlite.dialect())),
[value] + [self.__serialize_filter_values(val) for val in list(f.values())],
).fetchall()
entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries])

return entities

Expand All @@ -104,27 +122,30 @@ def _export(self, entity_id: str, folder_path: Union[str, pathlib.Path]):
else:
folder = folder_path

export_dir = folder / self.model_type.__table__.name # type: ignore
export_dir = folder / self.table.name
if not export_dir.exists():
export_dir.mkdir(parents=True)

export_path = export_dir / f"{entity_id}.json"

entry = self.db.query(self.model_type).filter_by(id=entity_id).first()
if entry is None:
raise ModelNotFound(self.model_type, entity_id) # type: ignore
query = self.table.select().filter_by(id=entity_id)

with open(export_path, "w", encoding="utf-8") as export_file:
export_file.write(json.dumps(entry.to_dict()))
if entry := self.db.execute(str(query.compile(dialect=sqlite.dialect())), [entity_id]).fetchone():
with open(export_path, "w", encoding="utf-8") as export_file:
export_file.write(json.dumps(entry))
else:
raise ModelNotFound(self.model_type, entity_id) # type: ignore

###########################################
# ## Specific or optimized methods ## #
###########################################
def _get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
return self.db.query(self.model_type).offset(skip).limit(limit).all()
query = self.table.select().offset(skip).limit(limit)
return self.db.execute(str(query.compile(dialect=sqlite.dialect()))).fetchall()

def _get_by_config(self, config_id: Any) -> Optional[ModelType]:
return self.db.query(self.model_type).filter(self.model_type.config_id == config_id).first() # type: ignore
query = self.table.select().filter_by(config_id=config_id)
return self.db.execute(str(query.compile(dialect=sqlite.dialect())), [config_id]).fetchall()

def _get_by_config_and_owner_id(
self, config_id: str, owner_id: Optional[str], filters: Optional[List[Dict]] = None
Expand Down Expand Up @@ -153,27 +174,44 @@ def _get_by_configs_and_owner_ids(self, configs_and_owner_ids, filters: Optional
return res

def __get_entities_by_config_and_owner(
self, config_id: str, owner_id: Optional[str] = "", filters: Optional[List[Dict]] = None
) -> ModelType:
self, config_id: str, owner_id: Optional[str] = None, filters: Optional[List[Dict]] = None
) -> Optional[ModelType]:
if not filters:
filters = []
versions = [item.get("version") for item in filters if item.get("version")]

query = self.table.select().filter_by(config_id=config_id)
parameters: List = [config_id]

if owner_id:
query = self.db.query(self.model_type).filter_by(config_id=config_id).filter_by(owner_id=owner_id)
else:
query = self.db.query(self.model_type).filter_by(config_id=config_id).filter_by(owner_id=None)
parameters.append(owner_id)
query = query.filter_by(owner_id=owner_id)
query = str(query.compile(dialect=sqlite.dialect()))

if versions:
query = query.filter(self.model_type.version.in_(versions)) # type: ignore
return query.first()
table_name = self.table.name
query = query + f" AND {table_name}.version IN ({','.join(['?']*len(versions))})"
parameters.extend(versions)

if entry := self.db.execute(query, parameters).fetchone():
return self.model_type.from_dict(entry)
return None

#############################
# ## Private methods ## #
#############################
def __insert_model(self, model: ModelType):
self.db.add(model)
query = self.table.insert()
self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list())
self.db.commit()
self.db.refresh(model)

def __update_entry(self, model):
self.db.merge(model)
def _update_entry(self, model):
query = self.table.update().filter_by(id=model.id)
self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list() + [model.id])
self.db.commit()

@staticmethod
def __serialize_filter_values(value):
if isinstance(value, (dict, list)):
return json.dumps(value).replace('"', "'")
return value
86 changes: 86 additions & 0 deletions src/taipy/core/_repository/db/_sql_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2023 Avaiga Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import sqlite3
from functools import lru_cache
from sqlite3 import Connection

from sqlalchemy.dialects import sqlite
from sqlalchemy.schema import CreateTable

from taipy.config.config import Config

from ...exceptions import MissingRequiredProperty


def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d


class _SQLConnection:
_connection = None

@classmethod
def init_db(cls):
if cls._connection:
return cls._connection

cls._connection = _build_connection()
cls._connection.row_factory = dict_factory

from ..._version._version_model import _VersionModel
from ...cycle._cycle_model import _CycleModel
from ...data._data_model import _DataNodeModel
from ...job._job_model import _JobModel
from ...scenario._scenario_model import _ScenarioModel
from ...task._task_model import _TaskModel

cls._connection.execute(
str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
cls._connection.execute(
str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
cls._connection.execute(
str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
cls._connection.execute(
str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
cls._connection.execute(
str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
cls._connection.execute(
str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)

return cls._connection


def _build_connection() -> Connection:
# Set SQLite threading mode to Serialized, means that threads may share the module, connections and cursors
sqlite3.threadsafety = 3

properties = Config.core.repository_properties
try:
db_location = properties["db_location"]
except KeyError:
raise MissingRequiredProperty("Missing property db_location")

return __build_connection(db_location)


@lru_cache
def __build_connection(db_location: str):
return sqlite3.connect(db_location, check_same_thread=False)
Loading

0 comments on commit 89b0202

Please sign in to comment.