Skip to content

Commit

Permalink
feat(sdk): datastore support flush (#1348)
Browse files Browse the repository at this point in the history
add flush for datastore & add ut
  • Loading branch information
goldenxinxing authored Oct 17, 2022
1 parent 0eb3ed3 commit 71d6704
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 2 deletions.
15 changes: 13 additions & 2 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,7 @@ def __init__(
self._cond = threading.Condition()
self._stopped = False
self._records: List[Dict[str, Any]] = []
self._updating_records: List[Dict[str, Any]] = []
self._queue_run_exceptions: List[Exception] = []
self._run_exceptions_limits = max(run_exceptions_limits, 0)

Expand Down Expand Up @@ -1310,20 +1311,30 @@ def _insert(self, record: Dict[str, Any]) -> None:
self._records.append(record)
self._cond.notify()

def flush(self) -> None:
while True:
with self._cond:
if len(self._records) == 0 and len(self._updating_records) == 0:
break

def run(self) -> None:
while True:
with self._cond:
while not self._stopped and len(self._records) == 0:
self._cond.wait()
if len(self._records) == 0:
break
records = self._records
self._updating_records = self._records
self._records = []

try:
self.data_store.update_table(self.table_name, self.schema, records)
self.data_store.update_table(
self.table_name, self.schema, self._updating_records
)
except Exception as e:
logger.warning(f"{self} run-update-table raise exception: {e}")
self._queue_run_exceptions.append(e)
if len(self._queue_run_exceptions) > self._run_exceptions_limits:
break
finally:
self._updating_records = []
2 changes: 2 additions & 0 deletions client/starwhale/api/_impl/dataset/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def make_swds(self) -> DatasetSummary:

increased_rows += 1

self.tabular_dataset.flush()

try:
empty = dwriter.tell() == 0
dwriter.close()
Expand Down
19 changes: 19 additions & 0 deletions client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def _log(self, table_name: str, record: Dict[str, Any]) -> None:

writer.insert(record)

def _flush(self, table_name: str) -> None:
with self._lock:
writer = self._writers.get(table_name)
if writer is None:
return
writer.flush()


def _serialize(data: Any) -> Any:
return dill.dumps(data)
Expand Down Expand Up @@ -132,6 +139,15 @@ def get(self, table_name: str) -> Iterator[Dict[str, Any]]:
[data_store.TableDesc(self._get_datastore_table_name(table_name))]
)

def flush_result(self) -> None:
self._flush(self._results_table_name)

def flush_metrics(self) -> None:
self._flush(self._summary_table_name)

def flush(self, table_name: str) -> None:
self._flush(table_name)


class Dataset(Logger):
def __init__(self, dataset_id: str, project: str) -> None:
Expand All @@ -158,6 +174,9 @@ def scan(self, start: Any, end: Any) -> Iterator[Dict[str, Any]]:
[data_store.TableDesc(self._meta_table_name)], start=start, end=end
)

def flush(self) -> None:
self._flush(self._meta_table_name)

def __str__(self) -> str:
return f"Dataset Wrapper, table:{self._meta_table_name}"

Expand Down
3 changes: 3 additions & 0 deletions client/starwhale/core/dataset/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def update(self, row_id: int, **kw: t.Union[int, str, bytes]) -> None:
def put(self, row: TabularDatasetRow) -> None:
self._ds_wrapper.put(row.id, **row.asdict())

def flush(self) -> None:
self._ds_wrapper.flush()

def scan(
self, start: int = 0, end: int = sys.maxsize
) -> t.Generator[TabularDatasetRow, None, None]:
Expand Down
48 changes: 48 additions & 0 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from starwhale.consts import HTTPMethod
from starwhale.api._impl import data_store
from starwhale.api._impl.data_store import TableWriterException

from .test_base import BaseTestCase

Expand Down Expand Up @@ -1829,6 +1830,53 @@ def tearDown(self) -> None:
self.writer.close()
super().tearDown()

def test_writer(self):
_writer = data_store.TableWriter("p/test_flush", "id")
for i in range(0, 10):
_writer.insert({"id": i, "result": f"data-{i}"})
with self.assertRaises(RuntimeError):
list(_writer.data_store.scan_tables([data_store.TableDesc("p/test_flush")]))
_writer.close()

_writer2 = data_store.TableWriter("p/test_flush2", "id")
for i in range(0, 10):
_writer2.insert({"id": i, "result": f"data-{i}"})
_writer2.flush()
self.assertEqual(
len(
list(
_writer.data_store.scan_tables(
[data_store.TableDesc("p/test_flush2")]
)
)
),
10,
)
_writer2.close()

_writer3 = data_store.TableWriter("p/test_flush3", "id")
_writer3.insert({"id": 0, "result": "data-0"})
_writer3.flush()
with patch(
"starwhale.api._impl.data_store.LocalDataStore.update_table"
) as update_table:
update_table.side_effect = RuntimeError()
for i in range(1, 11):
_writer3.insert({"id": i, "result": f"data-{i}"})
_writer3.flush()
self.assertEqual(
len(
list(
_writer.data_store.scan_tables(
[data_store.TableDesc("p/test_flush3")]
)
)
),
1,
)
with self.assertRaises(TableWriterException):
_writer3.close()

def test_insert_and_delete(self) -> None:
with self.assertRaises(RuntimeError, msg="no key"):
self.writer.insert({"a": "0"})
Expand Down

0 comments on commit 71d6704

Please sign in to comment.