From 9eb34addc33aa763f6254cb551eed7ba08931523 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Mon, 30 Oct 2023 15:42:53 +0700 Subject: [PATCH 01/16] draft moving to sqlite3 --- src/taipy/core/_repository/_sql_repository.py | 138 ++++++++++++++---- .../core/_version/_version_sql_repository.py | 6 +- src/taipy/core/cycle/_cycle_model.py | 18 ++- src/taipy/core/data/_data_model.py | 22 +++ src/taipy/core/job/_job_model.py | 16 ++ src/taipy/core/task/_task_model.py | 17 +++ tests/core/repository/mocks.py | 16 +- 7 files changed, 192 insertions(+), 41 deletions(-) diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index 697a31d45..2cd4302c0 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -11,14 +11,59 @@ import json import pathlib +import sqlite3 from typing import Any, Dict, Iterable, List, Optional, Type, Union +from sqlalchemy.dialects import sqlite from sqlalchemy.exc import NoResultFound +from sqlalchemy.schema import CreateTable +from taipy.config.config import Config + +from .._repository._abstract_repository import _AbstractRepository +from .._repository.db._sql_session import _SQLSession from ..common.typing import Converter, Entity, ModelType -from ..exceptions import ModelNotFound -from ._abstract_repository import _AbstractRepository -from .db._sql_session import _SQLSession +from ..exceptions import MissingRequiredProperty, ModelNotFound + +connection = None + + +def dict_factory(cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + + +def init_db(): + properties = Config.core.repository_properties + try: + db_location = properties["db_location"] + except KeyError: + raise MissingRequiredProperty("Missing property db_location") + + # More sql databases can be easily added in the future + sqlite3.threadsafety = 3 + + global connection + connection = connection if connection else sqlite3.connect(db_location, check_same_thread=False) + connection.row_factory = dict_factory + + from .._version._version_model import _VersionModel + from ..cycle._cycle_model import _CycleModel + from ..data._data_model import _DataNodeModel + from ..job._job_model import _JobModel + from ..scenario._scenario_model import _ScenarioModel + from ..task._task_model import _TaskModel + + connection.execute(str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + + return connection class _SQLRepository(_AbstractRepository[ModelType, Entity]): @@ -35,8 +80,7 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess converter: A class that handles conversion to and from a database backend db: An SQLAlchemy session object """ - SessionLocal = _SQLSession.init_db() - self.db = session or SessionLocal() + self.db = init_db() self.model_type = model_type self.converter = converter @@ -45,39 +89,57 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess ############################### def _save(self, entity: Entity): obj = self.converter._entity_to_model(entity) - if self.db.query(self.model_type).filter_by(id=obj.id).first(): + if self._exists(entity.id): self.__update_entry(obj) return self.__insert_model(obj) def _exists(self, entity_id: str): - return bool(self.db.query(self.model_type.id).filter_by(id=entity_id).first()) # type: ignore + return bool( + self.db.execute(str(self.model_type.__table__.select().filter_by(id=entity_id)), [entity_id]).fetchone() + ) def _load(self, entity_id: str) -> Entity: - if entry := self.db.query(self.model_type).filter(self.model_type.id == entity_id).first(): # type: ignore + get_query = str(self.model_type.__table__.select().filter_by(id=entity_id).compile(dialect=sqlite.dialect())) + + if entry := self.db.execute(str(get_query), [entity_id]).fetchone(): # type: ignore + entry = self.model_type.from_dict(entry) return self.converter._model_to_entity(entry) raise ModelNotFound(str(self.model_type.__name__), entity_id) + @staticmethod + def serialize_filter_values(value): + if isinstance(value, (dict, list)): + return json.dumps(value).replace('"', "'") + return value + def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]: - query = self.db.query(self.model_type) + query = self.model_type.__table__.select() entities: List[Entity] = [] for f in filters or [{}]: filtered_query = query.filter_by(**f) try: - entities.extend([self.converter._model_to_entity(m) for m in filtered_query.all()]) + entries = self.db.execute( + str(filtered_query.compile(dialect=sqlite.dialect())), + [self.serialize_filter_values(val) for val in list(f.values())], + ).fetchall() + + entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries]) except NoResultFound: continue return entities def _delete(self, entity_id: str): - number_of_deleted_entries = self.db.query(self.model_type).filter_by(id=entity_id).delete() - if not number_of_deleted_entries: - raise ModelNotFound(str(self.model_type.__name__), entity_id) + delete_query = self.model_type.__table__.delete().filter_by(id=entity_id).compile(dialect=sqlite.dialect()) + cursor = self.db.execute(str(delete_query), [entity_id]) self.db.commit() + if cursor.rowcount == 0: + raise ModelNotFound(str(self.model_type.__name__), entity_id) + def _delete_all(self): - self.db.query(self.model_type).delete() + self.db.execute(str(self.model_type.__table__.delete().compile(dialect=sqlite.dialect()))) self.db.commit() def _delete_many(self, ids: Iterable[str]): @@ -85,16 +147,22 @@ def _delete_many(self, ids: Iterable[str]): self._delete(entity_id) def _delete_by(self, attribute: str, value: str): - self.db.query(self.model_type).filter_by(**{attribute: value}).delete() + delete_by_query = ( + self.model_type.__table__.delete().filter_by(**{attribute: value}).compile(dialect=sqlite.dialect()) + ) + self.db.execute(str(delete_by_query), [value]) self.db.commit() def _search(self, attribute: str, value: Any, filters: Optional[List[Dict]] = None) -> List[Entity]: - query = self.db.query(self.model_type).filter_by(**{attribute: value}) + query = self.model_type.__table__.select().filter_by(**{attribute: value}) entities: List[Entity] = [] for f in filters or [{}]: - filtered_query = query.filter_by(**f) - entities.extend([self.converter._model_to_entity(m) for m in filtered_query.all()]) + entries = self.db.execute( + str(query.filter_by(**f).compile(dialect=sqlite.dialect())), + [value] + [self.serialize_filter_values(val) for val in list(f.values())], + ).fetchall() + entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries]) return entities @@ -110,21 +178,24 @@ def _export(self, entity_id: str, folder_path: Union[str, pathlib.Path]): export_path = export_dir / f"{entity_id}.json" - entry = self.db.query(self.model_type).filter_by(id=entity_id).first() - if entry is None: - raise ModelNotFound(self.model_type, entity_id) # type: ignore + get_query = str(self.model_type.__table__.select().filter_by(id=entity_id).compile(dialect=sqlite.dialect())) - with open(export_path, "w", encoding="utf-8") as export_file: - export_file.write(json.dumps(entry.to_dict())) + if entry := self.db.execute(str(get_query), [entity_id]).fetchone(): # type: ignore + with open(export_path, "w", encoding="utf-8") as export_file: + export_file.write(json.dumps(entry)) + else: + raise ModelNotFound(self.model_type, entity_id) # type: ignore ########################################### # ## Specific or optimized methods ## # ########################################### def _get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]: - return self.db.query(self.model_type).offset(skip).limit(limit).all() + query = str(self.model_type.__table__.select().offset(skip).limit(limit).compile(dialect=sqlite.dialect())) + return self.db.execute(query).fetchall() def _get_by_config(self, config_id: Any) -> Optional[ModelType]: - return self.db.query(self.model_type).filter(self.model_type.config_id == config_id).first() # type: ignore + query = str(self.model_type.__table__.select().filter_by(config_id=config_id).compile(dialect=sqlite.dialect())) + return self.db.execute(query, [config_id]).fetchall() def _get_by_config_and_owner_id( self, config_id: str, owner_id: Optional[str], filters: Optional[List[Dict]] = None @@ -158,22 +229,27 @@ def __get_entities_by_config_and_owner( if not filters: filters = [] versions = [item.get("version") for item in filters if item.get("version")] + + query = self.model_type.__table__.select().filter_by(config_id=config_id) + if owner_id: - query = self.db.query(self.model_type).filter_by(config_id=config_id).filter_by(owner_id=owner_id) + query = query.filter_by(owner_id=owner_id) else: - query = self.db.query(self.model_type).filter_by(config_id=config_id).filter_by(owner_id=None) + query = query.filter_by(owner_id=None) if versions: query = query.filter(self.model_type.version.in_(versions)) # type: ignore - return query.first() + query = str(query.compile(dialect=sqlite.dialect())) + return self.db.execute(query).fetchone() ############################# # ## Private methods ## # ############################# def __insert_model(self, model: ModelType): - self.db.add(model) + query = str(self.model_type.__table__.insert().compile(dialect=sqlite.dialect())) + self.db.execute(query, model.to_list(model)) self.db.commit() - self.db.refresh(model) def __update_entry(self, model): - self.db.merge(model) + query = str(self.model_type.__table__.update().filter_by(id=model.id).compile(dialect=sqlite.dialect())) + self.db.execute(query, model.to_list(model) + [model.id]) self.db.commit() diff --git a/src/taipy/core/_version/_version_sql_repository.py b/src/taipy/core/_version/_version_sql_repository.py index bff04b428..ced16d4a6 100644 --- a/src/taipy/core/_version/_version_sql_repository.py +++ b/src/taipy/core/_version/_version_sql_repository.py @@ -9,6 +9,8 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from sqlalchemy.dialects import sqlite + from .._repository._sql_repository import _SQLRepository from ..exceptions.exceptions import ModelNotFound, VersionIsNotProductionVersion from ._version_converter import _VersionConverter @@ -30,7 +32,9 @@ def _set_latest_version(self, version_number): self.db.commit() def _get_latest_version(self): - if latest := self.db.query(self.model_type).filter_by(is_latest=True).first(): + if latest := self.db.execute( + str(self.model_type.__table__.select().filter_by(is_latest=True).compile(dialect=sqlite.dialect())) + ).fetchone(): return latest.id return "" diff --git a/src/taipy/core/cycle/_cycle_model.py b/src/taipy/core/cycle/_cycle_model.py index 9cddfa7b7..af6cbe9fe 100644 --- a/src/taipy/core/cycle/_cycle_model.py +++ b/src/taipy/core/cycle/_cycle_model.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import json from dataclasses import dataclass from typing import Any, Dict @@ -45,12 +46,27 @@ class _CycleModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): + if properties := data["properties"]: + if isinstance(properties, str): + properties = json.loads(properties.replace("'", '"')) return _CycleModel( id=data["id"], name=data["name"], frequency=Frequency._from_repr(data["frequency"]), - properties=data["properties"], + properties=properties, creation_date=data["creation_date"], start_date=data["start_date"], end_date=data["end_date"], ) + + @staticmethod + def to_list(model): + return [ + model.id, + model.name, + repr(model.frequency), + json.dumps(model.properties), + model.creation_date, + model.start_date, + model.end_date, + ] diff --git a/src/taipy/core/data/_data_model.py b/src/taipy/core/data/_data_model.py index a271f5127..325e3f70c 100644 --- a/src/taipy/core/data/_data_model.py +++ b/src/taipy/core/data/_data_model.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -81,3 +82,24 @@ def from_dict(data: Dict[str, Any]): editor_expiration_date=data.get("editor_expiration_date"), data_node_properties=dn_properties, ) + + @staticmethod + def to_list(model): + return [ + model.id, + model.config_id, + repr(model.scope), + model.storage_type, + model.name, + model.owner_id, + json.dumps(model.parent_ids), + model.last_edit_date, + json.dumps(model.edits), + model.version, + model.validity_days, + model.validity_seconds, + model.edit_in_progress, + model.editor_id, + model.editor_expiration_date, + json.dumps(model.data_node_properties), + ] diff --git a/src/taipy/core/job/_job_model.py b/src/taipy/core/job/_job_model.py index a865d1dbe..0af492439 100644 --- a/src/taipy/core/job/_job_model.py +++ b/src/taipy/core/job/_job_model.py @@ -8,6 +8,7 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import json from dataclasses import dataclass from typing import Any, Dict, List @@ -61,3 +62,18 @@ def from_dict(data: Dict[str, Any]): stacktrace=data["stacktrace"], version=data["version"], ) + + @staticmethod + def to_list(model): + return [ + model.id, + model.task_id, + repr(model.status), + model.force, + model.submit_id, + model.submit_entity_id, + model.creation_date, + json.dumps(model.subscribers), + json.dumps(model.stacktrace), + model.version, + ] diff --git a/src/taipy/core/task/_task_model.py b/src/taipy/core/task/_task_model.py index e0103684e..97ae35611 100644 --- a/src/taipy/core/task/_task_model.py +++ b/src/taipy/core/task/_task_model.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -63,3 +64,19 @@ def from_dict(data: Dict[str, Any]): skippable=data["skippable"], properties=data["properties"] if "properties" in data.keys() else {}, ) + + @staticmethod + def to_list(model): + return [ + model.id, + model.owner_id, + json.dumps(model.parent_ids), + model.config_id, + json.dumps(model.input_ids), + model.function_name, + model.function_module, + json.dumps(model.output_ids), + model.version, + model.skippable, + json.dumps(model.properties), + ] diff --git a/tests/core/repository/mocks.py b/tests/core/repository/mocks.py index 491ba4c3f..6cdce5a9f 100644 --- a/tests/core/repository/mocks.py +++ b/tests/core/repository/mocks.py @@ -15,7 +15,9 @@ from typing import Any, Dict, Optional from sqlalchemy import Column, String, Table, create_engine -from sqlalchemy.orm import declarative_base, registry, sessionmaker +from sqlalchemy.dialects import sqlite +from sqlalchemy.orm import declarative_base, registry +from sqlalchemy.schema import CreateTable from src.taipy.core._repository._abstract_converter import _AbstractConverter from src.taipy.core._repository._filesystem_repository import _FileSystemRepository @@ -70,6 +72,10 @@ def _to_entity(self): def _from_entity(cls, entity: MockObj): return MockModel(id=entity.id, name=entity.name, version=entity._version) + @staticmethod + def to_list(model): + return [model.id, model.name, model.version] + class MockConverter(_AbstractConverter): @classmethod @@ -90,13 +96,7 @@ def _storage_folder(self) -> pathlib.Path: return pathlib.Path(Config.core.storage_folder) # type: ignore -def create_database(engine): - MockModel.__table__.create(engine, checkfirst=True) - - class MockSQLRepository(_SQLRepository): def __init__(self, **kwargs): - engine = create_engine("sqlite:///:memory:") - create_database(engine) - kwargs.update({"session": sessionmaker(autocommit=False, autoflush=False, bind=engine)()}) super().__init__(**kwargs) + self.db.execute(str(CreateTable(MockModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) From bdc46c527f8c23f706d8a1fe629f408111c427c8 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Wed, 1 Nov 2023 16:28:42 +0700 Subject: [PATCH 02/16] added to list to models --- src/taipy/core/_repository/_sql_repository.py | 30 ++++++++--- src/taipy/core/_version/_version_model.py | 12 +++++ .../core/_version/_version_sql_repository.py | 20 +++++-- src/taipy/core/data/_data_model.py | 9 +++- src/taipy/core/job/_job_model.py | 12 ++++- src/taipy/core/scenario/_scenario_model.py | 54 ++++++++++++++++--- src/taipy/core/task/_task_model.py | 24 +++++++-- tests/conftest.py | 41 ++++++++------ 8 files changed, 161 insertions(+), 41 deletions(-) diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index 2cd4302c0..f3eec39da 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -28,6 +28,15 @@ connection = None +from taipy.config.config import Config + +from .._repository._abstract_repository import _AbstractRepository +from .._repository.db._sql_session import _SQLSession +from ..exceptions import MissingRequiredProperty, ModelNotFound + +connection = None + + def dict_factory(cursor, row): d = {} for idx, col in enumerate(cursor.description): @@ -42,7 +51,6 @@ def init_db(): except KeyError: raise MissingRequiredProperty("Missing property db_location") - # More sql databases can be easily added in the future sqlite3.threadsafety = 3 global connection @@ -138,6 +146,9 @@ def _delete(self, entity_id: str): if cursor.rowcount == 0: raise ModelNotFound(str(self.model_type.__name__), entity_id) + if cursor.rowcount == 0: + raise ModelNotFound(str(self.model_type.__name__), entity_id) + def _delete_all(self): self.db.execute(str(self.model_type.__table__.delete().compile(dialect=sqlite.dialect()))) self.db.commit() @@ -224,22 +235,27 @@ def _get_by_configs_and_owner_ids(self, configs_and_owner_ids, filters: Optional return res def __get_entities_by_config_and_owner( - self, config_id: str, owner_id: Optional[str] = "", filters: Optional[List[Dict]] = None + self, config_id: str, owner_id: Optional[str] = None, filters: Optional[List[Dict]] = None ) -> ModelType: if not filters: filters = [] versions = [item.get("version") for item in filters if item.get("version")] query = self.model_type.__table__.select().filter_by(config_id=config_id) + parameters = [config_id] if owner_id: - query = query.filter_by(owner_id=owner_id) - else: - query = query.filter_by(owner_id=None) + parameters.append(owner_id) + query = query.filter_by(owner_id=owner_id) + if versions: - query = query.filter(self.model_type.version.in_(versions)) # type: ignore + query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore + return self.db.execute(query) + query = str(query.compile(dialect=sqlite.dialect())) - return self.db.execute(query).fetchone() + if entry := self.db.execute(query, parameters).fetchone(): + return self.model_type.from_dict(entry) + return None ############################# # ## Private methods ## # diff --git a/src/taipy/core/_version/_version_model.py b/src/taipy/core/_version/_version_model.py index 8b732e3fb..c308720d1 100644 --- a/src/taipy/core/_version/_version_model.py +++ b/src/taipy/core/_version/_version_model.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import json from dataclasses import dataclass from typing import Any, Dict @@ -42,3 +43,14 @@ def from_dict(data: Dict[str, Any]): config=data["config"], creation_date=data["creation_date"], ) + + @staticmethod + def to_list(model): + return [ + model.id, + model.config, + model.creation_date, + model.is_production, + model.is_development, + model.is_latest, + ] diff --git a/src/taipy/core/_version/_version_sql_repository.py b/src/taipy/core/_version/_version_sql_repository.py index ced16d4a6..1dcd28600 100644 --- a/src/taipy/core/_version/_version_sql_repository.py +++ b/src/taipy/core/_version/_version_sql_repository.py @@ -23,7 +23,7 @@ def __init__(self): super().__init__(model_type=_VersionModel, converter=_VersionConverter) def _set_latest_version(self, version_number): - if old_latest := self.db.query(self.model_type).filter_by(is_latest=True).first(): + if old_latest := self.db.execute(str(self.model_type.__table__.select().filter_by(is_latest=True))).fetchone(): old_latest.is_latest = False version = self.__get_by_id(version_number) @@ -39,7 +39,9 @@ def _get_latest_version(self): return "" def _set_development_version(self, version_number): - if old_development := self.db.query(self.model_type).filter_by(is_development=True).first(): + if old_development := self.db.execute( + str(self.model_type.__table__.select().filter_by(is_development=True)) + ).fetchone(): old_development.is_development = False version = self.__get_by_id(version_number) @@ -50,7 +52,9 @@ def _set_development_version(self, version_number): self.db.commit() def _get_development_version(self): - if development := self.db.query(self.model_type).filter_by(is_development=True).first(): + if development := self.db.execute( + str(self.model_type.__table__.select().filter_by(is_development=True)) + ).fetchone(): return development.id raise ModelNotFound(self.model_type, "") @@ -63,7 +67,11 @@ def _set_production_version(self, version_number): self.db.commit() def _get_production_versions(self): - if productions := self.db.query(self.model_type).filter_by(is_production=True).all(): + if productions := self.db.execute( + str(self.model_type.__table__.select().filter_by(is_production=True).compile(dialect=sqlite.dialect())), + ).fetchall(): + + # if productions := self.db.query(self.model_type).filter_by(is_production=True).all(): return [p.id for p in productions] return [] @@ -77,4 +85,6 @@ def _delete_production_version(self, version_number): self.db.commit() def __get_by_id(self, version_id): - return self.db.query(self.model_type).filter_by(id=version_id).first() + query = str(self.model_type.__table__.select().filter_by(id=version_id).compile(dialect=sqlite.dialect())) + entry = self.db.execute(query, [version_id]).fetchone() + return self.model_type.from_dict(entry) if entry else None diff --git a/src/taipy/core/data/_data_model.py b/src/taipy/core/data/_data_model.py index 325e3f70c..7f1475df9 100644 --- a/src/taipy/core/data/_data_model.py +++ b/src/taipy/core/data/_data_model.py @@ -65,6 +65,13 @@ class _DataNodeModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): dn_properties = data["data_node_properties"] + if isinstance(dn_properties, str): + dn_properties = json.loads(dn_properties.replace("'", '"')) + + edits = data["edits"] + if isinstance(edits, str): + edits = json.loads(edits.replace("'", '"')) + return _DataNodeModel( id=data["id"], config_id=data["config_id"], @@ -73,7 +80,7 @@ def from_dict(data: Dict[str, Any]): owner_id=data.get("owner_id"), parent_ids=data.get("parent_ids", []), last_edit_date=data.get("last_edit_date"), - edits=data["edits"], + edits=edits, version=data["version"], validity_days=data["validity_days"], validity_seconds=data["validity_seconds"], diff --git a/src/taipy/core/job/_job_model.py b/src/taipy/core/job/_job_model.py index 0af492439..7519b686e 100644 --- a/src/taipy/core/job/_job_model.py +++ b/src/taipy/core/job/_job_model.py @@ -50,6 +50,14 @@ class _JobModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): + subscribers = data["subscribers"] + if isinstance(subscribers, str): + subscribers = json.loads(subscribers.replace("'", '"')) + + stacktrace = data["stacktrace"] + if isinstance(stacktrace, str): + stacktrace = json.loads(stacktrace.replace("'", '"')) + return _JobModel( id=data["id"], task_id=data["task_id"], @@ -58,8 +66,8 @@ def from_dict(data: Dict[str, Any]): submit_id=data["submit_id"], submit_entity_id=data["submit_entity_id"], creation_date=data["creation_date"], - subscribers=data["subscribers"], - stacktrace=data["stacktrace"], + subscribers=subscribers, + stacktrace=stacktrace, version=data["version"], ) diff --git a/src/taipy/core/scenario/_scenario_model.py b/src/taipy/core/scenario/_scenario_model.py index 0cd040d98..0b9644619 100644 --- a/src/taipy/core/scenario/_scenario_model.py +++ b/src/taipy/core/scenario/_scenario_model.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -56,17 +57,58 @@ class _ScenarioModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): + tasks = data.get("tasks", None) + if isinstance(tasks, str): + tasks = json.loads(tasks.replace("'", '"')) + + additional_data_nodes = data.get("additional_data_nodes", None) + if isinstance(additional_data_nodes, str): + additional_data_nodes = json.loads(additional_data_nodes.replace("'", '"')) + + properties = data["properties"] + if isinstance(properties, str): + properties = json.loads(properties.replace("'", '"')) + + subscribers = data["subscribers"] + if isinstance(subscribers, str): + subscribers = json.loads(subscribers.replace("'", '"')) + + tags = data["tags"] + if isinstance(tags, str): + tags = json.loads(tags.replace("'", '"')) + + sequences = data.get("sequences", None) + if isinstance(sequences, str): + sequences = json.loads(sequences.replace("'", '"')) + return _ScenarioModel( id=data["id"], config_id=data["config_id"], - tasks=data.get("tasks", None), - additional_data_nodes=data.get("additional_data_nodes", None), - properties=data["properties"], + tasks=tasks, + additional_data_nodes=additional_data_nodes, + properties=properties, creation_date=data["creation_date"], primary_scenario=data["primary_scenario"], - subscribers=data["subscribers"], - tags=data["tags"], + subscribers=subscribers, + tags=tags, version=data["version"], - sequences=data.get("sequences", None), + sequences=sequences, cycle=CycleId(data["cycle"]) if "cycle" in data else None, ) + + @staticmethod + def to_list(model): + return [ + model.id, + model.config_id, + json.dumps(model.tasks), + json.dumps(model.additional_data_nodes), + json.dumps(model.properties), + model.creation_date, + model.primary_scenario, + json.dumps(model.subscribers), + json.dumps(model.tags), + model.version, + json.dumps(model.sequences), + model.cycle, + ] diff --git a/src/taipy/core/task/_task_model.py b/src/taipy/core/task/_task_model.py index 97ae35611..cc97d5614 100644 --- a/src/taipy/core/task/_task_model.py +++ b/src/taipy/core/task/_task_model.py @@ -51,18 +51,34 @@ class _TaskModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): + parent_ids = data.get("parent_ids", []) + if isinstance(parent_ids, str): + parent_ids = json.loads(parent_ids.replace("'", '"')) + + input_ids = data["input_ids"] + if isinstance(input_ids, str): + input_ids = json.loads(input_ids.replace("'", '"')) + + output_ids = data["output_ids"] + if isinstance(output_ids, str): + output_ids = json.loads(output_ids.replace("'", '"')) + + properties = data["properties"] if "properties" in data.keys() else {} + if isinstance(properties, str): + properties = json.loads(properties.replace("'", '"')) + return _TaskModel( id=data["id"], owner_id=data.get("owner_id"), - parent_ids=data.get("parent_ids", []), + parent_ids=parent_ids, config_id=data["config_id"], - input_ids=data["input_ids"], + input_ids=input_ids, function_name=data["function_name"], function_module=data["function_module"], - output_ids=data["output_ids"], + output_ids=output_ids, version=data["version"], skippable=data["skippable"], - properties=data["properties"] if "properties" in data.keys() else {}, + properties=properties, ) @staticmethod diff --git a/tests/conftest.py b/tests/conftest.py index 165ba357a..16038a291 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,9 +18,11 @@ import pandas as pd import pytest from sqlalchemy import create_engine, text +from sqlalchemy.dialects import sqlite +from sqlalchemy.schema import CreateTable, DropTable from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory -from src.taipy.core._repository.db._sql_session import _build_engine +from src.taipy.core._repository._sql_repository import connection from src.taipy.core._version._version import _Version from src.taipy.core._version._version_manager_factory import _VersionManagerFactory from src.taipy.core._version._version_model import _VersionModel @@ -441,20 +443,27 @@ def init_sql_repo(tmp_sqlite): Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite}) # Clean SQLite database - engine = _build_engine() - - _CycleModel.__table__.drop(bind=engine, checkfirst=True) - _DataNodeModel.__table__.drop(bind=engine, checkfirst=True) - _JobModel.__table__.drop(bind=engine, checkfirst=True) - _ScenarioModel.__table__.drop(bind=engine, checkfirst=True) - _TaskModel.__table__.drop(bind=engine, checkfirst=True) - _VersionModel.__table__.drop(bind=engine, checkfirst=True) - - _CycleModel.__table__.create(bind=engine, checkfirst=True) - _DataNodeModel.__table__.create(bind=engine, checkfirst=True) - _JobModel.__table__.create(bind=engine, checkfirst=True) - _ScenarioModel.__table__.create(bind=engine, checkfirst=True) - _TaskModel.__table__.create(bind=engine, checkfirst=True) - _VersionModel.__table__.create(bind=engine, checkfirst=True) + if connection: + connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_JobModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_ScenarioModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_TaskModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_VersionModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + + connection.execute( + str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + connection.execute( + str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute( + str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute( + str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) return tmp_sqlite From ef8346a842d7373181589e5149dee18cab6aa7b5 Mon Sep 17 00:00:00 2001 From: trgiangdo Date: Fri, 3 Nov 2023 15:50:39 +0700 Subject: [PATCH 03/16] fix: reset Core._is_running before each test --- src/taipy/core/_core.py | 8 ++++---- tests/conftest.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/taipy/core/_core.py b/src/taipy/core/_core.py index d21b518dd..1c742851f 100644 --- a/src/taipy/core/_core.py +++ b/src/taipy/core/_core.py @@ -30,7 +30,7 @@ class Core: Core service """ - __is_running = False + _is_running = False __lock_is_running = Lock() __logger = _TaipyLogger._get_logger() @@ -51,11 +51,11 @@ def run(self, force_restart=False): This function checks the configuration, manages application's version, and starts a dispatcher and lock the Config. """ - if self.__class__.__is_running: + if self.__class__._is_running: raise CoreServiceIsAlreadyRunning with self.__class__.__lock_is_running: - self.__class__.__is_running = True + self.__class__._is_running = True self.__update_core_section() self.__manage_version() @@ -79,7 +79,7 @@ def stop(self): self.__logger.info("Core service has been stopped.") with self.__class__.__lock_is_running: - self.__class__.__is_running = False + self.__class__._is_running = False @staticmethod def __update_core_section(): diff --git a/tests/conftest.py b/tests/conftest.py index 16038a291..55a089f46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ from sqlalchemy.dialects import sqlite from sqlalchemy.schema import CreateTable, DropTable +from src.taipy.core._core import Core from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory from src.taipy.core._repository._sql_repository import connection from src.taipy.core._version._version import _Version @@ -409,6 +410,7 @@ def init_config(): _Checker.add_checker(_ScenarioConfigChecker) Config.configure_core(read_entity_retry=0) + Core._is_running = False def init_managers(): From b3ab8135c62054a8150a8bbf4f6b658c89592f35 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Fri, 3 Nov 2023 15:55:13 +0700 Subject: [PATCH 04/16] fixed failed version query --- src/taipy/core/_repository/_sql_repository.py | 12 ++++++---- src/taipy/core/_version/_version_model.py | 6 ++++- .../core/_version/_version_sql_repository.py | 24 +++++++++---------- tests/conftest.py | 1 + 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index f3eec39da..cce620973 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -98,7 +98,7 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess def _save(self, entity: Entity): obj = self.converter._entity_to_model(entity) if self._exists(entity.id): - self.__update_entry(obj) + self._update_entry(obj) return self.__insert_model(obj) @@ -247,12 +247,14 @@ def __get_entities_by_config_and_owner( if owner_id: parameters.append(owner_id) query = query.filter_by(owner_id=owner_id) + query = str(query.compile(dialect=sqlite.dialect())) if versions: - query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore - return self.db.execute(query) - query = str(query.compile(dialect=sqlite.dialect())) + query = query + f" AND {self.model_type.__table__.name}.version IN ({','.join(['?']*len(versions))})" + # query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore + parameters.extend(versions) + if entry := self.db.execute(query, parameters).fetchone(): return self.model_type.from_dict(entry) return None @@ -265,7 +267,7 @@ def __insert_model(self, model: ModelType): self.db.execute(query, model.to_list(model)) self.db.commit() - def __update_entry(self, model): + def _update_entry(self, model): query = str(self.model_type.__table__.update().filter_by(id=model.id).compile(dialect=sqlite.dialect())) self.db.execute(query, model.to_list(model) + [model.id]) self.db.commit() diff --git a/src/taipy/core/_version/_version_model.py b/src/taipy/core/_version/_version_model.py index c308720d1..1d9cc37b6 100644 --- a/src/taipy/core/_version/_version_model.py +++ b/src/taipy/core/_version/_version_model.py @@ -38,11 +38,15 @@ class _VersionModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): - return _VersionModel( + model = _VersionModel( id=data["id"], config=data["config"], creation_date=data["creation_date"], ) + model.is_production = data.get("is_production") + model.is_development = data.get("is_development") + model.is_latest = data.get("is_latest") + return model @staticmethod def to_list(model): diff --git a/src/taipy/core/_version/_version_sql_repository.py b/src/taipy/core/_version/_version_sql_repository.py index 1dcd28600..ee173544b 100644 --- a/src/taipy/core/_version/_version_sql_repository.py +++ b/src/taipy/core/_version/_version_sql_repository.py @@ -24,55 +24,54 @@ def __init__(self): def _set_latest_version(self, version_number): if old_latest := self.db.execute(str(self.model_type.__table__.select().filter_by(is_latest=True))).fetchone(): + old_latest = self.model_type.from_dict(old_latest) old_latest.is_latest = False + self._update_entry(old_latest) version = self.__get_by_id(version_number) version.is_latest = True - - self.db.commit() + self._update_entry(version) def _get_latest_version(self): if latest := self.db.execute( str(self.model_type.__table__.select().filter_by(is_latest=True).compile(dialect=sqlite.dialect())) ).fetchone(): - return latest.id + return latest["id"] return "" def _set_development_version(self, version_number): if old_development := self.db.execute( str(self.model_type.__table__.select().filter_by(is_development=True)) ).fetchone(): + old_development = self.model_type.from_dict(old_development) old_development.is_development = False + self._update_entry(old_development) version = self.__get_by_id(version_number) version.is_development = True + self._update_entry(version) self._set_latest_version(version_number) - self.db.commit() - def _get_development_version(self): if development := self.db.execute( str(self.model_type.__table__.select().filter_by(is_development=True)) ).fetchone(): - return development.id + return development["id"] raise ModelNotFound(self.model_type, "") def _set_production_version(self, version_number): version = self.__get_by_id(version_number) version.is_production = True + self._update_entry(version) self._set_latest_version(version_number) - self.db.commit() - def _get_production_versions(self): if productions := self.db.execute( str(self.model_type.__table__.select().filter_by(is_production=True).compile(dialect=sqlite.dialect())), ).fetchall(): - - # if productions := self.db.query(self.model_type).filter_by(is_production=True).all(): - return [p.id for p in productions] + return [p["id"] for p in productions] return [] def _delete_production_version(self, version_number): @@ -81,8 +80,7 @@ def _delete_production_version(self, version_number): if not version or not version.is_production: raise VersionIsNotProductionVersion(f"Version '{version_number}' is not a production version.") version.is_production = False - - self.db.commit() + self._update_entry(version) def __get_by_id(self, version_id): query = str(self.model_type.__table__.select().filter_by(id=version_id).compile(dialect=sqlite.dialect())) diff --git a/tests/conftest.py b/tests/conftest.py index 55a089f46..027790b68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -444,6 +444,7 @@ def sql_engine(): def init_sql_repo(tmp_sqlite): Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite}) + init_managers() # Clean SQLite database if connection: connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) From 7d9aa8b6ecd937ce25978ec1fc99e4ac37c6ae19 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Mon, 6 Nov 2023 20:01:20 +0700 Subject: [PATCH 05/16] fixed migration version for sql --- src/taipy/core/_version/_cli/_version_cli.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/taipy/core/_version/_cli/_version_cli.py b/src/taipy/core/_version/_cli/_version_cli.py index ba54420ae..dac3ef436 100644 --- a/src/taipy/core/_version/_cli/_version_cli.py +++ b/src/taipy/core/_version/_cli/_version_cli.py @@ -17,7 +17,7 @@ from taipy.logger._taipy_logger import _TaipyLogger from ...data._data_manager_factory import _DataManagerFactory -from ...exceptions.exceptions import VersionIsNotProductionVersion +from ...exceptions.exceptions import ModelNotFound, VersionIsNotProductionVersion from ...job._job_manager_factory import _JobManagerFactory from ...scenario._scenario_manager_factory import _ScenarioManagerFactory from ...sequence._sequence_manager_factory import _SequenceManagerFactory @@ -199,8 +199,10 @@ def __rename_version(cls, old_version: str, new_version: str): _version_manager._delete_production_version(old_version) except VersionIsNotProductionVersion: pass - version_entity.id = new_version - _version_manager._set(version_entity) + + if not _version_manager._get(new_version): + version_entity.id = new_version + _version_manager._set(version_entity) @classmethod def __compare_version_config(cls, version_1: str, version_2: str): From 5f8e358fe4b6476cb2d9295af0450a20a268f651 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Tue, 7 Nov 2023 13:39:32 +0700 Subject: [PATCH 06/16] added changes to to list dn --- src/taipy/core/data/_data_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/taipy/core/data/_data_model.py b/src/taipy/core/data/_data_model.py index 7f1475df9..c04ff1cd3 100644 --- a/src/taipy/core/data/_data_model.py +++ b/src/taipy/core/data/_data_model.py @@ -97,7 +97,6 @@ def to_list(model): model.config_id, repr(model.scope), model.storage_type, - model.name, model.owner_id, json.dumps(model.parent_ids), model.last_edit_date, From 5cdcc8642f90ec34211c92613a07520d16a77a90 Mon Sep 17 00:00:00 2001 From: trgiangdo Date: Sat, 11 Nov 2023 02:16:41 +0700 Subject: [PATCH 07/16] fix: Model.to_list() should use the proper _Encoder and _Decoder --- .../core/_repository/_base_taipy_model.py | 14 ++++++ src/taipy/core/_repository/_encoder.py | 4 +- src/taipy/core/_version/_version_model.py | 7 ++- src/taipy/core/cycle/_cycle_model.py | 8 +-- src/taipy/core/data/_data_model.py | 19 ++----- src/taipy/core/job/_job_model.py | 18 ++----- src/taipy/core/scenario/_scenario_model.py | 49 +++++-------------- src/taipy/core/task/_task_model.py | 33 +++---------- 8 files changed, 50 insertions(+), 102 deletions(-) diff --git a/src/taipy/core/_repository/_base_taipy_model.py b/src/taipy/core/_repository/_base_taipy_model.py index 3c4b8b8b5..07d36c919 100644 --- a/src/taipy/core/_repository/_base_taipy_model.py +++ b/src/taipy/core/_repository/_base_taipy_model.py @@ -11,8 +11,12 @@ import dataclasses import enum +import json from typing import Any, Dict +from ._decoder import _Decoder +from ._encoder import _Encoder + class _BaseModel: def __iter__(self): @@ -26,3 +30,13 @@ def to_dict(self) -> Dict[str, Any]: if isinstance(v, enum.Enum): model_dict[k] = repr(v) return model_dict + + @staticmethod + def _serialize_attribute(value): + return json.dumps(value, ensure_ascii=False, cls=_Encoder) + + @staticmethod + def _deserialize_attribute(value): + if isinstance(value, str): + return json.loads(value.replace("'", '"'), cls=_Decoder) + return value diff --git a/src/taipy/core/_repository/_encoder.py b/src/taipy/core/_repository/_encoder.py index 38de6372a..ab48870bf 100644 --- a/src/taipy/core/_repository/_encoder.py +++ b/src/taipy/core/_repository/_encoder.py @@ -14,8 +14,6 @@ from enum import Enum from typing import Any -from ..common.typing import Json - class _Encoder(json.JSONEncoder): def _timedelta_to_str(self, obj: timedelta) -> str: @@ -27,7 +25,7 @@ def _timedelta_to_str(self, obj: timedelta) -> str: f"{int(total_seconds % 60)}s" ) - def default(self, o: Any) -> Json: + def default(self, o: Any): if isinstance(o, Enum): result = o.value elif isinstance(o, datetime): diff --git a/src/taipy/core/_version/_version_model.py b/src/taipy/core/_version/_version_model.py index 1d9cc37b6..b4e1a0561 100644 --- a/src/taipy/core/_version/_version_model.py +++ b/src/taipy/core/_version/_version_model.py @@ -9,7 +9,6 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -import json from dataclasses import dataclass from typing import Any, Dict @@ -43,9 +42,9 @@ def from_dict(data: Dict[str, Any]): config=data["config"], creation_date=data["creation_date"], ) - model.is_production = data.get("is_production") - model.is_development = data.get("is_development") - model.is_latest = data.get("is_latest") + model.is_production = data.get("is_production") # type: ignore + model.is_development = data.get("is_development") # type: ignore + model.is_latest = data.get("is_latest") # type: ignore return model @staticmethod diff --git a/src/taipy/core/cycle/_cycle_model.py b/src/taipy/core/cycle/_cycle_model.py index af6cbe9fe..da8122a89 100644 --- a/src/taipy/core/cycle/_cycle_model.py +++ b/src/taipy/core/cycle/_cycle_model.py @@ -9,7 +9,6 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -import json from dataclasses import dataclass from typing import Any, Dict @@ -46,14 +45,11 @@ class _CycleModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): - if properties := data["properties"]: - if isinstance(properties, str): - properties = json.loads(properties.replace("'", '"')) return _CycleModel( id=data["id"], name=data["name"], frequency=Frequency._from_repr(data["frequency"]), - properties=properties, + properties=_BaseModel._deserialize_attribute(data["properties"]), creation_date=data["creation_date"], start_date=data["start_date"], end_date=data["end_date"], @@ -65,7 +61,7 @@ def to_list(model): model.id, model.name, repr(model.frequency), - json.dumps(model.properties), + _BaseModel._serialize_attribute(model.properties), model.creation_date, model.start_date, model.end_date, diff --git a/src/taipy/core/data/_data_model.py b/src/taipy/core/data/_data_model.py index c04ff1cd3..d6ccbec02 100644 --- a/src/taipy/core/data/_data_model.py +++ b/src/taipy/core/data/_data_model.py @@ -9,7 +9,6 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -64,14 +63,6 @@ class _DataNodeModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): - dn_properties = data["data_node_properties"] - if isinstance(dn_properties, str): - dn_properties = json.loads(dn_properties.replace("'", '"')) - - edits = data["edits"] - if isinstance(edits, str): - edits = json.loads(edits.replace("'", '"')) - return _DataNodeModel( id=data["id"], config_id=data["config_id"], @@ -80,14 +71,14 @@ def from_dict(data: Dict[str, Any]): owner_id=data.get("owner_id"), parent_ids=data.get("parent_ids", []), last_edit_date=data.get("last_edit_date"), - edits=edits, + edits=_BaseModel._deserialize_attribute(data["edits"]), version=data["version"], validity_days=data["validity_days"], validity_seconds=data["validity_seconds"], edit_in_progress=bool(data.get("edit_in_progress", False)), editor_id=data.get("editor_id", None), editor_expiration_date=data.get("editor_expiration_date"), - data_node_properties=dn_properties, + data_node_properties=_BaseModel._deserialize_attribute(data["data_node_properties"]), ) @staticmethod @@ -98,14 +89,14 @@ def to_list(model): repr(model.scope), model.storage_type, model.owner_id, - json.dumps(model.parent_ids), + _BaseModel._serialize_attribute(model.parent_ids), model.last_edit_date, - json.dumps(model.edits), + _BaseModel._serialize_attribute(model.edits), model.version, model.validity_days, model.validity_seconds, model.edit_in_progress, model.editor_id, model.editor_expiration_date, - json.dumps(model.data_node_properties), + _BaseModel._serialize_attribute(model.data_node_properties), ] diff --git a/src/taipy/core/job/_job_model.py b/src/taipy/core/job/_job_model.py index 7519b686e..e31e38da7 100644 --- a/src/taipy/core/job/_job_model.py +++ b/src/taipy/core/job/_job_model.py @@ -8,7 +8,7 @@ # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -import json + from dataclasses import dataclass from typing import Any, Dict, List @@ -50,14 +50,6 @@ class _JobModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): - subscribers = data["subscribers"] - if isinstance(subscribers, str): - subscribers = json.loads(subscribers.replace("'", '"')) - - stacktrace = data["stacktrace"] - if isinstance(stacktrace, str): - stacktrace = json.loads(stacktrace.replace("'", '"')) - return _JobModel( id=data["id"], task_id=data["task_id"], @@ -66,8 +58,8 @@ def from_dict(data: Dict[str, Any]): submit_id=data["submit_id"], submit_entity_id=data["submit_entity_id"], creation_date=data["creation_date"], - subscribers=subscribers, - stacktrace=stacktrace, + subscribers=_BaseModel._deserialize_attribute(data["subscribers"]), + stacktrace=data["stacktrace"], version=data["version"], ) @@ -81,7 +73,7 @@ def to_list(model): model.submit_id, model.submit_entity_id, model.creation_date, - json.dumps(model.subscribers), - json.dumps(model.stacktrace), + _BaseModel._serialize_attribute(model.subscribers), + _BaseModel._serialize_attribute(model.stacktrace), model.version, ] diff --git a/src/taipy/core/scenario/_scenario_model.py b/src/taipy/core/scenario/_scenario_model.py index 0b9644619..9eb9993af 100644 --- a/src/taipy/core/scenario/_scenario_model.py +++ b/src/taipy/core/scenario/_scenario_model.py @@ -9,7 +9,6 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -57,42 +56,18 @@ class _ScenarioModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): - tasks = data.get("tasks", None) - if isinstance(tasks, str): - tasks = json.loads(tasks.replace("'", '"')) - - additional_data_nodes = data.get("additional_data_nodes", None) - if isinstance(additional_data_nodes, str): - additional_data_nodes = json.loads(additional_data_nodes.replace("'", '"')) - - properties = data["properties"] - if isinstance(properties, str): - properties = json.loads(properties.replace("'", '"')) - - subscribers = data["subscribers"] - if isinstance(subscribers, str): - subscribers = json.loads(subscribers.replace("'", '"')) - - tags = data["tags"] - if isinstance(tags, str): - tags = json.loads(tags.replace("'", '"')) - - sequences = data.get("sequences", None) - if isinstance(sequences, str): - sequences = json.loads(sequences.replace("'", '"')) - return _ScenarioModel( id=data["id"], config_id=data["config_id"], - tasks=tasks, - additional_data_nodes=additional_data_nodes, - properties=properties, + tasks=_BaseModel._deserialize_attribute(data["tasks"]), + additional_data_nodes=_BaseModel._deserialize_attribute(data["additional_data_nodes"]), + properties=_BaseModel._deserialize_attribute(data["properties"]), creation_date=data["creation_date"], primary_scenario=data["primary_scenario"], - subscribers=subscribers, - tags=tags, + subscribers=_BaseModel._deserialize_attribute(data["subscribers"]), + tags=_BaseModel._deserialize_attribute(data["tags"]), version=data["version"], - sequences=sequences, + sequences=_BaseModel._deserialize_attribute(data["sequences"]), cycle=CycleId(data["cycle"]) if "cycle" in data else None, ) @@ -101,14 +76,14 @@ def to_list(model): return [ model.id, model.config_id, - json.dumps(model.tasks), - json.dumps(model.additional_data_nodes), - json.dumps(model.properties), + _BaseModel._serialize_attribute(model.tasks), + _BaseModel._serialize_attribute(model.additional_data_nodes), + _BaseModel._serialize_attribute(model.properties), model.creation_date, model.primary_scenario, - json.dumps(model.subscribers), - json.dumps(model.tags), + _BaseModel._serialize_attribute(model.subscribers), + _BaseModel._serialize_attribute(model.tags), model.version, - json.dumps(model.sequences), + _BaseModel._serialize_attribute(model.sequences), model.cycle, ] diff --git a/src/taipy/core/task/_task_model.py b/src/taipy/core/task/_task_model.py index cc97d5614..370dc384b 100644 --- a/src/taipy/core/task/_task_model.py +++ b/src/taipy/core/task/_task_model.py @@ -9,7 +9,6 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -import json from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -51,34 +50,18 @@ class _TaskModel(_BaseModel): @staticmethod def from_dict(data: Dict[str, Any]): - parent_ids = data.get("parent_ids", []) - if isinstance(parent_ids, str): - parent_ids = json.loads(parent_ids.replace("'", '"')) - - input_ids = data["input_ids"] - if isinstance(input_ids, str): - input_ids = json.loads(input_ids.replace("'", '"')) - - output_ids = data["output_ids"] - if isinstance(output_ids, str): - output_ids = json.loads(output_ids.replace("'", '"')) - - properties = data["properties"] if "properties" in data.keys() else {} - if isinstance(properties, str): - properties = json.loads(properties.replace("'", '"')) - return _TaskModel( id=data["id"], owner_id=data.get("owner_id"), - parent_ids=parent_ids, + parent_ids=_BaseModel._deserialize_attribute(data.get("parent_ids", [])), config_id=data["config_id"], - input_ids=input_ids, + input_ids=_BaseModel._deserialize_attribute(data["input_ids"]), function_name=data["function_name"], function_module=data["function_module"], - output_ids=output_ids, + output_ids=_BaseModel._deserialize_attribute(data["output_ids"]), version=data["version"], skippable=data["skippable"], - properties=properties, + properties=_BaseModel._deserialize_attribute(data["properties"] if "properties" in data.keys() else {}), ) @staticmethod @@ -86,13 +69,13 @@ def to_list(model): return [ model.id, model.owner_id, - json.dumps(model.parent_ids), + _BaseModel._serialize_attribute(model.parent_ids), model.config_id, - json.dumps(model.input_ids), + _BaseModel._serialize_attribute(model.input_ids), model.function_name, model.function_module, - json.dumps(model.output_ids), + _BaseModel._serialize_attribute(model.output_ids), model.version, model.skippable, - json.dumps(model.properties), + _BaseModel._serialize_attribute(model.properties), ] From 59e79af9069cf75df333241d5d0e18924b283973 Mon Sep 17 00:00:00 2001 From: trgiangdo Date: Sat, 11 Nov 2023 02:21:06 +0700 Subject: [PATCH 08/16] fix: cache sqlite_3 connection in _SQLConnection class --- src/taipy/core/_repository/_sql_repository.py | 60 +------------- .../core/_repository/db/_sql_connection.py | 83 +++++++++++++++++++ src/taipy/core/_repository/db/_sql_session.py | 70 ---------------- tests/conftest.py | 39 ++++----- .../job/test_job_manager_with_sql_repo.py | 7 +- 5 files changed, 105 insertions(+), 154 deletions(-) create mode 100644 src/taipy/core/_repository/db/_sql_connection.py delete mode 100644 src/taipy/core/_repository/db/_sql_session.py diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index cce620973..082af73ae 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -11,71 +11,19 @@ import json import pathlib -import sqlite3 from typing import Any, Dict, Iterable, List, Optional, Type, Union from sqlalchemy.dialects import sqlite from sqlalchemy.exc import NoResultFound -from sqlalchemy.schema import CreateTable - -from taipy.config.config import Config from .._repository._abstract_repository import _AbstractRepository -from .._repository.db._sql_session import _SQLSession from ..common.typing import Converter, Entity, ModelType -from ..exceptions import MissingRequiredProperty, ModelNotFound - -connection = None - - -from taipy.config.config import Config - -from .._repository._abstract_repository import _AbstractRepository -from .._repository.db._sql_session import _SQLSession -from ..exceptions import MissingRequiredProperty, ModelNotFound - -connection = None - - -def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - - -def init_db(): - properties = Config.core.repository_properties - try: - db_location = properties["db_location"] - except KeyError: - raise MissingRequiredProperty("Missing property db_location") - - sqlite3.threadsafety = 3 - - global connection - connection = connection if connection else sqlite3.connect(db_location, check_same_thread=False) - connection.row_factory = dict_factory - - from .._version._version_model import _VersionModel - from ..cycle._cycle_model import _CycleModel - from ..data._data_model import _DataNodeModel - from ..job._job_model import _JobModel - from ..scenario._scenario_model import _ScenarioModel - from ..task._task_model import _TaskModel - - connection.execute(str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - - return connection +from ..exceptions import ModelNotFound +from .db._sql_connection import _SQLConnection class _SQLRepository(_AbstractRepository[ModelType, Entity]): - def __init__(self, model_type: Type[ModelType], converter: Type[Converter], session=None): + def __init__(self, model_type: Type[ModelType], converter: Type[Converter]): """ Holds common methods to be used and extended when the need for saving dataclasses in a SqlLite database. @@ -88,7 +36,7 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess converter: A class that handles conversion to and from a database backend db: An SQLAlchemy session object """ - self.db = init_db() + self.db = _SQLConnection.init_db() self.model_type = model_type self.converter = converter diff --git a/src/taipy/core/_repository/db/_sql_connection.py b/src/taipy/core/_repository/db/_sql_connection.py new file mode 100644 index 000000000..1cad4c1bf --- /dev/null +++ b/src/taipy/core/_repository/db/_sql_connection.py @@ -0,0 +1,83 @@ +# Copyright 2023 Avaiga Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import sqlite3 +from functools import lru_cache +from sqlite3 import Connection + +from sqlalchemy.dialects import sqlite +from sqlalchemy.schema import CreateTable + +from taipy.config.config import Config + +from ...exceptions import MissingRequiredProperty + + +class _SQLConnection: + _connection = None + + @classmethod + def dict_factory(cls, cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + + @classmethod + def init_db(cls): + if cls._connection: + return cls._connection + + cls._connection = _build_connection() + cls._connection.row_factory = cls.dict_factory + + from ..._version._version_model import _VersionModel + from ...cycle._cycle_model import _CycleModel + from ...data._data_model import _DataNodeModel + from ...job._job_model import _JobModel + from ...scenario._scenario_model import _ScenarioModel + from ...task._task_model import _TaskModel + + cls._connection.execute( + str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + cls._connection.execute( + str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + cls._connection.execute( + str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + cls._connection.execute( + str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + cls._connection.execute( + str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + cls._connection.execute( + str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) + ) + + return cls._connection + + +@lru_cache +def _build_connection() -> Connection: + # Set SQLite threading mode to Serialized, means that threads may share the module, connections and cursors + sqlite3.threadsafety = 3 + + properties = Config.core.repository_properties + try: + db_location = properties["db_location"] + except KeyError: + raise MissingRequiredProperty("Missing property db_location") + + connection = sqlite3.connect(db_location, check_same_thread=False) + return connection diff --git a/src/taipy/core/_repository/db/_sql_session.py b/src/taipy/core/_repository/db/_sql_session.py deleted file mode 100644 index 5192ce88f..000000000 --- a/src/taipy/core/_repository/db/_sql_session.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2023 Avaiga Private Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. - -from functools import lru_cache - -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import StaticPool - -from taipy.config.config import Config - -from ...exceptions import MissingRequiredProperty -from .._decoder import loads -from .._encoder import dumps - - -class _SQLSession: - _engine = None - _SessionLocal = None - - @classmethod - def init_db(cls): - if cls._SessionLocal: - return cls._SessionLocal - - cls._engine = _build_engine() - cls._SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=cls._engine) - - from ....core._version._version_model import _VersionModel - from ....core.cycle._cycle_model import _CycleModel - from ....core.data._data_model import _DataNodeModel - from ....core.job._job_model import _JobModel - from ....core.scenario._scenario_model import _ScenarioModel - from ....core.task._task_model import _TaskModel - - _CycleModel.__table__.create(bind=cls._engine, checkfirst=True) - _DataNodeModel.__table__.create(bind=cls._engine, checkfirst=True) - _JobModel.__table__.create(bind=cls._engine, checkfirst=True) - _ScenarioModel.__table__.create(bind=cls._engine, checkfirst=True) - _TaskModel.__table__.create(bind=cls._engine, checkfirst=True) - _VersionModel.__table__.create(bind=cls._engine, checkfirst=True) - - return cls._SessionLocal - - -@lru_cache -def _build_engine() -> Engine: - properties = Config.core.repository_properties - try: - db_location = properties["db_location"] - except KeyError: - raise MissingRequiredProperty("Missing property db_location") - - # More sql databases can be easily added in the future - engine = create_engine( - f"sqlite:///{db_location}?check_same_thread=False", - poolclass=StaticPool, - json_serializer=dumps, - json_deserializer=loads, - ) - return engine diff --git a/tests/conftest.py b/tests/conftest.py index 027790b68..161f06744 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from src.taipy.core._core import Core from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory -from src.taipy.core._repository._sql_repository import connection +from src.taipy.core._repository.db._sql_connection import _build_connection from src.taipy.core._version._version import _Version from src.taipy.core._version._version_manager_factory import _VersionManagerFactory from src.taipy.core._version._version_model import _VersionModel @@ -444,29 +444,20 @@ def sql_engine(): def init_sql_repo(tmp_sqlite): Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite}) - init_managers() # Clean SQLite database - if connection: - connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_JobModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_ScenarioModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_TaskModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_VersionModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - - connection.execute( - str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) - ) - connection.execute( - str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) - ) - connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute( - str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) - ) - connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute( - str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())) - ) + connection = _build_connection() + connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_JobModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_ScenarioModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_TaskModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(DropTable(_VersionModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) + + connection.execute(str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + connection.execute(str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) return tmp_sqlite diff --git a/tests/core/job/test_job_manager_with_sql_repo.py b/tests/core/job/test_job_manager_with_sql_repo.py index 38547b74e..281695a7e 100644 --- a/tests/core/job/test_job_manager_with_sql_repo.py +++ b/tests/core/job/test_job_manager_with_sql_repo.py @@ -20,7 +20,7 @@ from src.taipy.core import Task from src.taipy.core._orchestrator._dispatcher._job_dispatcher import _JobDispatcher from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory -from src.taipy.core._repository.db._sql_session import _build_engine, _SQLSession +from src.taipy.core._repository.db._sql_connection import _build_connection, _SQLConnection from src.taipy.core.config.job_config import JobConfig from src.taipy.core.data import InMemoryDataNode from src.taipy.core.data._data_manager import _DataManager @@ -53,9 +53,8 @@ def init_managers(): def clear_sql_session(): - _build_engine.cache_clear() - _SQLSession._SessionLocal = None - _SQLSession._engine = None + _build_connection.cache_clear() + _SQLConnection._connection = None def test_create_jobs(init_sql_repo): From 32caa9eb2d813dfc900e4823292a80a78e4992a6 Mon Sep 17 00:00:00 2001 From: trgiangdo Date: Sat, 11 Nov 2023 02:27:11 +0700 Subject: [PATCH 09/16] fix: linter error --- src/taipy/core/_repository/_sql_repository.py | 87 ++++++++++--------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index 082af73ae..745444524 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -45,32 +45,25 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter]): ############################### def _save(self, entity: Entity): obj = self.converter._entity_to_model(entity) - if self._exists(entity.id): + if self._exists(entity.id): # type: ignore self._update_entry(obj) return self.__insert_model(obj) def _exists(self, entity_id: str): - return bool( - self.db.execute(str(self.model_type.__table__.select().filter_by(id=entity_id)), [entity_id]).fetchone() - ) + query = self.model_type.__table__.select().filter_by(id=entity_id) # type: ignore + return bool(self.db.execute(str(query), [entity_id]).fetchone()) def _load(self, entity_id: str) -> Entity: - get_query = str(self.model_type.__table__.select().filter_by(id=entity_id).compile(dialect=sqlite.dialect())) + query = self.model_type.__table__.select().filter_by(id=entity_id) # type: ignore - if entry := self.db.execute(str(get_query), [entity_id]).fetchone(): # type: ignore - entry = self.model_type.from_dict(entry) + if entry := self.db.execute(str(query.compile(dialect=sqlite.dialect())), [entity_id]).fetchone(): + entry = self.model_type.from_dict(entry) # type: ignore return self.converter._model_to_entity(entry) raise ModelNotFound(str(self.model_type.__name__), entity_id) - @staticmethod - def serialize_filter_values(value): - if isinstance(value, (dict, list)): - return json.dumps(value).replace('"', "'") - return value - def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]: - query = self.model_type.__table__.select() + query = self.model_type.__table__.select() # type: ignore entities: List[Entity] = [] for f in filters or [{}]: @@ -78,17 +71,19 @@ def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]: try: entries = self.db.execute( str(filtered_query.compile(dialect=sqlite.dialect())), - [self.serialize_filter_values(val) for val in list(f.values())], + [self.__serialize_filter_values(val) for val in list(f.values())], ).fetchall() - entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries]) + entities.extend( + [self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries] # type: ignore + ) except NoResultFound: continue return entities def _delete(self, entity_id: str): - delete_query = self.model_type.__table__.delete().filter_by(id=entity_id).compile(dialect=sqlite.dialect()) - cursor = self.db.execute(str(delete_query), [entity_id]) + delete_query = self.model_type.__table__.delete().filter_by(id=entity_id) # type: ignore + cursor = self.db.execute(str(delete_query.compile(dialect=sqlite.dialect())), [entity_id]) self.db.commit() if cursor.rowcount == 0: @@ -98,7 +93,7 @@ def _delete(self, entity_id: str): raise ModelNotFound(str(self.model_type.__name__), entity_id) def _delete_all(self): - self.db.execute(str(self.model_type.__table__.delete().compile(dialect=sqlite.dialect()))) + self.db.execute(str(self.model_type.__table__.delete().compile(dialect=sqlite.dialect()))) # type: ignore self.db.commit() def _delete_many(self, ids: Iterable[str]): @@ -106,22 +101,23 @@ def _delete_many(self, ids: Iterable[str]): self._delete(entity_id) def _delete_by(self, attribute: str, value: str): - delete_by_query = ( - self.model_type.__table__.delete().filter_by(**{attribute: value}).compile(dialect=sqlite.dialect()) - ) - self.db.execute(str(delete_by_query), [value]) + delete_by_query = self.model_type.__table__.delete().filter_by(**{attribute: value}) # type: ignore + + self.db.execute(str(delete_by_query.compile(dialect=sqlite.dialect())), [value]) self.db.commit() def _search(self, attribute: str, value: Any, filters: Optional[List[Dict]] = None) -> List[Entity]: - query = self.model_type.__table__.select().filter_by(**{attribute: value}) + query = self.model_type.__table__.select().filter_by(**{attribute: value}) # type: ignore entities: List[Entity] = [] for f in filters or [{}]: entries = self.db.execute( str(query.filter_by(**f).compile(dialect=sqlite.dialect())), - [value] + [self.serialize_filter_values(val) for val in list(f.values())], + [value] + [self.__serialize_filter_values(val) for val in list(f.values())], ).fetchall() - entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries]) + entities.extend( + [self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries] # type: ignore + ) return entities @@ -137,9 +133,9 @@ def _export(self, entity_id: str, folder_path: Union[str, pathlib.Path]): export_path = export_dir / f"{entity_id}.json" - get_query = str(self.model_type.__table__.select().filter_by(id=entity_id).compile(dialect=sqlite.dialect())) + query = self.model_type.__table__.select().filter_by(id=entity_id) # type: ignore - if entry := self.db.execute(str(get_query), [entity_id]).fetchone(): # type: ignore + if entry := self.db.execute(str(query.compile(dialect=sqlite.dialect())), [entity_id]).fetchone(): with open(export_path, "w", encoding="utf-8") as export_file: export_file.write(json.dumps(entry)) else: @@ -149,12 +145,12 @@ def _export(self, entity_id: str, folder_path: Union[str, pathlib.Path]): # ## Specific or optimized methods ## # ########################################### def _get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]: - query = str(self.model_type.__table__.select().offset(skip).limit(limit).compile(dialect=sqlite.dialect())) - return self.db.execute(query).fetchall() + query = self.model_type.__table__.select().offset(skip).limit(limit) # type: ignore + return self.db.execute(str(query.compile(dialect=sqlite.dialect()))).fetchall() def _get_by_config(self, config_id: Any) -> Optional[ModelType]: - query = str(self.model_type.__table__.select().filter_by(config_id=config_id).compile(dialect=sqlite.dialect())) - return self.db.execute(query, [config_id]).fetchall() + query = self.model_type.__table__.select().filter_by(config_id=config_id) # type: ignore + return self.db.execute(str(query.compile(dialect=sqlite.dialect())), [config_id]).fetchall() def _get_by_config_and_owner_id( self, config_id: str, owner_id: Optional[str], filters: Optional[List[Dict]] = None @@ -184,12 +180,12 @@ def _get_by_configs_and_owner_ids(self, configs_and_owner_ids, filters: Optional def __get_entities_by_config_and_owner( self, config_id: str, owner_id: Optional[str] = None, filters: Optional[List[Dict]] = None - ) -> ModelType: + ) -> Optional[ModelType]: if not filters: filters = [] versions = [item.get("version") for item in filters if item.get("version")] - query = self.model_type.__table__.select().filter_by(config_id=config_id) + query = self.model_type.__table__.select().filter_by(config_id=config_id) # type: ignore parameters = [config_id] if owner_id: @@ -198,24 +194,29 @@ def __get_entities_by_config_and_owner( query = str(query.compile(dialect=sqlite.dialect())) if versions: - - query = query + f" AND {self.model_type.__table__.name}.version IN ({','.join(['?']*len(versions))})" - # query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore - parameters.extend(versions) + table_name = self.model_type.__table__.name # type: ignore + query = query + f" AND {table_name}.version IN ({','.join(['?']*len(versions))})" # type: ignore + parameters.extend(versions) # type: ignore if entry := self.db.execute(query, parameters).fetchone(): - return self.model_type.from_dict(entry) + return self.model_type.from_dict(entry) # type: ignore return None ############################# # ## Private methods ## # ############################# def __insert_model(self, model: ModelType): - query = str(self.model_type.__table__.insert().compile(dialect=sqlite.dialect())) - self.db.execute(query, model.to_list(model)) + query = self.model_type.__table__.insert() # type: ignore + self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list(model)) # type: ignore self.db.commit() def _update_entry(self, model): - query = str(self.model_type.__table__.update().filter_by(id=model.id).compile(dialect=sqlite.dialect())) - self.db.execute(query, model.to_list(model) + [model.id]) + query = self.model_type.__table__.update().filter_by(id=model.id) + self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list(model) + [model.id]) self.db.commit() + + @staticmethod + def __serialize_filter_values(value): + if isinstance(value, (dict, list)): + return json.dumps(value).replace('"', "'") + return value From 3e305089c7718effab701a90636b85e7fde1b342 Mon Sep 17 00:00:00 2001 From: trgiangdo Date: Sat, 11 Nov 2023 02:28:46 +0700 Subject: [PATCH 10/16] fix: remove unnecessary import --- src/taipy/core/_version/_cli/_version_cli.py | 2 +- tests/core/repository/mocks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/taipy/core/_version/_cli/_version_cli.py b/src/taipy/core/_version/_cli/_version_cli.py index dac3ef436..0932763bf 100644 --- a/src/taipy/core/_version/_cli/_version_cli.py +++ b/src/taipy/core/_version/_cli/_version_cli.py @@ -17,7 +17,7 @@ from taipy.logger._taipy_logger import _TaipyLogger from ...data._data_manager_factory import _DataManagerFactory -from ...exceptions.exceptions import ModelNotFound, VersionIsNotProductionVersion +from ...exceptions.exceptions import VersionIsNotProductionVersion from ...job._job_manager_factory import _JobManagerFactory from ...scenario._scenario_manager_factory import _ScenarioManagerFactory from ...sequence._sequence_manager_factory import _SequenceManagerFactory diff --git a/tests/core/repository/mocks.py b/tests/core/repository/mocks.py index 6cdce5a9f..a6e749921 100644 --- a/tests/core/repository/mocks.py +++ b/tests/core/repository/mocks.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional -from sqlalchemy import Column, String, Table, create_engine +from sqlalchemy import Column, String, Table from sqlalchemy.dialects import sqlite from sqlalchemy.orm import declarative_base, registry from sqlalchemy.schema import CreateTable From 147d706db57101eb04bbc60a0a9a1a4b7abee182 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Mon, 13 Nov 2023 10:45:54 +0700 Subject: [PATCH 11/16] remove type:ignore and clean code --- .../core/_repository/_base_taipy_model.py | 11 ++++ src/taipy/core/_repository/_sql_repository.py | 53 +++++++++---------- src/taipy/core/_version/_version_model.py | 15 +++--- .../core/_version/_version_sql_repository.py | 16 +++--- src/taipy/core/cycle/_cycle_model.py | 17 +++--- src/taipy/core/data/_data_model.py | 33 ++++++------ src/taipy/core/job/_job_model.py | 23 ++++---- src/taipy/core/scenario/_scenario_model.py | 27 +++++----- src/taipy/core/task/_task_model.py | 25 +++++---- tests/core/repository/mocks.py | 5 +- 10 files changed, 111 insertions(+), 114 deletions(-) diff --git a/src/taipy/core/_repository/_base_taipy_model.py b/src/taipy/core/_repository/_base_taipy_model.py index 07d36c919..68fff281e 100644 --- a/src/taipy/core/_repository/_base_taipy_model.py +++ b/src/taipy/core/_repository/_base_taipy_model.py @@ -14,11 +14,15 @@ import json from typing import Any, Dict +from sqlalchemy import Table + from ._decoder import _Decoder from ._encoder import _Encoder class _BaseModel: + __table__: Table + def __iter__(self): for attr, value in self.__dict__.items(): yield attr, value @@ -40,3 +44,10 @@ def _deserialize_attribute(value): if isinstance(value, str): return json.loads(value.replace("'", '"'), cls=_Decoder) return value + + @staticmethod + def from_dict(data: Dict[str, Any]): + pass + + def to_list(self): + pass diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index 745444524..28b4820e4 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -39,6 +39,7 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter]): self.db = _SQLConnection.init_db() self.model_type = model_type self.converter = converter + self.table = self.model_type.__table__ ############################### # ## Inherited methods ## # @@ -51,19 +52,19 @@ def _save(self, entity: Entity): self.__insert_model(obj) def _exists(self, entity_id: str): - query = self.model_type.__table__.select().filter_by(id=entity_id) # type: ignore + query = self.table.select().filter_by(id=entity_id) return bool(self.db.execute(str(query), [entity_id]).fetchone()) def _load(self, entity_id: str) -> Entity: - query = self.model_type.__table__.select().filter_by(id=entity_id) # type: ignore + query = self.table.select().filter_by(id=entity_id) if entry := self.db.execute(str(query.compile(dialect=sqlite.dialect())), [entity_id]).fetchone(): - entry = self.model_type.from_dict(entry) # type: ignore + entry = self.model_type.from_dict(entry) return self.converter._model_to_entity(entry) raise ModelNotFound(str(self.model_type.__name__), entity_id) def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]: - query = self.model_type.__table__.select() # type: ignore + query = self.table.select() entities: List[Entity] = [] for f in filters or [{}]: @@ -74,15 +75,13 @@ def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]: [self.__serialize_filter_values(val) for val in list(f.values())], ).fetchall() - entities.extend( - [self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries] # type: ignore - ) + entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries]) except NoResultFound: continue return entities def _delete(self, entity_id: str): - delete_query = self.model_type.__table__.delete().filter_by(id=entity_id) # type: ignore + delete_query = self.table.delete().filter_by(id=entity_id) cursor = self.db.execute(str(delete_query.compile(dialect=sqlite.dialect())), [entity_id]) self.db.commit() @@ -93,7 +92,7 @@ def _delete(self, entity_id: str): raise ModelNotFound(str(self.model_type.__name__), entity_id) def _delete_all(self): - self.db.execute(str(self.model_type.__table__.delete().compile(dialect=sqlite.dialect()))) # type: ignore + self.db.execute(str(self.table.delete().compile(dialect=sqlite.dialect()))) self.db.commit() def _delete_many(self, ids: Iterable[str]): @@ -101,13 +100,13 @@ def _delete_many(self, ids: Iterable[str]): self._delete(entity_id) def _delete_by(self, attribute: str, value: str): - delete_by_query = self.model_type.__table__.delete().filter_by(**{attribute: value}) # type: ignore + delete_by_query = self.table.delete().filter_by(**{attribute: value}) self.db.execute(str(delete_by_query.compile(dialect=sqlite.dialect())), [value]) self.db.commit() def _search(self, attribute: str, value: Any, filters: Optional[List[Dict]] = None) -> List[Entity]: - query = self.model_type.__table__.select().filter_by(**{attribute: value}) # type: ignore + query = self.table.select().filter_by(**{attribute: value}) entities: List[Entity] = [] for f in filters or [{}]: @@ -115,9 +114,7 @@ def _search(self, attribute: str, value: Any, filters: Optional[List[Dict]] = No str(query.filter_by(**f).compile(dialect=sqlite.dialect())), [value] + [self.__serialize_filter_values(val) for val in list(f.values())], ).fetchall() - entities.extend( - [self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries] # type: ignore - ) + entities.extend([self.converter._model_to_entity(self.model_type.from_dict(m)) for m in entries]) return entities @@ -127,13 +124,13 @@ def _export(self, entity_id: str, folder_path: Union[str, pathlib.Path]): else: folder = folder_path - export_dir = folder / self.model_type.__table__.name # type: ignore + export_dir = folder / self.table.name if not export_dir.exists(): export_dir.mkdir(parents=True) export_path = export_dir / f"{entity_id}.json" - query = self.model_type.__table__.select().filter_by(id=entity_id) # type: ignore + query = self.table.select().filter_by(id=entity_id) if entry := self.db.execute(str(query.compile(dialect=sqlite.dialect())), [entity_id]).fetchone(): with open(export_path, "w", encoding="utf-8") as export_file: @@ -145,11 +142,11 @@ def _export(self, entity_id: str, folder_path: Union[str, pathlib.Path]): # ## Specific or optimized methods ## # ########################################### def _get_multi(self, *, skip: int = 0, limit: int = 100) -> List[ModelType]: - query = self.model_type.__table__.select().offset(skip).limit(limit) # type: ignore + query = self.table.select().offset(skip).limit(limit) return self.db.execute(str(query.compile(dialect=sqlite.dialect()))).fetchall() def _get_by_config(self, config_id: Any) -> Optional[ModelType]: - query = self.model_type.__table__.select().filter_by(config_id=config_id) # type: ignore + query = self.table.select().filter_by(config_id=config_id) return self.db.execute(str(query.compile(dialect=sqlite.dialect())), [config_id]).fetchall() def _get_by_config_and_owner_id( @@ -185,8 +182,8 @@ def __get_entities_by_config_and_owner( filters = [] versions = [item.get("version") for item in filters if item.get("version")] - query = self.model_type.__table__.select().filter_by(config_id=config_id) # type: ignore - parameters = [config_id] + query = self.table.select().filter_by(config_id=config_id) + parameters: List = [config_id] if owner_id: parameters.append(owner_id) @@ -194,25 +191,25 @@ def __get_entities_by_config_and_owner( query = str(query.compile(dialect=sqlite.dialect())) if versions: - table_name = self.model_type.__table__.name # type: ignore - query = query + f" AND {table_name}.version IN ({','.join(['?']*len(versions))})" # type: ignore - parameters.extend(versions) # type: ignore + table_name = self.table.name + query = query + f" AND {table_name}.version IN ({','.join(['?']*len(versions))})" + parameters.extend(versions) if entry := self.db.execute(query, parameters).fetchone(): - return self.model_type.from_dict(entry) # type: ignore + return self.model_type.from_dict(entry) return None ############################# # ## Private methods ## # ############################# def __insert_model(self, model: ModelType): - query = self.model_type.__table__.insert() # type: ignore - self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list(model)) # type: ignore + query = self.table.insert() + self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list()) self.db.commit() def _update_entry(self, model): - query = self.model_type.__table__.update().filter_by(id=model.id) - self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list(model) + [model.id]) + query = self.table.update().filter_by(id=model.id) + self.db.execute(str(query.compile(dialect=sqlite.dialect())), model.to_list() + [model.id]) self.db.commit() @staticmethod diff --git a/src/taipy/core/_version/_version_model.py b/src/taipy/core/_version/_version_model.py index b4e1a0561..7105590cb 100644 --- a/src/taipy/core/_version/_version_model.py +++ b/src/taipy/core/_version/_version_model.py @@ -47,13 +47,12 @@ def from_dict(data: Dict[str, Any]): model.is_latest = data.get("is_latest") # type: ignore return model - @staticmethod - def to_list(model): + def to_list(self): return [ - model.id, - model.config, - model.creation_date, - model.is_production, - model.is_development, - model.is_latest, + self.id, + self.config, + self.creation_date, + self.is_production, + self.is_development, + self.is_latest, ] diff --git a/src/taipy/core/_version/_version_sql_repository.py b/src/taipy/core/_version/_version_sql_repository.py index ee173544b..535133e90 100644 --- a/src/taipy/core/_version/_version_sql_repository.py +++ b/src/taipy/core/_version/_version_sql_repository.py @@ -23,7 +23,7 @@ def __init__(self): super().__init__(model_type=_VersionModel, converter=_VersionConverter) def _set_latest_version(self, version_number): - if old_latest := self.db.execute(str(self.model_type.__table__.select().filter_by(is_latest=True))).fetchone(): + if old_latest := self.db.execute(str(self.table.select().filter_by(is_latest=True))).fetchone(): old_latest = self.model_type.from_dict(old_latest) old_latest.is_latest = False self._update_entry(old_latest) @@ -34,15 +34,13 @@ def _set_latest_version(self, version_number): def _get_latest_version(self): if latest := self.db.execute( - str(self.model_type.__table__.select().filter_by(is_latest=True).compile(dialect=sqlite.dialect())) + str(self.table.select().filter_by(is_latest=True).compile(dialect=sqlite.dialect())) ).fetchone(): return latest["id"] return "" def _set_development_version(self, version_number): - if old_development := self.db.execute( - str(self.model_type.__table__.select().filter_by(is_development=True)) - ).fetchone(): + if old_development := self.db.execute(str(self.table.select().filter_by(is_development=True))).fetchone(): old_development = self.model_type.from_dict(old_development) old_development.is_development = False self._update_entry(old_development) @@ -54,9 +52,7 @@ def _set_development_version(self, version_number): self._set_latest_version(version_number) def _get_development_version(self): - if development := self.db.execute( - str(self.model_type.__table__.select().filter_by(is_development=True)) - ).fetchone(): + if development := self.db.execute(str(self.table.select().filter_by(is_development=True))).fetchone(): return development["id"] raise ModelNotFound(self.model_type, "") @@ -69,7 +65,7 @@ def _set_production_version(self, version_number): def _get_production_versions(self): if productions := self.db.execute( - str(self.model_type.__table__.select().filter_by(is_production=True).compile(dialect=sqlite.dialect())), + str(self.table.select().filter_by(is_production=True).compile(dialect=sqlite.dialect())), ).fetchall(): return [p["id"] for p in productions] return [] @@ -83,6 +79,6 @@ def _delete_production_version(self, version_number): self._update_entry(version) def __get_by_id(self, version_id): - query = str(self.model_type.__table__.select().filter_by(id=version_id).compile(dialect=sqlite.dialect())) + query = str(self.table.select().filter_by(id=version_id).compile(dialect=sqlite.dialect())) entry = self.db.execute(query, [version_id]).fetchone() return self.model_type.from_dict(entry) if entry else None diff --git a/src/taipy/core/cycle/_cycle_model.py b/src/taipy/core/cycle/_cycle_model.py index da8122a89..bd2f1a1f4 100644 --- a/src/taipy/core/cycle/_cycle_model.py +++ b/src/taipy/core/cycle/_cycle_model.py @@ -55,14 +55,13 @@ def from_dict(data: Dict[str, Any]): end_date=data["end_date"], ) - @staticmethod - def to_list(model): + def to_list(self): return [ - model.id, - model.name, - repr(model.frequency), - _BaseModel._serialize_attribute(model.properties), - model.creation_date, - model.start_date, - model.end_date, + self.id, + self.name, + repr(self.frequency), + _BaseModel._serialize_attribute(self.properties), + self.creation_date, + self.start_date, + self.end_date, ] diff --git a/src/taipy/core/data/_data_model.py b/src/taipy/core/data/_data_model.py index d6ccbec02..dc8636ff8 100644 --- a/src/taipy/core/data/_data_model.py +++ b/src/taipy/core/data/_data_model.py @@ -81,22 +81,21 @@ def from_dict(data: Dict[str, Any]): data_node_properties=_BaseModel._deserialize_attribute(data["data_node_properties"]), ) - @staticmethod - def to_list(model): + def to_list(self): return [ - model.id, - model.config_id, - repr(model.scope), - model.storage_type, - model.owner_id, - _BaseModel._serialize_attribute(model.parent_ids), - model.last_edit_date, - _BaseModel._serialize_attribute(model.edits), - model.version, - model.validity_days, - model.validity_seconds, - model.edit_in_progress, - model.editor_id, - model.editor_expiration_date, - _BaseModel._serialize_attribute(model.data_node_properties), + self.id, + self.config_id, + repr(self.scope), + self.storage_type, + self.owner_id, + _BaseModel._serialize_attribute(self.parent_ids), + self.last_edit_date, + _BaseModel._serialize_attribute(self.edits), + self.version, + self.validity_days, + self.validity_seconds, + self.edit_in_progress, + self.editor_id, + self.editor_expiration_date, + _BaseModel._serialize_attribute(self.data_node_properties), ] diff --git a/src/taipy/core/job/_job_model.py b/src/taipy/core/job/_job_model.py index e31e38da7..eab9bfe94 100644 --- a/src/taipy/core/job/_job_model.py +++ b/src/taipy/core/job/_job_model.py @@ -63,17 +63,16 @@ def from_dict(data: Dict[str, Any]): version=data["version"], ) - @staticmethod - def to_list(model): + def to_list(self): return [ - model.id, - model.task_id, - repr(model.status), - model.force, - model.submit_id, - model.submit_entity_id, - model.creation_date, - _BaseModel._serialize_attribute(model.subscribers), - _BaseModel._serialize_attribute(model.stacktrace), - model.version, + self.id, + self.task_id, + repr(self.status), + self.force, + self.submit_id, + self.submit_entity_id, + self.creation_date, + _BaseModel._serialize_attribute(self.subscribers), + _BaseModel._serialize_attribute(self.stacktrace), + self.version, ] diff --git a/src/taipy/core/scenario/_scenario_model.py b/src/taipy/core/scenario/_scenario_model.py index 9eb9993af..54757b95e 100644 --- a/src/taipy/core/scenario/_scenario_model.py +++ b/src/taipy/core/scenario/_scenario_model.py @@ -71,19 +71,18 @@ def from_dict(data: Dict[str, Any]): cycle=CycleId(data["cycle"]) if "cycle" in data else None, ) - @staticmethod - def to_list(model): + def to_list(self): return [ - model.id, - model.config_id, - _BaseModel._serialize_attribute(model.tasks), - _BaseModel._serialize_attribute(model.additional_data_nodes), - _BaseModel._serialize_attribute(model.properties), - model.creation_date, - model.primary_scenario, - _BaseModel._serialize_attribute(model.subscribers), - _BaseModel._serialize_attribute(model.tags), - model.version, - _BaseModel._serialize_attribute(model.sequences), - model.cycle, + self.id, + self.config_id, + _BaseModel._serialize_attribute(self.tasks), + _BaseModel._serialize_attribute(self.additional_data_nodes), + _BaseModel._serialize_attribute(self.properties), + self.creation_date, + self.primary_scenario, + _BaseModel._serialize_attribute(self.subscribers), + _BaseModel._serialize_attribute(self.tags), + self.version, + _BaseModel._serialize_attribute(self.sequences), + self.cycle, ] diff --git a/src/taipy/core/task/_task_model.py b/src/taipy/core/task/_task_model.py index 370dc384b..2c671c1ee 100644 --- a/src/taipy/core/task/_task_model.py +++ b/src/taipy/core/task/_task_model.py @@ -64,18 +64,17 @@ def from_dict(data: Dict[str, Any]): properties=_BaseModel._deserialize_attribute(data["properties"] if "properties" in data.keys() else {}), ) - @staticmethod - def to_list(model): + def to_list(self): return [ - model.id, - model.owner_id, - _BaseModel._serialize_attribute(model.parent_ids), - model.config_id, - _BaseModel._serialize_attribute(model.input_ids), - model.function_name, - model.function_module, - _BaseModel._serialize_attribute(model.output_ids), - model.version, - model.skippable, - _BaseModel._serialize_attribute(model.properties), + self.id, + self.owner_id, + _BaseModel._serialize_attribute(self.parent_ids), + self.config_id, + _BaseModel._serialize_attribute(self.input_ids), + self.function_name, + self.function_module, + _BaseModel._serialize_attribute(self.output_ids), + self.version, + self.skippable, + _BaseModel._serialize_attribute(self.properties), ] diff --git a/tests/core/repository/mocks.py b/tests/core/repository/mocks.py index a6e749921..c5b1978aa 100644 --- a/tests/core/repository/mocks.py +++ b/tests/core/repository/mocks.py @@ -72,9 +72,8 @@ def _to_entity(self): def _from_entity(cls, entity: MockObj): return MockModel(id=entity.id, name=entity.name, version=entity._version) - @staticmethod - def to_list(model): - return [model.id, model.name, model.version] + def to_list(self): + return [self.id, self.name, self.version] class MockConverter(_AbstractConverter): From aed66778f95b90f308ea3aa060c5996646fa7d97 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Mon, 13 Nov 2023 11:06:41 +0700 Subject: [PATCH 12/16] minor improvement --- src/taipy/core/_repository/_sql_repository.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/taipy/core/_repository/_sql_repository.py b/src/taipy/core/_repository/_sql_repository.py index 28b4820e4..e0d8e79d9 100644 --- a/src/taipy/core/_repository/_sql_repository.py +++ b/src/taipy/core/_repository/_sql_repository.py @@ -34,7 +34,7 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter]): Attributes: model_type: Generic dataclass. converter: A class that handles conversion to and from a database backend - db: An SQLAlchemy session object + db: An sqlite3 session object """ self.db = _SQLConnection.init_db() self.model_type = model_type @@ -83,13 +83,11 @@ def _load_all(self, filters: Optional[List[Dict]] = None) -> List[Entity]: def _delete(self, entity_id: str): delete_query = self.table.delete().filter_by(id=entity_id) cursor = self.db.execute(str(delete_query.compile(dialect=sqlite.dialect())), [entity_id]) - self.db.commit() if cursor.rowcount == 0: raise ModelNotFound(str(self.model_type.__name__), entity_id) - if cursor.rowcount == 0: - raise ModelNotFound(str(self.model_type.__name__), entity_id) + self.db.commit() def _delete_all(self): self.db.execute(str(self.table.delete().compile(dialect=sqlite.dialect()))) From 422bd7f401787808211b5169e3cef79c5e737dd9 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Tue, 14 Nov 2023 18:01:54 +0700 Subject: [PATCH 13/16] fixed caching issue in sql connection --- .../core/_repository/db/_sql_connection.py | 17 ++++++++--------- tests/conftest.py | 5 +++-- .../core/job/test_job_manager_with_sql_repo.py | 9 --------- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/taipy/core/_repository/db/_sql_connection.py b/src/taipy/core/_repository/db/_sql_connection.py index 1cad4c1bf..1e19e1371 100644 --- a/src/taipy/core/_repository/db/_sql_connection.py +++ b/src/taipy/core/_repository/db/_sql_connection.py @@ -21,23 +21,23 @@ from ...exceptions import MissingRequiredProperty +def dict_factory(cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + + class _SQLConnection: _connection = None - @classmethod - def dict_factory(cls, cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - @classmethod def init_db(cls): if cls._connection: return cls._connection cls._connection = _build_connection() - cls._connection.row_factory = cls.dict_factory + cls._connection.row_factory = dict_factory from ..._version._version_model import _VersionModel from ...cycle._cycle_model import _CycleModel @@ -68,7 +68,6 @@ def init_db(cls): return cls._connection -@lru_cache def _build_connection() -> Connection: # Set SQLite threading mode to Serialized, means that threads may share the module, connections and cursors sqlite3.threadsafety = 3 diff --git a/tests/conftest.py b/tests/conftest.py index 161f06744..3107d405d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from src.taipy.core._core import Core from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory -from src.taipy.core._repository.db._sql_connection import _build_connection +from src.taipy.core._repository.db._sql_connection import _build_connection, _SQLConnection from src.taipy.core._version._version import _Version from src.taipy.core._version._version_manager_factory import _VersionManagerFactory from src.taipy.core._version._version_model import _VersionModel @@ -445,7 +445,8 @@ def init_sql_repo(tmp_sqlite): Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite}) # Clean SQLite database - connection = _build_connection() + _SQLConnection._connection = None + connection = _SQLConnection.init_db() connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) connection.execute(str(DropTable(_JobModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) diff --git a/tests/core/job/test_job_manager_with_sql_repo.py b/tests/core/job/test_job_manager_with_sql_repo.py index 281695a7e..9fdafe525 100644 --- a/tests/core/job/test_job_manager_with_sql_repo.py +++ b/tests/core/job/test_job_manager_with_sql_repo.py @@ -52,11 +52,6 @@ def init_managers(): _JobManagerFactory._build_manager()._delete_all() -def clear_sql_session(): - _build_connection.cache_clear() - _SQLConnection._connection = None - - def test_create_jobs(init_sql_repo): Config.configure_job_executions(mode=JobConfig._DEVELOPMENT_MODE) init_managers() @@ -165,8 +160,6 @@ def test_delete_job(init_sql_repo): def test_raise_when_trying_to_delete_unfinished_job(init_sql_repo): - clear_sql_session() - Config.configure_job_executions(mode=JobConfig._STANDALONE_MODE, max_nb_of_workers=2) init_managers() @@ -196,8 +189,6 @@ def test_raise_when_trying_to_delete_unfinished_job(init_sql_repo): def test_force_deleting_unfinished_job(init_sql_repo): - clear_sql_session() - Config.configure_job_executions(mode=JobConfig._STANDALONE_MODE, max_nb_of_workers=2) init_managers() From 54cc59780aebd9fd3dac401f8f88cd5f1c893ef8 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Tue, 14 Nov 2023 18:08:33 +0700 Subject: [PATCH 14/16] added deserializable for stacktrace in job model --- src/taipy/core/job/_job_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/taipy/core/job/_job_model.py b/src/taipy/core/job/_job_model.py index eab9bfe94..98eb98ba5 100644 --- a/src/taipy/core/job/_job_model.py +++ b/src/taipy/core/job/_job_model.py @@ -59,7 +59,7 @@ def from_dict(data: Dict[str, Any]): submit_entity_id=data["submit_entity_id"], creation_date=data["creation_date"], subscribers=_BaseModel._deserialize_attribute(data["subscribers"]), - stacktrace=data["stacktrace"], + stacktrace=_BaseModel._deserialize_attribute(data["stacktrace"]), version=data["version"], ) From ab6280318e38245898a6d1401d35048732230438 Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Wed, 15 Nov 2023 10:26:16 +0700 Subject: [PATCH 15/16] added caching to build sqlite3 connection --- src/taipy/core/_repository/db/_sql_connection.py | 8 ++++++-- tests/conftest.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/taipy/core/_repository/db/_sql_connection.py b/src/taipy/core/_repository/db/_sql_connection.py index 1e19e1371..f01c7e040 100644 --- a/src/taipy/core/_repository/db/_sql_connection.py +++ b/src/taipy/core/_repository/db/_sql_connection.py @@ -78,5 +78,9 @@ def _build_connection() -> Connection: except KeyError: raise MissingRequiredProperty("Missing property db_location") - connection = sqlite3.connect(db_location, check_same_thread=False) - return connection + return __build_connection(db_location) + + +@lru_cache +def __build_connection(db_location: str): + return sqlite3.connect(db_location, check_same_thread=False) diff --git a/tests/conftest.py b/tests/conftest.py index 3107d405d..075c8f300 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -445,7 +445,9 @@ def init_sql_repo(tmp_sqlite): Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite}) # Clean SQLite database - _SQLConnection._connection = None + if _SQLConnection._connection: + _SQLConnection._connection.close() + _SQLConnection._connection = None connection = _SQLConnection.init_db() connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) From 9b4ddf251b0bba157eb903d702fb576258af651c Mon Sep 17 00:00:00 2001 From: Toan Quach Date: Wed, 15 Nov 2023 21:11:50 +0700 Subject: [PATCH 16/16] clean conftest --- tests/conftest.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 075c8f300..0a3cb07cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,15 +18,12 @@ import pandas as pd import pytest from sqlalchemy import create_engine, text -from sqlalchemy.dialects import sqlite -from sqlalchemy.schema import CreateTable, DropTable from src.taipy.core._core import Core from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory -from src.taipy.core._repository.db._sql_connection import _build_connection, _SQLConnection +from src.taipy.core._repository.db._sql_connection import _SQLConnection from src.taipy.core._version._version import _Version from src.taipy.core._version._version_manager_factory import _VersionManagerFactory -from src.taipy.core._version._version_model import _VersionModel from src.taipy.core.config import ( CoreSection, DataNodeConfig, @@ -49,7 +46,6 @@ from src.taipy.core.data._data_model import _DataNodeModel from src.taipy.core.data.in_memory import InMemoryDataNode from src.taipy.core.job._job_manager_factory import _JobManagerFactory -from src.taipy.core.job._job_model import _JobModel from src.taipy.core.job.job import Job from src.taipy.core.job.job_id import JobId from src.taipy.core.notification.notifier import Notifier @@ -61,7 +57,6 @@ from src.taipy.core.sequence.sequence import Sequence from src.taipy.core.sequence.sequence_id import SequenceId from src.taipy.core.task._task_manager_factory import _TaskManagerFactory -from src.taipy.core.task._task_model import _TaskModel from src.taipy.core.task.task import Task from taipy.config import _inject_section from taipy.config._config import _Config @@ -448,19 +443,6 @@ def init_sql_repo(tmp_sqlite): if _SQLConnection._connection: _SQLConnection._connection.close() _SQLConnection._connection = None - connection = _SQLConnection.init_db() - connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_JobModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_ScenarioModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_TaskModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(DropTable(_VersionModel.__table__, if_exists=True).compile(dialect=sqlite.dialect()))) - - connection.execute(str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) - connection.execute(str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))) + _SQLConnection.init_db() return tmp_sqlite