Skip to content

Commit

Permalink
Add types to descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
kesmit13 committed Nov 28, 2023
1 parent 690016c commit 788e462
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 5 deletions.
176 changes: 171 additions & 5 deletions singlestoredb/fusion/cache/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
from __future__ import annotations

import datetime
import json
import sqlite3
Expand All @@ -18,9 +20,13 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from .. import result

CACHE_NAME = 'file:fusion?mode=memory&cache=shared'
SCHEMA = r'''
Expand Down Expand Up @@ -148,9 +154,165 @@ def convert_json(val: Optional[bytes]) -> Optional[Any]:
sqlite3.register_converter('json', convert_json)


def dict_factory(cursor: Any, row: Tuple[Any, ...]) -> Dict[str, Any]:
def dict_factory(
desc: Any,
row: Union[Tuple[Any, ...], Dict[str, Any]],
) -> Dict[str, Any]:
"""Return row as a dictionary."""
return {k[0]: v for k, v in zip(cursor.description, row)}
if isinstance(row, dict):
return row
return {k[0]: v for k, v in zip(desc, row)}


class Cursor(sqlite3.Cursor):

def __init__(self, connection: Connection):
super().__init__(connection)
self._results: List[Union[Tuple[Any, ...], Dict[str, Any]]] = []
self._results_idx: int = 0
self._description: List[result.Description] = []

def _set_description(self, description: Any) -> None:
if description is None:
self._description = []
return

desc = [list(x) + [None, None] for x in description]
out: Dict[int, result.Description] = {}
for row in self._results:
for i, item in enumerate(row):
if isinstance(item, float):
fields = list(desc[i])
fields[1] = result.DOUBLE
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, int):
fields = list(desc[i])
fields[1] = result.INTEGER
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, str):
fields = list(desc[i])
fields[1] = result.STRING
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, bytes):
fields = list(desc[i])
fields[1] = result.BLOB
fields[8] = 63
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, datetime.datetime):
fields = list(desc[i])
fields[1] = result.DATETIME
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, datetime.date):
fields = list(desc[i])
fields[1] = result.DATE
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, datetime.timedelta):
fields = list(desc[i])
fields[1] = result.TIME
fields[6] = True
out[i] = result.Description(*fields)

elif isinstance(item, (list, dict)):
fields = list(desc[i])
fields[1] = result.JSON
fields[6] = True
out[i] = result.Description(*fields)

elif item is None:
if desc[i][1] is None:
fields = list(desc[i])
fields[1] = result.NULL
fields[6] = True
out[i] = result.Description(*fields)

else:
raise TypeError(f'unrecognized data type: {item}')

if len(out) == len(desc):
break

else:
for i, d in enumerate(desc):
fields = list(d)
fields[1] = result.NULL
fields[6] = True
out[i] = result.Description(*fields)

self._description = [result.Description(*v) for k, v in sorted(out.items())]

@property
def description(self) -> List[result.Description]:
return self._description

def __iter__(self) -> Iterator[Any]: # type: ignore
return iter(self._results[self._results_idx:])

def fetchall(self) -> List[Any]:
if self._results_idx >= len(self._results):
return []
out = self._results[self._results_idx:]
self._results_idx += len(self._results)
return out

def fetchone(self) -> Any:
if self._results_idx >= len(self._results):
return None
out = self._results[self._results_idx]
self._results_idx += 1
return out

def fetchmany(self, batchsize: Optional[int] = 1) -> List[Any]:
batchsize = batchsize or 1
if self._results_idx >= len(self._results):
return []
out = self._results[self._results_idx:self._results_idx+batchsize]
self._results_idx += batchsize
return out

def execute(self, query: str, params: Optional[Any] = None) -> Cursor:
self._results_idx = 0
if params is None:
out = super().execute(query)
else:
out = super().execute(query, params)
desc = super().description
self._results = super().fetchall()
self._set_description(desc)
if self.connection.results_type == 'dicts': # type: ignore
self._results = [dict_factory(self._description, x) for x in self._results]
return out

def executemany(self, query: str, params: Optional[Any] = None) -> Cursor:
self._results_idx = 0
out = super().executemany(query, params or [])
desc = super().description
self._results = super().fetchall()
self._set_description(desc)
if self.connection.results_type == 'dicts': # type: ignore
self._results = [dict_factory(self._description, x) for x in self._results]
return out


class Connection(sqlite3.Connection):

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.results_type: str = 'tuples'

def cursor(self) -> Cursor: # type: ignore
return Cursor(self)


class Cache(object):
Expand All @@ -170,7 +332,10 @@ def __init__(self) -> None:
self.table_fields[table[0]] = info

def connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(':memory:', uri=True, detect_types=sqlite3.PARSE_DECLTYPES)
conn = sqlite3.connect(
':memory:', uri=True,
detect_types=sqlite3.PARSE_DECLTYPES, factory=Connection,
)
conn.cursor().execute(f'ATTACH "{CACHE_NAME}" AS fusion')
return conn

Expand All @@ -191,8 +356,8 @@ def update(self, table: str, func: Callable[[], Any]) -> List[Dict[str, Any]]:

# If we are within the cache timeout, return the current results
if stats and (datetime.datetime.now() - stats[0][1]) < CACHE_TIMEOUTS[table]:
conn.row_factory = dict_factory
return list(conn.cursor().execute(f'SELECT * FROM fusion.{table}'))
conn.results_type = 'dicts' # type: ignore
return list(cur.execute(f'SELECT * FROM fusion.{table}'))

# Build query components
columns = [x[1] for x in self.table_fields[table]]
Expand Down Expand Up @@ -228,5 +393,6 @@ def update(self, table: str, func: Callable[[], Any]) -> List[Dict[str, Any]]:

except Exception:
cur.execute('ROLLBACK')
raise

return values
2 changes: 2 additions & 0 deletions singlestoredb/fusion/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from ..mysql.constants.FIELD_TYPE import DOUBLE # noqa: F401
from ..mysql.constants.FIELD_TYPE import JSON # noqa: F401
from ..mysql.constants.FIELD_TYPE import LONGLONG as INTEGER # noqa: F401
from ..mysql.constants.FIELD_TYPE import NULL # noqa: F401
from ..mysql.constants.FIELD_TYPE import STRING # noqa: F401
from ..mysql.constants.FIELD_TYPE import TIME # noqa: F401
from ..utils.results import Description
from ..utils.results import format_results

Expand Down

0 comments on commit 788e462

Please sign in to comment.