From 788e4623ec094e3947ff46aa850c94760a2898db Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 28 Nov 2023 10:15:50 -0600 Subject: [PATCH] Add types to descriptions --- singlestoredb/fusion/cache/sqlite.py | 176 ++++++++++++++++++++++++++- singlestoredb/fusion/result.py | 2 + 2 files changed, 173 insertions(+), 5 deletions(-) diff --git a/singlestoredb/fusion/cache/sqlite.py b/singlestoredb/fusion/cache/sqlite.py index 7a4f9b30..4a1f1856 100644 --- a/singlestoredb/fusion/cache/sqlite.py +++ b/singlestoredb/fusion/cache/sqlite.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import datetime import json import sqlite3 @@ -18,9 +20,13 @@ from typing import Any from typing import Callable from typing import Dict +from typing import Iterator from typing import List from typing import Optional from typing import Tuple +from typing import Union + +from .. import result CACHE_NAME = 'file:fusion?mode=memory&cache=shared' SCHEMA = r''' @@ -148,9 +154,165 @@ def convert_json(val: Optional[bytes]) -> Optional[Any]: sqlite3.register_converter('json', convert_json) -def dict_factory(cursor: Any, row: Tuple[Any, ...]) -> Dict[str, Any]: +def dict_factory( + desc: Any, + row: Union[Tuple[Any, ...], Dict[str, Any]], +) -> Dict[str, Any]: """Return row as a dictionary.""" - return {k[0]: v for k, v in zip(cursor.description, row)} + if isinstance(row, dict): + return row + return {k[0]: v for k, v in zip(desc, row)} + + +class Cursor(sqlite3.Cursor): + + def __init__(self, connection: Connection): + super().__init__(connection) + self._results: List[Union[Tuple[Any, ...], Dict[str, Any]]] = [] + self._results_idx: int = 0 + self._description: List[result.Description] = [] + + def _set_description(self, description: Any) -> None: + if description is None: + self._description = [] + return + + desc = [list(x) + [None, None] for x in description] + out: Dict[int, result.Description] = {} + for row in self._results: + for i, item in enumerate(row): + if isinstance(item, float): + fields = list(desc[i]) + fields[1] = result.DOUBLE + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, int): + fields = list(desc[i]) + fields[1] = result.INTEGER + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, str): + fields = list(desc[i]) + fields[1] = result.STRING + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, bytes): + fields = list(desc[i]) + fields[1] = result.BLOB + fields[8] = 63 + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, datetime.datetime): + fields = list(desc[i]) + fields[1] = result.DATETIME + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, datetime.date): + fields = list(desc[i]) + fields[1] = result.DATE + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, datetime.timedelta): + fields = list(desc[i]) + fields[1] = result.TIME + fields[6] = True + out[i] = result.Description(*fields) + + elif isinstance(item, (list, dict)): + fields = list(desc[i]) + fields[1] = result.JSON + fields[6] = True + out[i] = result.Description(*fields) + + elif item is None: + if desc[i][1] is None: + fields = list(desc[i]) + fields[1] = result.NULL + fields[6] = True + out[i] = result.Description(*fields) + + else: + raise TypeError(f'unrecognized data type: {item}') + + if len(out) == len(desc): + break + + else: + for i, d in enumerate(desc): + fields = list(d) + fields[1] = result.NULL + fields[6] = True + out[i] = result.Description(*fields) + + self._description = [result.Description(*v) for k, v in sorted(out.items())] + + @property + def description(self) -> List[result.Description]: + return self._description + + def __iter__(self) -> Iterator[Any]: # type: ignore + return iter(self._results[self._results_idx:]) + + def fetchall(self) -> List[Any]: + if self._results_idx >= len(self._results): + return [] + out = self._results[self._results_idx:] + self._results_idx += len(self._results) + return out + + def fetchone(self) -> Any: + if self._results_idx >= len(self._results): + return None + out = self._results[self._results_idx] + self._results_idx += 1 + return out + + def fetchmany(self, batchsize: Optional[int] = 1) -> List[Any]: + batchsize = batchsize or 1 + if self._results_idx >= len(self._results): + return [] + out = self._results[self._results_idx:self._results_idx+batchsize] + self._results_idx += batchsize + return out + + def execute(self, query: str, params: Optional[Any] = None) -> Cursor: + self._results_idx = 0 + if params is None: + out = super().execute(query) + else: + out = super().execute(query, params) + desc = super().description + self._results = super().fetchall() + self._set_description(desc) + if self.connection.results_type == 'dicts': # type: ignore + self._results = [dict_factory(self._description, x) for x in self._results] + return out + + def executemany(self, query: str, params: Optional[Any] = None) -> Cursor: + self._results_idx = 0 + out = super().executemany(query, params or []) + desc = super().description + self._results = super().fetchall() + self._set_description(desc) + if self.connection.results_type == 'dicts': # type: ignore + self._results = [dict_factory(self._description, x) for x in self._results] + return out + + +class Connection(sqlite3.Connection): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.results_type: str = 'tuples' + + def cursor(self) -> Cursor: # type: ignore + return Cursor(self) class Cache(object): @@ -170,7 +332,10 @@ def __init__(self) -> None: self.table_fields[table[0]] = info def connect(self) -> sqlite3.Connection: - conn = sqlite3.connect(':memory:', uri=True, detect_types=sqlite3.PARSE_DECLTYPES) + conn = sqlite3.connect( + ':memory:', uri=True, + detect_types=sqlite3.PARSE_DECLTYPES, factory=Connection, + ) conn.cursor().execute(f'ATTACH "{CACHE_NAME}" AS fusion') return conn @@ -191,8 +356,8 @@ def update(self, table: str, func: Callable[[], Any]) -> List[Dict[str, Any]]: # If we are within the cache timeout, return the current results if stats and (datetime.datetime.now() - stats[0][1]) < CACHE_TIMEOUTS[table]: - conn.row_factory = dict_factory - return list(conn.cursor().execute(f'SELECT * FROM fusion.{table}')) + conn.results_type = 'dicts' # type: ignore + return list(cur.execute(f'SELECT * FROM fusion.{table}')) # Build query components columns = [x[1] for x in self.table_fields[table]] @@ -228,5 +393,6 @@ def update(self, table: str, func: Callable[[], Any]) -> List[Dict[str, Any]]: except Exception: cur.execute('ROLLBACK') + raise return values diff --git a/singlestoredb/fusion/result.py b/singlestoredb/fusion/result.py index c48317b8..884e136e 100644 --- a/singlestoredb/fusion/result.py +++ b/singlestoredb/fusion/result.py @@ -17,7 +17,9 @@ from ..mysql.constants.FIELD_TYPE import DOUBLE # noqa: F401 from ..mysql.constants.FIELD_TYPE import JSON # noqa: F401 from ..mysql.constants.FIELD_TYPE import LONGLONG as INTEGER # noqa: F401 +from ..mysql.constants.FIELD_TYPE import NULL # noqa: F401 from ..mysql.constants.FIELD_TYPE import STRING # noqa: F401 +from ..mysql.constants.FIELD_TYPE import TIME # noqa: F401 from ..utils.results import Description from ..utils.results import format_results