From a9f4772970101f208610dc65d66cde69cae35465 Mon Sep 17 00:00:00 2001 From: tianweidut Date: Mon, 5 Sep 2022 19:05:49 +0800 Subject: [PATCH] refactor dataset build --- .gitignore | 2 + client/starwhale/api/_impl/data_store.py | 2 - .../starwhale/api/_impl/dataset/__init__.py | 20 +- client/starwhale/api/_impl/dataset/builder.py | 172 ++++------ client/starwhale/api/_impl/dataset/loader.py | 27 +- client/starwhale/api/_impl/dataset/mnist.py | 33 -- client/starwhale/api/_impl/metric.py | 51 ++- client/starwhale/api/_impl/model.py | 184 +++-------- client/starwhale/api/_impl/wrapper.py | 2 +- client/starwhale/api/dataset.py | 18 +- client/starwhale/base/mixin.py | 5 + client/starwhale/core/dataset/model.py | 10 +- client/starwhale/core/dataset/store.py | 4 +- client/starwhale/core/dataset/tabular.py | 95 ++++-- client/starwhale/core/dataset/type.py | 296 ++++++++++++++++-- client/starwhale/core/job/model.py | 2 - client/starwhale/utils/flatten.py | 3 +- client/tests/core/test_dataset.py | 1 - client/tests/core/test_model.py | 4 +- client/tests/sdk/test_dataset.py | 103 +++--- client/tests/sdk/test_loader.py | 108 ++++--- client/tests/sdk/test_model.py | 30 +- docs/docs/tutorials/pfp.md | 75 ----- docs/docs/tutorials/speech.md | 4 +- example/PennFudanPed/code/data_slicer.py | 141 +++++---- example/PennFudanPed/code/ds.py | 3 +- example/PennFudanPed/code/ppl.py | 60 +--- example/PennFudanPed/dataset.yaml | 6 +- example/cifar10/code/data_slicer.py | 49 +-- example/cifar10/code/ppl.py | 27 +- example/cifar10/dataset.yaml | 6 +- example/mnist/dataset.yaml | 4 +- example/mnist/mnist/ppl.py | 51 ++- example/mnist/mnist/process.py | 127 ++++---- example/nmt/code/dataset.py | 70 +---- example/nmt/code/helper.py | 52 +-- example/nmt/code/ppl.py | 19 +- example/nmt/dataset.yaml | 6 +- example/speech_command/code/data_slicer.py | 87 ++--- example/speech_command/code/ppl.py | 27 +- example/speech_command/dataset.yaml | 6 +- example/text_cls_AG_NEWS/code/data_slicer.py | 35 +-- example/text_cls_AG_NEWS/code/ppl.py | 34 +- example/text_cls_AG_NEWS/dataset.yaml | 9 +- scripts/run_demo.sh | 4 +- 45 files changed, 996 insertions(+), 1078 deletions(-) delete mode 100644 client/starwhale/api/_impl/dataset/mnist.py diff --git a/.gitignore b/.gitignore index ef9776e7ab..2c63360c9d 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,8 @@ example/PennFudanPed/data/ example/PennFudanPed/models/ example/nmt/models/ example/nmt/data/ +example/speech_command/data/ +example/speech_command/models/ # UI node_modules diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index 8761c91e3f..33300f009c 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -896,9 +896,7 @@ def __init__( yield r def dump(self) -> None: - logger.debug(f"start dump, tables size:{len(self.tables.values())}") for table in list(self.tables.values()): - logger.debug(f"dump {table.table_name} to {self.root_path}") table.dump(self.root_path) diff --git a/client/starwhale/api/_impl/dataset/__init__.py b/client/starwhale/api/_impl/dataset/__init__.py index c38eb0e391..74e690b771 100644 --- a/client/starwhale/api/_impl/dataset/__init__.py +++ b/client/starwhale/api/_impl/dataset/__init__.py @@ -1,14 +1,20 @@ from starwhale.core.dataset.type import ( Link, + Text, + Audio, + Image, + Binary, LinkType, MIMEType, - DataField, + ClassLabel, S3LinkAuth, + BoundingBox, + GrayscaleImage, LocalFSLinkAuth, DefaultS3LinkAuth, + COCOObjectAnnotation, ) -from .mnist import MNISTBuildExecutor from .loader import get_data_loader, SWDSBinDataLoader, UserRawDataLoader from .builder import BuildExecutor, SWDSBinBuildExecutor, UserRawBuildExecutor @@ -20,11 +26,17 @@ "S3LinkAuth", "MIMEType", "LinkType", - "DataField", "BuildExecutor", # SWDSBinBuildExecutor alias "UserRawBuildExecutor", "SWDSBinBuildExecutor", - "MNISTBuildExecutor", "SWDSBinDataLoader", "UserRawDataLoader", + "Binary", + "Text", + "Audio", + "Image", + "ClassLabel", + "BoundingBox", + "GrayscaleImage", + "COCOObjectAnnotation", ] diff --git a/client/starwhale/api/_impl/dataset/builder.py b/client/starwhale/api/_impl/dataset/builder.py index 1268e74959..5dbcd96f82 100644 --- a/client/starwhale/api/_impl/dataset/builder.py +++ b/client/starwhale/api/_impl/dataset/builder.py @@ -1,6 +1,3 @@ -from __future__ import annotations - -import sys import struct import typing as t import tempfile @@ -19,8 +16,10 @@ from starwhale.core.dataset import model from starwhale.core.dataset.type import ( Link, + Binary, LinkAuth, MIMEType, + BaseArtifact, DatasetSummary, D_ALIGNMENT_SIZE, D_FILE_VOLUME_SIZE, @@ -45,10 +44,7 @@ def __init__( dataset_name: str, dataset_version: str, project_name: str, - data_dir: Path = Path("."), workdir: Path = Path("./sw_output"), - data_filter: str = "*", - label_filter: str = "*", alignment_bytes_size: int = D_ALIGNMENT_SIZE, volume_bytes_size: int = D_FILE_VOLUME_SIZE, append: bool = False, @@ -58,10 +54,6 @@ def __init__( ) -> None: # TODO: add more docstring for args # TODO: validate group upper and lower? - self.data_dir = data_dir - self.data_filter = data_filter - self.label_filter = label_filter - self.workdir = workdir self.data_output_dir = workdir / "data" ensure_dir(self.data_output_dir) @@ -118,6 +110,10 @@ def __exit__( print("cleanup done.") + @abstractmethod + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + raise NotImplementedError + @abstractmethod def make_swds(self) -> DatasetSummary: raise NotImplementedError @@ -128,59 +124,16 @@ def _merge_forked_summary(self, s: DatasetSummary) -> DatasetSummary: s.rows += _fs.rows s.unchanged_rows += _fs.rows s.data_byte_size += _fs.data_byte_size - s.label_byte_size += _fs.label_byte_size + s.annotations = list(set(s.annotations) | set(_fs.annotations)) s.include_link |= _fs.include_link s.include_user_raw |= _fs.include_user_raw return s - def _iter_files( - self, filter: str, sort_key: t.Optional[t.Any] = None - ) -> t.Generator[Path, None, None]: - _key = sort_key - if _key is not None and not callable(_key): - raise Exception(f"data_sort_func({_key}) is not callable.") - - _files = sorted(self.data_dir.rglob(filter), key=_key) - for p in _files: - if not p.is_file(): - continue - yield p - - def iter_data_files(self) -> t.Generator[Path, None, None]: - return self._iter_files(self.data_filter, self.data_sort_func()) - - def iter_label_files(self) -> t.Generator[Path, None, None]: - return self._iter_files(self.label_filter, self.label_sort_func()) - - def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]: - for p in self.iter_data_files(): - for d in self.iter_data_slice(str(p.absolute())): - yield p, d - - def iter_all_label_slice(self) -> t.Generator[t.Any, None, None]: - for p in self.iter_label_files(): - for d in self.iter_label_slice(str(p.absolute())): - yield p, d - - @abstractmethod - def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]: - raise NotImplementedError - - @abstractmethod - def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]: - raise NotImplementedError - @property def data_format_type(self) -> DataFormatType: raise NotImplementedError - def data_sort_func(self) -> t.Any: - return None - - def label_sort_func(self) -> t.Any: - return None - class SWDSBinBuildExecutor(BaseBuildExecutor): """ @@ -244,26 +197,32 @@ def make_swds(self) -> DatasetSummary: ds_copy_candidates[fno] = dwriter_path increased_rows = 0 - total_label_size, total_data_size = 0, 0 + total_data_size = 0 + dataset_annotations: t.Dict[str, t.Any] = {} - for idx, ((_, data), (_, label)) in enumerate( - zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()), - start=self._forked_last_idx + 1, + for idx, (row_data, row_annotations) in enumerate( + self.iter_item(), start=self._forked_last_idx + 1 ): - if isinstance(data, (tuple, list)): - _data_content, _data_mime_type = data + if not isinstance(row_annotations, dict): + raise FormatError(f"annotations({row_annotations}) must be dict type") + + _artifact: BaseArtifact + if isinstance(row_data, bytes): + _artifact = Binary(row_data, self.default_data_mime_type) + elif isinstance(row_data, BaseArtifact): + _artifact = row_data else: - _data_content, _data_mime_type = data, self.default_data_mime_type + raise NoSupportError(f"data type {type(row_data)}") - if not isinstance(_data_content, bytes): - raise FormatError("data content must be bytes type") + if not dataset_annotations: + # TODO: check annotations type and name + dataset_annotations = row_annotations - _bin_section = self._write(dwriter, _data_content) + _bin_section = self._write(dwriter, _artifact.to_bytes()) self.tabular_dataset.put( TabularDatasetRow( id=idx, data_uri=str(fno), - label=label, data_format=self.data_format_type, object_store_type=ObjectStoreType.LOCAL, data_offset=_bin_section.raw_data_offset, @@ -271,12 +230,12 @@ def make_swds(self) -> DatasetSummary: _swds_bin_offset=_bin_section.offset, _swds_bin_size=_bin_section.size, data_origin=DataOriginType.NEW, - data_mime_type=_data_mime_type or self.default_data_mime_type, + data_type=_artifact.astype(), + annotations=row_annotations, ) ) total_data_size += _bin_section.size - total_label_size += sys.getsizeof(label) wrote_size += _bin_section.size if wrote_size > self.volume_bytes_size: @@ -303,10 +262,10 @@ def make_swds(self) -> DatasetSummary: summary = DatasetSummary( rows=increased_rows, increased_rows=increased_rows, - label_byte_size=total_label_size, data_byte_size=total_data_size, include_user_raw=False, include_link=False, + annotations=list(dataset_annotations.keys()), ) return self._merge_forked_summary(summary) @@ -335,18 +294,12 @@ def _copy_files( _dest_path.symlink_to(_obj_path) + # TODO: tune performance scan after put in a second for row in self.tabular_dataset.scan(*row_pos): self.tabular_dataset.update( row_id=row.id, data_uri=map_fno_sign[int(row.data_uri)] ) - def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]: - with Path(path).open() as f: - yield f.read() - - def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]: - yield Path(path).name - BuildExecutor = SWDSBinBuildExecutor @@ -354,20 +307,40 @@ def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]: class UserRawBuildExecutor(BaseBuildExecutor): def make_swds(self) -> DatasetSummary: increased_rows = 0 - total_label_size, total_data_size = 0, 0 + total_data_size = 0 auth_candidates = {} include_link = False map_path_sign: t.Dict[str, t.Tuple[str, Path]] = {} + dataset_annotations: t.Dict[str, t.Any] = {} - for idx, (data, (_, label)) in enumerate( - zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()), + for idx, (row_data, row_annotations) in enumerate( + self.iter_item(), start=self._forked_last_idx + 1, ): - if isinstance(data, Link): - _remote_link = data + if not isinstance(row_annotations, dict): + raise FormatError(f"annotations({row_annotations}) must be dict type") + + if not dataset_annotations: + # TODO: check annotations type and name + dataset_annotations = row_annotations + + if not isinstance(row_data, Link): + raise FormatError(f"data({row_data}) must be Link type") + + if row_data.with_local_fs_data: + _local_link = row_data + _data_fpath = _local_link.uri + if _data_fpath not in map_path_sign: + map_path_sign[_data_fpath] = DatasetStorage.save_data_file( + _data_fpath + ) + data_uri, _ = map_path_sign[_data_fpath] + auth = "" + object_store_type = ObjectStoreType.LOCAL + else: + _remote_link = row_data data_uri = _remote_link.uri - data_offset, data_size = _remote_link.offset, _remote_link.size if _remote_link.auth: auth = _remote_link.auth.name auth_candidates[ @@ -377,42 +350,23 @@ def make_swds(self) -> DatasetSummary: auth = "" object_store_type = ObjectStoreType.REMOTE include_link = True - data_mime_type = _remote_link.mime_type - elif isinstance(data, (tuple, list)): - _data_fpath, _local_link = data - if _data_fpath not in map_path_sign: - map_path_sign[_data_fpath] = DatasetStorage.save_data_file( - _data_fpath - ) - - if not isinstance(_local_link, Link): - raise NoSupportError("data only support Link type") - - data_mime_type = _local_link.mime_type - data_offset, data_size = _local_link.offset, _local_link.size - data_uri, _ = map_path_sign[_data_fpath] - auth = "" - object_store_type = ObjectStoreType.LOCAL - else: - raise FormatError(f"data({data}) type error, no list, tuple or Link") self.tabular_dataset.put( TabularDatasetRow( id=idx, data_uri=data_uri, - label=label, data_format=self.data_format_type, object_store_type=object_store_type, - data_offset=data_offset, - data_size=data_size, + data_offset=row_data.offset, + data_size=row_data.size, data_origin=DataOriginType.NEW, auth_name=auth, - data_mime_type=data_mime_type, + data_type=row_data.astype(), + annotations=row_annotations, ) ) - total_data_size += data_size - total_label_size += sys.getsizeof(label) + total_data_size += row_data.size increased_rows += 1 self._copy_files(map_path_sign) @@ -422,10 +376,10 @@ def make_swds(self) -> DatasetSummary: summary = DatasetSummary( rows=increased_rows, increased_rows=increased_rows, - label_byte_size=total_label_size, data_byte_size=total_data_size, include_link=include_link, include_user_raw=True, + annotations=list(dataset_annotations.keys()), ) return self._merge_forked_summary(summary) @@ -444,12 +398,6 @@ def _copy_auth(self, auth_candidates: t.Dict[str, LinkAuth]) -> None: for auth in auth_candidates.values(): f.write("\n".join(auth.dump_env())) - def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]: - yield 0, Path(path).stat().st_size - - def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]: - yield Path(path).name - @property def data_format_type(self) -> DataFormatType: return DataFormatType.USER_RAW diff --git a/client/starwhale/api/_impl/dataset/loader.py b/client/starwhale/api/_impl/dataset/loader.py index f50b8c5063..5491dae898 100644 --- a/client/starwhale/api/_impl/dataset/loader.py +++ b/client/starwhale/api/_impl/dataset/loader.py @@ -13,7 +13,7 @@ from starwhale.consts import AUTH_ENV_FNAME from starwhale.base.uri import URI from starwhale.base.type import InstanceType, DataFormatType, ObjectStoreType -from starwhale.core.dataset.type import DataField +from starwhale.core.dataset.type import BaseArtifact from starwhale.core.dataset.store import FileLikeObj, ObjectStore, DatasetStorage from starwhale.core.dataset.tabular import TabularDataset, TabularDatasetRow @@ -80,31 +80,18 @@ def _get_key_compose(self, row: TabularDatasetRow, store: ObjectStore) -> str: _key_compose = f"{data_uri}:{offset}:{offset + size - 1}" return _key_compose - def __iter__(self) -> t.Generator[t.Tuple[DataField, DataField], None, None]: - _attr = { - "ds_name": self.tabular_dataset.name, - "ds_version": self.tabular_dataset.version, - } + def __iter__(self) -> t.Generator[t.Tuple[int, t.Any, t.Dict], None, None]: for row in self.tabular_dataset.scan(): # TODO: tune performance by fetch in batch - # TODO: remove ext_attr field _store = self._get_store(row) _key_compose = self._get_key_compose(row, _store) - self.logger.info(f"@{_store.bucket}/{_key_compose}") + self.logger.info(f"[{row.id}] @{_store.bucket}/{_key_compose}") _file = _store.backend._make_file(_store.bucket, _key_compose) - for data_content, data_size in self._do_iter(_file, row): - label = DataField( - idx=row.id, - data_size=sys.getsizeof(row.label), - data=row.label, - ext_attr=_attr, - ) - data = DataField( - idx=row.id, data_size=data_size, data=data_content, ext_attr=_attr - ) - - yield data, label + for data_content, _ in self._do_iter(_file, row): + data = BaseArtifact.reflect(data_content, row.data_type) + # TODO: refactor annotation origin type + yield row.id, data, row.annotations @abstractmethod def _do_iter( diff --git a/client/starwhale/api/_impl/dataset/mnist.py b/client/starwhale/api/_impl/dataset/mnist.py deleted file mode 100644 index 00953b6d55..0000000000 --- a/client/starwhale/api/_impl/dataset/mnist.py +++ /dev/null @@ -1,33 +0,0 @@ -import struct -import typing as t -from pathlib import Path - -from .builder import SWDSBinBuildExecutor - - -class MNISTBuildExecutor(SWDSBinBuildExecutor): - def iter_data_slice(self, path: str) -> t.Generator[bytes, None, None]: - fpath = Path(path) - - with fpath.open("rb") as f: - _, number, height, width = struct.unpack(">IIII", f.read(16)) - print(f">data({fpath.name}) split {number} group") - - while True: - content = f.read(height * width) - if not content: - break - yield content - - def iter_label_slice(self, path: str) -> t.Generator[bytes, None, None]: - fpath = Path(path) - - with fpath.open("rb") as f: - _, number = struct.unpack(">II", f.read(8)) - print(f">label({fpath.name}) split {number} group") - - while True: - content = f.read(1) - if not content: - break - yield struct.unpack(">B", content)[0] diff --git a/client/starwhale/api/_impl/metric.py b/client/starwhale/api/_impl/metric.py index 110c06d173..17c41a07fc 100644 --- a/client/starwhale/api/_impl/metric.py +++ b/client/starwhale/api/_impl/metric.py @@ -12,6 +12,10 @@ classification_report, ) +from starwhale.utils.flatten import do_flatten_dict + +from .model import PipelineHandler + class MetricKind: MultiClassification = "multi_classification" @@ -28,6 +32,7 @@ def _decorator(func: t.Any) -> t.Any: @wraps(func) def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: y_pr: t.Any = None + handler: PipelineHandler = args[0] _rt = func(*args, **kwargs) if show_roc_auc: @@ -41,16 +46,33 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: ) _summary_m = ["accuracy", "macro avg", "weighted avg"] _r["summary"] = {k: cr.get(k) for k in _summary_m} - _r["labels"] = {k: v for k, v in cr.items() if k not in _summary_m} - # TODO: tune performace, use intermediated result + _record_summary = do_flatten_dict(_r["summary"]) + _record_summary["kind"] = _r["kind"] + handler.evaluation.log_metrics(_record_summary) + + _r["labels"] = {} + for k, v in cr.items(): + if k in _summary_m: + continue + _r["labels"][k] = v + handler.evaluation.log("labels", id=k, **v) + + # TODO: tune performance, use intermediated result cm = confusion_matrix( y_true, y_pred, labels=all_labels, normalize=confusion_matrix_normalize ) - _r["confusion_matrix"] = { - "binarylabel": cm.tolist(), - } + _cm_list = cm.tolist() + _r["confusion_matrix"] = {"binarylabel": _cm_list} + + for idx, _pa in enumerate(_cm_list): + handler.evaluation.log( + "confusion_matrix/binarylabel", + id=idx, + **{str(_id): _v for _id, _v in enumerate(_pa)}, + ) + if show_hamming_loss: _r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred) if show_cohen_kappa_score: @@ -59,9 +81,22 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: if show_roc_auc and all_labels is not None and y_true and y_pr: _r["roc_auc"] = {} for _idx, _label in enumerate(all_labels): - _r["roc_auc"][_label] = _calculate_roc_auc( - y_true, y_pr, _label, _idx - ) + _ra_value = _calculate_roc_auc(y_true, y_pr, _label, _idx) + _r["roc_auc"][_label] = _ra_value + + for _fpr, _tpr, _threshold in zip( + _ra_value["fpr"], _ra_value["tpr"], _ra_value["thresholds"] + ): + handler.evaluation.log( + f"roc_auc/{_label}", + id=_idx, + fpr=_fpr, + tpr=_tpr, + threshold=_threshold, + ) + handler.evaluation.log( + "roc_auc/summary", id=_label, auc=_ra_value["auc"] + ) return _r return _wrapper diff --git a/client/starwhale/api/_impl/model.py b/client/starwhale/api/_impl/model.py index 73ac5f0e27..ffbc399754 100644 --- a/client/starwhale/api/_impl/model.py +++ b/client/starwhale/api/_impl/model.py @@ -3,7 +3,6 @@ import io import os import sys -import json import math import base64 import typing as t @@ -26,10 +25,9 @@ from starwhale.consts.env import SWEnv from starwhale.utils.error import FieldTypeOrValueError from starwhale.api._impl.job import Context -from starwhale.utils.flatten import do_flatten_dict from starwhale.core.job.model import STATUS from starwhale.core.eval.store import EvaluationStorage -from starwhale.api._impl.dataset import DataField, get_data_loader +from starwhale.api._impl.dataset import get_data_loader from starwhale.api._impl.wrapper import Evaluation from starwhale.core.dataset.model import Dataset @@ -55,36 +53,36 @@ def calculate_index( return _start_index, _end_index -class ResultLoader: +class PPLResultIterator: def __init__( self, - data: t.List[t.Any], + data: t.Iterator[t.Dict[str, t.Any]], deserializer: t.Optional[t.Callable] = None, ) -> None: self.data = data self.deserializer = deserializer - def __iter__(self) -> t.Any: - for _data in self.data: + def __iter__(self) -> t.Iterator[t.Dict[str, t.Any]]: + # TODO: use class to refactor data + for d in self.data: if self.deserializer: - yield self.deserializer(_data) - continue - yield _data + yield self.deserializer(d) + else: + yield d class PipelineHandler(metaclass=ABCMeta): def __init__( self, context: Context, - merge_label: bool = True, + ignore_annotations: bool = False, ignore_error: bool = False, ) -> None: self.context = context self._init_dir() # TODO: add args for compare result and label directly - self.merge_label = merge_label - + self.ignore_annotations = ignore_annotations self.ignore_error = ignore_error self.logger, self._sw_logger = self._init_logger() @@ -95,10 +93,7 @@ def __init__( # TODO: split status/result files self._timeline_writer = _jl_writer(self.status_dir / "timeline") - self._ppl_data_field = "result" - self._label_field = "label" self.evaluation = self._init_datastore() - self._monkey_patch() def _init_dir(self) -> None: @@ -133,8 +128,8 @@ def _init_logger(self) -> t.Tuple[loguru.Logger, loguru.Logger]: ) _logger.bind( type=_LogType.USER, - task_id=self.context.index, # os.environ.get("SW_TASK_ID", ""), - job_id=self.context.version, # os.environ.get("SW_JOB_ID", ""), + task_id=self.context.index, + job_id=self.context.version, ) _sw_logger = _logger.bind(type=_LogType.SW) return _logger, _sw_logger @@ -155,7 +150,7 @@ def _monkey_patch(self) -> None: def __str__(self) -> str: return f"PipelineHandler status@{self.status_dir}, " f"log@{self.log_dir}" - def __enter__(self) -> PipelineHandler: + def __enter__(self) -> "PipelineHandler": return self def __exit__( @@ -178,46 +173,34 @@ def __exit__( # self._sw_logger.remove() @abstractmethod - def ppl(self, data: bytes, **kw: t.Any) -> t.Any: + def ppl(self, data: t.Any, **kw: t.Any) -> t.Any: # TODO: how to handle each element is not equal. raise NotImplementedError @abstractmethod - def cmp(self, _data_loader: ResultLoader) -> t.Any: + def cmp(self, ppl_result: PPLResultIterator) -> t.Any: raise NotImplementedError def _builtin_serialize(self, *data: t.Any) -> bytes: return dill.dumps(data) # type: ignore - def ppl_data_serialize(self, *data: t.Any) -> bytes: + def ppl_result_serialize(self, *data: t.Any) -> bytes: return self._builtin_serialize(*data) - def ppl_data_deserialize(self, data: bytes) -> t.Any: + def ppl_result_deserialize(self, data: bytes) -> t.Any: return dill.loads(base64.b64decode(data)) - def label_data_serialize(self, data: t.Any) -> bytes: + def annotations_serialize(self, data: t.Any) -> bytes: return self._builtin_serialize(data) - def label_data_deserialize(self, data: bytes) -> bytes: + def annotations_deserialize(self, data: bytes) -> bytes: return dill.loads(base64.b64decode(data))[0] # type: ignore - # todoļ¼š waiting remove it - def deserialize(self, data: t.Union[str, bytes]) -> t.Any: - ret = json.loads(data) - ret[self._ppl_data_field] = self.ppl_data_deserialize(ret[self._ppl_data_field]) - ret[self._label_field] = self.label_data_deserialize(ret[self._label_field]) - return ret - - def deserialize_fields(self, data: t.Dict[str, t.Any]) -> t.Any: - data[self._ppl_data_field] = self.ppl_data_deserialize( - data[self._ppl_data_field] - ) - data[self._label_field] = self.label_data_deserialize(data[self._label_field]) + def deserialize(self, data: t.Dict[str, t.Any]) -> t.Any: + data["result"] = self.ppl_result_deserialize(data["result"]) + data["annotations"] = self.annotations_deserialize(data["annotations"]) return data - def handle_label(self, label: t.Any, **kw: t.Any) -> t.Any: - return label - def _record_status(func): # type: ignore @wraps(func) # type: ignore def _wrapper(*args: t.Any, **kwargs: t.Any) -> None: @@ -238,65 +221,23 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> None: @_record_status # type: ignore def _starwhale_internal_run_cmp(self) -> None: - self._sw_logger.debug("enter cmp func...") self._update_status(STATUS.START) now = now_str() try: - _ppl_results = list(self.evaluation.get_results()) - self._sw_logger.debug("cmp input data size:{}", len(_ppl_results)) - _data_loader = ResultLoader( - data=_ppl_results, deserializer=self.deserialize_fields + _iter = PPLResultIterator( + data=self.evaluation.get_results(), deserializer=self.deserialize ) - output = self.cmp(_data_loader) + output = self.cmp(_iter) except Exception as e: self._sw_logger.exception(f"cmp exception: {e}") - self._timeline_writer.write({"time": now, "status": False, "exception": e}) + self._timeline_writer.write( + {"time": now, "status": False, "exception": str(e)} + ) raise else: self._timeline_writer.write({"time": now, "status": True, "exception": ""}) self._sw_logger.debug(f"cmp result:{output}") - if not output: - self._sw_logger.warning("cmp results is None!") - return - if isinstance(output, dict): - if "summary" in output: - self.evaluation.log_metrics(do_flatten_dict(output["summary"])) - self.evaluation.log_metrics({"kind": output["kind"]}) - - if "labels" in output: - for i, label in output["labels"].items(): - self.evaluation.log("labels", id=i, **label) - - if ( - "confusion_matrix" in output - and "binarylabel" in output["confusion_matrix"] - ): - _binary_label = output["confusion_matrix"]["binarylabel"] - for _label, _probability in enumerate(_binary_label): - self.evaluation.log( - "confusion_matrix/binarylabel", - id=str(_label), - **{str(k): v for k, v in enumerate(_probability)}, - ) - if "roc_auc" in output: - for _label, _roc_auc in output["roc_auc"].items(): - _id = 0 - for _fpr, _tpr, _threshold in zip( - _roc_auc["fpr"], _roc_auc["tpr"], _roc_auc["thresholds"] - ): - self.evaluation.log( - f"roc_auc/{_label}", - id=str(_id), - fpr=_fpr, - tpr=_tpr, - threshold=_threshold, - ) - _id += 1 - self.evaluation.log( - "roc_auc/summary", id=_label, auc=_roc_auc["auc"] - ) - @_record_status # type: ignore def _starwhale_internal_run_ppl(self) -> None: self._update_status(STATUS.START) @@ -316,86 +257,47 @@ def _starwhale_internal_run_ppl(self) -> None: _data_loader = get_data_loader( dataset_uri=_dataset_uri, start=dataset_row_start, - end=dataset_row_end, + end=dataset_row_end + 1, logger=self._sw_logger, ) - for data, label in _data_loader: - if data.idx != label.idx: - msg = ( - f"data index[{data.idx}] is not equal label index [{label.idx}], " - f"{'ignore error' if self.ignore_error else ''}" - ) - self._sw_logger.error(msg) - if not self.ignore_error: - raise Exception(msg) - + for _idx, _data, _annotations in _data_loader: pred: t.Any = b"" exception = None try: # TODO: inspect profiling - pred = self.ppl( - data.data.encode() if isinstance(data.data, str) else data.data, - data_index=data.idx, - data_size=data.data_size, - label_content=label.data, - label_size=label.data_size, - label_index=label.idx, - ds_name=data.ext_attr.get("ds_name", ""), - ds_version=data.ext_attr.get("ds_version", ""), - ) + pred = self.ppl(_data, annotations=_annotations, index=_idx) except Exception as e: exception = e - self._sw_logger.exception(f"[{data.idx}] data handle -> failed") + self._sw_logger.exception(f"[{_idx}] data handle -> failed") if not self.ignore_error: self._update_status(STATUS.FAILED) raise else: exception = None - self._do_record(data, label, exception, *pred) + self._do_record(_idx, _annotations, exception, *pred) def _do_record( self, - data: DataField, - label: DataField, - exception: t.Union[None, Exception], + idx: int, + annotations: t.Dict, + exception: t.Optional[Exception], *args: t.Any, ) -> None: _timeline = { "time": now_str(), "status": exception is None, "exception": str(exception), - "index": data.idx, - self._ppl_data_field: base64.b64encode( - self.ppl_data_serialize(*args) - ).decode("ascii"), + "index": idx, } self._timeline_writer.write(_timeline) - _label = "" - if self.merge_label: - try: - label = self.handle_label( - label.data, - index=label.idx, - size=label.data_size, - ) - _label = base64.b64encode(self.label_data_serialize(label)).decode( - "ascii" - ) - except Exception as e: - self._sw_logger.exception(f"{label.data!r} label handle exception:{e}") - if not self.ignore_error: - self._update_status(STATUS.FAILED) - raise - else: - _label = "" - # self._sw_logger.debug(f"record ppl result:{data.idx}") + annotations = {} if self.ignore_annotations else annotations + _b64: t.Callable[[bytes], str] = lambda x: base64.b64encode(x).decode("ascii") self.evaluation.log_result( - data_id=str(data.idx), - result=base64.b64encode(self.ppl_data_serialize(*args)).decode("ascii"), - data_size=data.data_size, - label=_label, + data_id=idx, + result=_b64(self.ppl_result_serialize(*args)), + annotations=_b64(self.annotations_serialize(annotations)), ) self._update_status(STATUS.RUNNING) diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index 3c0d0c48fa..285fa362c9 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -68,7 +68,7 @@ def __init__(self, eval_id: Optional[str] = None): def _get_datastore_table_name(self, table_name: str) -> str: return f"project/{self.project}/eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{table_name}" - def log_result(self, data_id: str, result: Any, **kwargs: Any) -> None: + def log_result(self, data_id: Union[int, str], result: Any, **kwargs: Any) -> None: record = {"id": data_id, "result": result} for k, v in kwargs.items(): record[k.lower()] = v diff --git a/client/starwhale/api/dataset.py b/client/starwhale/api/dataset.py index 07b095562c..6b7e242b31 100644 --- a/client/starwhale/api/dataset.py +++ b/client/starwhale/api/dataset.py @@ -1,15 +1,22 @@ from ._impl.dataset import ( Link, + Text, + Audio, + Image, + Binary, LinkType, MIMEType, + ClassLabel, S3LinkAuth, + BoundingBox, BuildExecutor, + GrayscaleImage, get_data_loader, LocalFSLinkAuth, DefaultS3LinkAuth, SWDSBinDataLoader, UserRawDataLoader, - MNISTBuildExecutor, + COCOObjectAnnotation, SWDSBinBuildExecutor, UserRawBuildExecutor, ) @@ -28,7 +35,14 @@ "BuildExecutor", # SWDSBinBuildExecutor alias "UserRawBuildExecutor", "SWDSBinBuildExecutor", - "MNISTBuildExecutor", "SWDSBinDataLoader", "UserRawDataLoader", + "Binary", + "Text", + "Audio", + "Image", + "ClassLabel", + "BoundingBox", + "GrayscaleImage", + "COCOObjectAnnotation", ] diff --git a/client/starwhale/base/mixin.py b/client/starwhale/base/mixin.py index 8445849ada..696775e8ff 100644 --- a/client/starwhale/base/mixin.py +++ b/client/starwhale/base/mixin.py @@ -1,3 +1,4 @@ +import json import typing as t from copy import deepcopy from enum import Enum @@ -18,6 +19,10 @@ def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict: else: raise FormatError(f"{self} cannot be formatted as a dict") + def jsonify(self, ignore_keys: t.Optional[t.List[str]] = None) -> str: + r = self.asdict(ignore_keys) + return json.dumps(r, separators=(",", ":")) + def _do_asdict_convert(obj: t.Any) -> t.Any: if isinstance(obj, dict): diff --git a/client/starwhale/core/dataset/model.py b/client/starwhale/core/dataset/model.py index 51b768a568..6c994789a3 100644 --- a/client/starwhale/core/dataset/model.py +++ b/client/starwhale/core/dataset/model.py @@ -366,10 +366,7 @@ def _call_make_swds( dataset_name=self.uri.object.name, dataset_version=self._version, project_name=self.uri.project, - data_dir=workdir / swds_config.data_dir, workdir=self.store.snapshot_workdir, - data_filter=swds_config.data_filter, - label_filter=swds_config.label_filter, alignment_bytes_size=swds_config.attr.alignment_size, volume_bytes_size=swds_config.attr.volume_size, append=append, @@ -452,12 +449,7 @@ def _copy_src( def _load_dataset_config(self, yaml_path: Path) -> DatasetConfig: self._do_validate_yaml(yaml_path) - _config = DatasetConfig.create_by_yaml(yaml_path) - - if not (yaml_path.parent / _config.data_dir).exists(): - raise FileNotFoundError(f"dataset datadir:{_config.data_dir}") - - return _config + return DatasetConfig.create_by_yaml(yaml_path) class CloudDataset(CloudBundleModelMixin, Dataset): diff --git a/client/starwhale/core/dataset/store.py b/client/starwhale/core/dataset/store.py index 70509cf7d3..dd98a2e149 100644 --- a/client/starwhale/core/dataset/store.py +++ b/client/starwhale/core/dataset/store.py @@ -93,8 +93,10 @@ def dataset_rootdir(self) -> Path: @classmethod def save_data_file( - cls, src: Path, force: bool = False, remove_src: bool = False + cls, src: t.Union[Path, str], force: bool = False, remove_src: bool = False ) -> t.Tuple[str, Path]: + src = Path(src) + if not src.exists(): raise NotFoundError(f"data origin file: {src}") diff --git a/client/starwhale/core/dataset/tabular.py b/client/starwhale/core/dataset/tabular.py index 67f67076c3..7f6d5073f8 100644 --- a/client/starwhale/core/dataset/tabular.py +++ b/client/starwhale/core/dataset/tabular.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import sys import json import typing as t @@ -33,61 +31,89 @@ from starwhale.api._impl.wrapper import Dataset as DatastoreWrapperDataset from starwhale.core.dataset.store import DatasetStorage -from .type import MIMEType - class TabularDatasetRow(ASDictMixin): + + ANNOTATION_PREFIX = "_annotation_" + def __init__( self, id: int, data_uri: str, - label: t.Any, data_format: DataFormatType = DataFormatType.SWDS_BIN, object_store_type: ObjectStoreType = ObjectStoreType.LOCAL, data_offset: int = 0, data_size: int = 0, data_origin: DataOriginType = DataOriginType.NEW, - data_mime_type: MIMEType = MIMEType.UNDEFINED, + data_type: t.Optional[t.Dict[str, t.Any]] = None, auth_name: str = "", + annotations: t.Optional[t.Dict[str, t.Any]] = None, **kw: t.Union[str, int, float], ) -> None: self.id = id self.data_uri = data_uri.strip() - self.data_format = DataFormatType(data_format) + self.data_format = data_format self.data_offset = data_offset self.data_size = data_size - self.data_origin = DataOriginType(data_origin) - self.object_store_type = ObjectStoreType(object_store_type) - self.data_mime_type = MIMEType(data_mime_type) + self.data_origin = data_origin + self.object_store_type = object_store_type self.auth_name = auth_name - self.label = self._parse_label(label) + self.data_type = data_type or {} + self.annotations = annotations or {} self.extra_kw = kw - # TODO: add non-starwhale object store related fields, such as address, authority # TODO: add data uri crc for versioning - # TODO: support user custom annotations + + @classmethod + def from_datastore( + cls, + id: int, + data_uri: str, + data_format: str = DataFormatType.SWDS_BIN.value, + object_store_type: str = ObjectStoreType.LOCAL.value, + data_offset: int = 0, + data_size: int = 0, + data_origin: str = DataOriginType.NEW.value, + data_type: str = "", + auth_name: str = "", + **kw: t.Any, + ) -> "TabularDatasetRow": + annotations = {} + extra_kw = {} + for k, v in kw.items(): + if not k.startswith(cls.ANNOTATION_PREFIX): + extra_kw[k] = v + continue + _, name = k.split(cls.ANNOTATION_PREFIX, 1) + annotations[name] = json.loads(v) + + return cls( + id=id, + data_uri=data_uri, + data_format=DataFormatType(data_format), + object_store_type=ObjectStoreType(object_store_type), + data_offset=data_offset, + data_size=data_size, + data_origin=DataOriginType(data_origin), + auth_name=auth_name, + # TODO: use protobuf format to store and reflect annotation + data_type=json.loads(data_type), + annotations=annotations, + **extra_kw, + ) def __eq__(self, o: object) -> bool: return self.__dict__ == o.__dict__ - def _parse_label(self, label: t.Any) -> str: - # TODO: add more label type-parse - # TODO: support user-defined label type - if isinstance(label, bytes): - return label.decode() - elif isinstance(label, (int, float, complex)): - return str(label) - elif isinstance(label, str): - return label - else: - raise NoSupportError(f"label type:{type(label)} {label}") - def _do_validate(self) -> None: if self.id < 0: raise FieldTypeOrValueError( f"id need to be greater than or equal to zero, but current id is {self.id}" ) + if not isinstance(self.annotations, dict) or not self.annotations: + raise FieldTypeOrValueError("no annotations field") + if not self.data_uri: raise FieldTypeOrValueError("no raw_data_uri field") @@ -106,13 +132,26 @@ def __str__(self) -> str: def __repr__(self) -> str: return ( f"row-{self.id}, data-{self.data_uri}(offset:{self.data_offset}, size:{self.data_size}," - f"format:{self.data_format}, mime type:{self.data_mime_type}), " + f"format:{self.data_format}, meta type:{self.data_type}), " f"origin-[{self.data_origin}], object store-{self.object_store_type}" ) def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict: - d = super().asdict(ignore_keys=ignore_keys or ["extra_kw"]) + d = super().asdict( + ignore_keys=ignore_keys or ["annotations", "extra_kw", "data_type"] + ) d.update(_do_asdict_convert(self.extra_kw)) + # TODO: use protobuf format to store and reflect annotation + for k, v in self.annotations.items(): + v = _do_asdict_convert(v) + if getattr(v, "jsonify", None): + v = v.jsonify() + else: + v = json.dumps(v, separators=(",", ":")) + d[f"{self.ANNOTATION_PREFIX}{k}"] = v + d["data_type"] = json.dumps( + _do_asdict_convert(self.data_type), separators=(",", ":") + ) return d @@ -174,7 +213,7 @@ def scan( if k not in _d: continue _d[k] = v(_d[k]) - yield TabularDatasetRow(**_d) + yield TabularDatasetRow.from_datastore(**_d) def close(self) -> None: self._ds_wrapper.close() diff --git a/client/starwhale/core/dataset/type.py b/client/starwhale/core/dataset/type.py index 00d7eef7db..7a1b80bc3c 100644 --- a/client/starwhale/core/dataset/type.py +++ b/client/starwhale/core/dataset/type.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import os import typing as t from abc import ABCMeta, abstractmethod @@ -8,7 +9,7 @@ from functools import partial from starwhale.utils import load_yaml, convert_to_bytes -from starwhale.consts import DEFAULT_STARWHALE_API_VERSION +from starwhale.consts import SHORT_VERSION_CNT, DEFAULT_STARWHALE_API_VERSION from starwhale.utils.fs import FilePosition from starwhale.base.mixin import ASDictMixin from starwhale.utils.error import NoSupportError, FieldTypeOrValueError @@ -17,13 +18,6 @@ D_ALIGNMENT_SIZE = 4 * 1024 # 4k for page cache -class DataField(t.NamedTuple): - idx: int - data_size: int - data: t.Union[bytes, str] - ext_attr: t.Dict[str, t.Any] - - @unique class LinkType(Enum): LocalFS = "local_fs" @@ -157,24 +151,271 @@ def create_by_file_suffix(cls, name: str) -> MIMEType: DefaultS3LinkAuth = S3LinkAuth() -class Link: +_T = t.TypeVar("_T") +_TupleOrList = t.Union[t.Tuple[_T, ...], t.List[_T]] +_TShape = _TupleOrList[t.Optional[int]] +_TArtifactFP = t.Union[str, bytes, Path, io.IOBase] + + +@unique +class ArtifactType(Enum): + Binary = "binary" + Image = "image" + Video = "video" + Audio = "audio" + Text = "text" + + +class BaseArtifact(ASDictMixin, metaclass=ABCMeta): + def __init__( + self, + fp: _TArtifactFP, + type: ArtifactType, + display_name: str = "", + shape: t.Optional[_TShape] = None, + mime_type: t.Optional[MIMEType] = None, + encoding: str = "", + ) -> None: + self.fp = fp + self.type = ArtifactType(type) + + _fpath = str(fp) if isinstance(fp, (Path, str)) and fp else "" + self.display_name = display_name or os.path.basename(_fpath) + self.mime_type = mime_type or MIMEType.create_by_file_suffix(_fpath) + self.shape = shape + self.encoding = encoding + self._do_validate() + + def _do_validate(self) -> None: + ... + + @classmethod + def reflect(cls, raw_data: bytes, data_type: t.Dict[str, t.Any]) -> BaseArtifact: + if not isinstance(raw_data, bytes): + raise NoSupportError(f"raw data type({type(raw_data)}) is not bytes") + + # TODO: support data_type reflect + dtype = data_type.get("type") + mime_type = MIMEType(data_type.get("mime_type", MIMEType.UNDEFINED)) + shape = data_type.get("shape", []) + encoding = data_type.get("encoding", "") + + if dtype == ArtifactType.Text.value: + _encoding = encoding or Text.DEFAULT_ENCODING + return Text(content=raw_data.decode(_encoding), encoding=_encoding) + elif dtype == ArtifactType.Image.value: + return Image(raw_data, mime_type=mime_type, shape=shape) + elif dtype == ArtifactType.Audio.value: + return Audio(raw_data, mime_type=mime_type, shape=shape) + elif not dtype or dtype == ArtifactType.Binary.value: + return Binary(raw_data) + else: + raise NoSupportError(f"Artifact reflect error: {data_type}") + + # TODO: add to_tensor, to_numpy method + def to_bytes(self) -> bytes: + if isinstance(self.fp, bytes): + return self.fp + elif isinstance(self.fp, (str, Path)): + return Path(self.fp).read_bytes() + elif isinstance(self.fp, io.IOBase): + # TODO: strict to binary io? + return self.fp.read() # type: ignore + else: + raise NoSupportError(f"read raw for type:{type(self.fp)}") + + def astype(self) -> t.Dict[str, t.Any]: + return { + "type": self.type, + "mime_type": self.mime_type, + "shape": self.shape, + "encoding": self.encoding, + } + + def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Any]: + return super().asdict(ignore_keys or ["fp"]) + + def __str__(self) -> str: + return f"{self.type}, display:{self.display_name}, mime_type:{self.mime_type}, shape:{self.shape}, encoding: {self.encoding}" + + __repr__ = __str__ + + +class Binary(BaseArtifact): + def __init__( + self, + fp: _TArtifactFP, + mime_type: MIMEType = MIMEType.UNDEFINED, + ) -> None: + super().__init__(fp, ArtifactType.Binary, "", (1,), mime_type) + + +class Image(BaseArtifact): + def __init__( + self, + fp: _TArtifactFP = "", + display_name: str = "", + shape: t.Optional[_TShape] = None, + mime_type: t.Optional[MIMEType] = None, + ) -> None: + super().__init__( + fp, + ArtifactType.Image, + display_name=display_name, + shape=shape or (None, None, 3), + mime_type=mime_type, + ) + + def _do_validate(self) -> None: + if self.mime_type not in ( + MIMEType.PNG, + MIMEType.JPEG, + MIMEType.WEBP, + MIMEType.SVG, + MIMEType.GIF, + MIMEType.APNG, + MIMEType.GRAYSCALE, + MIMEType.UNDEFINED, + ): + raise NoSupportError(f"Image type: {self.mime_type}") + + +class GrayscaleImage(Image): + def __init__( + self, + fp: _TArtifactFP = "", + display_name: str = "", + shape: t.Optional[_TShape] = None, + ) -> None: + shape = shape or (None, None) + super().__init__( + fp, display_name, (shape[0], shape[1], 1), mime_type=MIMEType.GRAYSCALE + ) + + +# TODO: support Video type + + +class Audio(BaseArtifact): + def __init__( + self, + fp: _TArtifactFP = "", + display_name: str = "", + shape: t.Optional[_TShape] = None, + mime_type: t.Optional[MIMEType] = None, + ) -> None: + shape = shape or (None,) + super().__init__(fp, ArtifactType.Audio, display_name, shape, mime_type) + + def _do_validate(self) -> None: + if self.mime_type not in ( + MIMEType.MP3, + MIMEType.WAV, + MIMEType.UNDEFINED, + ): + raise NoSupportError(f"Audio type: {self.mime_type}") + + +class ClassLabel(ASDictMixin): + def __init__(self, names: t.List[t.Union[int, float, str]]) -> None: + self.type = "class_label" + self.names = names + + @classmethod + def from_num_classes(cls, num: int) -> ClassLabel: + if num < 1: + raise FieldTypeOrValueError(f"num:{num} less than 1") + return cls(list(range(0, num))) + + def __str__(self) -> str: + return f"ClassLabel: {len(self.names)} classes" + + def __repr__(self) -> str: + return f"ClassLabel: {self.names}" + + +# TODO: support other bounding box format +class BoundingBox(ASDictMixin): + def __init__(self, x: float, y: float, width: float, height: float) -> None: + self.type = "bounding_box" + self.x = x + self.y = y + self.width = width + self.height = height + + def to_list(self) -> t.List[float]: + return [self.x, self.y, self.width, self.height] + + def __str__(self) -> str: + return f"BoundingBox: point:({self.x}, {self.y}), width: {self.width}, height: {self.height})" + + __repr__ = __str__ + + +class Text(BaseArtifact): + DEFAULT_ENCODING = "utf-8" + + def __init__(self, content: str, encoding: str = DEFAULT_ENCODING) -> None: + # TODO: add encoding validate + self.content = content + super().__init__( + fp=b"", + type=ArtifactType.Text, + display_name=f"{content[:SHORT_VERSION_CNT]}...", + shape=(1,), + mime_type=MIMEType.PLAIN, + encoding=encoding, + ) + + def to_bytes(self) -> bytes: + return self.content.encode(self.encoding) + + +# https://cocodataset.org/#format-data +class COCOObjectAnnotation(ASDictMixin): def __init__( self, - uri: str = "", + id: int, + image_id: int, + category_id: int, + segmentation: t.Union[t.List, t.Dict], + area: float, + bbox: t.Union[BoundingBox, t.List[float]], + iscrowd: int, + ) -> None: + self.type = "coco_object_annotation" + self.id = id + self.image_id = image_id + self.category_id = category_id + self.bbox = bbox.to_list() if isinstance(bbox, BoundingBox) else bbox + self.segmentation = segmentation + self.area = area + self.iscrowd = iscrowd + + def do_validate(self) -> None: + if self.iscrowd not in (0, 1): + raise FieldTypeOrValueError(f"iscrowd({self.iscrowd}) only accepts 0 or 1") + + # TODO: iscrowd=0 -> polygons, iscrowd=1 -> RLE validate + + +class Link(ASDictMixin): + def __init__( + self, + uri: str, auth: t.Optional[LinkAuth] = DefaultS3LinkAuth, offset: int = FilePosition.START, size: int = -1, - mime_type: MIMEType = MIMEType.UNDEFINED, + data_type: t.Optional[BaseArtifact] = None, + with_local_fs_data: bool = False, ) -> None: + self.type = "link" self.uri = uri.strip() self.offset = offset self.size = size self.auth = auth - - if mime_type == MIMEType.UNDEFINED or mime_type not in MIMEType: - self.mime_type = MIMEType.create_by_file_suffix(self.uri) - else: - self.mime_type = mime_type + self.data_type = data_type + self.with_local_fs_data = with_local_fs_data self.do_validate() @@ -185,11 +426,17 @@ def do_validate(self) -> None: if self.size < -1: raise FieldTypeOrValueError(f"size({self.size}) must be non-negative or -1") + def astype(self) -> t.Dict[str, t.Any]: + return { + "type": self.type, + "data_type": self.data_type.astype() if self.data_type else {}, + } + def __str__(self) -> str: return f"Link {self.uri}" def __repr__(self) -> str: - return f"Link uri:{self.uri}, offset:{self.offset}, size:{self.size}, mime type:{self.mime_type}" + return f"Link uri:{self.uri}, offset:{self.offset}, size:{self.size}, data type:{self.data_type}, with localFS data:{self.with_local_fs_data}" class DatasetSummary(ASDictMixin): @@ -197,19 +444,19 @@ def __init__( self, rows: int = 0, increased_rows: int = 0, - label_byte_size: int = 0, data_byte_size: int = 0, include_link: bool = False, include_user_raw: bool = False, + annotations: t.Optional[t.List[str]] = None, **kw: t.Any, ) -> None: self.rows = rows self.increased_rows = increased_rows self.unchanged_rows = rows - increased_rows - self.label_byte_size = label_byte_size self.data_byte_size = data_byte_size self.include_link = include_link self.include_user_raw = include_user_raw + self.annotations = annotations or [] def __str__(self) -> str: return f"Dataset Summary: rows({self.rows}), include user-raw({self.include_user_raw}), include link({self.include_link})" @@ -218,7 +465,7 @@ def __repr__(self) -> str: return ( f"Dataset Summary: rows({self.rows}, increased: {self.increased_rows}), " f"include user-raw({self.include_user_raw}), include link({self.include_link})," - f"size(data:{self.data_byte_size}, label: {self.label_byte_size})" + f"size(data:{self.data_byte_size}, annotations: {self.annotations})" ) @@ -246,10 +493,7 @@ class DatasetConfig(ASDictMixin): def __init__( self, name: str, - data_dir: str, process: str, - data_filter: str = "", - label_filter: str = "", runtime: str = "", pkg_data: t.List[str] = [], exclude_pkg_data: t.List[str] = [], @@ -260,9 +504,6 @@ def __init__( **kw: t.Any, ) -> None: self.name = name - self.data_dir = str(data_dir) - self.data_filter = data_filter - self.label_filter = label_filter self.process = process self.tag = tag self.desc = desc @@ -286,8 +527,7 @@ def _validator(self) -> None: def __str__(self) -> str: return f"DataSet Config {self.name}" - def __repr__(self) -> str: - return f"DataSet Config {self.name}, data:{self.data_dir}" + __repr__ = __str__ @classmethod def create_by_yaml(cls, fpath: t.Union[str, Path]) -> DatasetConfig: diff --git a/client/starwhale/core/job/model.py b/client/starwhale/core/job/model.py index 9571549812..5c64559ab1 100644 --- a/client/starwhale/core/job/model.py +++ b/client/starwhale/core/job/model.py @@ -156,12 +156,10 @@ def execute(self) -> TaskResult: self.status = STATUS.RUNNING # instance method if not self.cls_name: - logger.debug("hi, use func") func = get_func_from_module(_module, self.func) # The standard implementation does not return results func(context=self.context) else: - logger.debug("hi, use class") _cls = load_cls(_module, self.cls_name) # need an instance with _cls() as obj: diff --git a/client/starwhale/utils/flatten.py b/client/starwhale/utils/flatten.py index 87cf3ec9c3..1fcc956c0e 100644 --- a/client/starwhale/utils/flatten.py +++ b/client/starwhale/utils/flatten.py @@ -1,4 +1,5 @@ import typing as t +from copy import deepcopy def do_flatten_dict(origin: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: @@ -27,5 +28,5 @@ def _f_list( rt[f"{_prefix}{index}"] = _d index += 1 - _f_dict(origin) + _f_dict(deepcopy(origin)) return rt diff --git a/client/tests/core/test_dataset.py b/client/tests/core/test_dataset.py index 42f9a798d6..ed48ce6165 100644 --- a/client/tests/core/test_dataset.py +++ b/client/tests/core/test_dataset.py @@ -77,7 +77,6 @@ def test_build_workflow(self, m_import: MagicMock, m_copy_fs: MagicMock) -> None assert m_import.call_args[0][1] == "mnist.process:DataSetProcessExecutor" assert m_cls.call_count == 1 - assert m_cls.call_args[1]["data_dir"] == (Path(workdir) / "data").resolve() assert m_cls.call_args[1]["alignment_bytes_size"] == 4096 assert snapshot_workdir.exists() diff --git a/client/tests/core/test_model.py b/client/tests/core/test_model.py index e1c914222a..36d7d731bb 100644 --- a/client/tests/core/test_model.py +++ b/client/tests/core/test_model.py @@ -20,7 +20,7 @@ from starwhale.base.type import URIType, BundleType from starwhale.utils.config import SWCliConfigMixed from starwhale.api._impl.job import Context -from starwhale.api._impl.model import ResultLoader, PipelineHandler +from starwhale.api._impl.model import PipelineHandler, PPLResultIterator from starwhale.core.model.view import ModelTermView from starwhale.core.model.model import StandaloneModel from starwhale.core.instance.view import InstanceTermView @@ -196,7 +196,7 @@ class SimpleHandler(PipelineHandler): def ppl(self, data: bytes, **kw: t.Any) -> t.Any: pass - def cmp(self, _data_loader: ResultLoader) -> t.Any: + def cmp(self, _iter: PPLResultIterator) -> t.Any: pass def some(self): diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index 325d57a5f7..9ecc17c0e4 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -1,15 +1,12 @@ import os import json import struct +import typing as t from pathlib import Path -from starwhale.utils.fs import ensure_dir, blake2b_file -from starwhale.api.dataset import ( - Link, - MIMEType, - MNISTBuildExecutor, - UserRawBuildExecutor, -) +from starwhale.utils.fs import blake2b_file +from starwhale.api.dataset import Link, MIMEType, GrayscaleImage, UserRawBuildExecutor +from starwhale.core.dataset.type import ArtifactType from starwhale.core.dataset.store import DatasetStorage from starwhale.core.dataset.tabular import TabularDataset from starwhale.api._impl.dataset.builder import ( @@ -17,35 +14,63 @@ _header_size, _header_magic, _header_struct, + SWDSBinBuildExecutor, ) from .. import ROOT_DIR from .test_base import BaseTestCase -_mnist_dir = f"{ROOT_DIR}/data/dataset/mnist" -_mnist_data = open(f"{_mnist_dir}/data", "rb").read() -_mnist_label = open(f"{_mnist_dir}/label", "rb").read() +_mnist_dir = Path(f"{ROOT_DIR}/data/dataset/mnist") +_mnist_data_path = _mnist_dir / "data" +_mnist_label_path = _mnist_dir / "label" + + +class MNISTBuildExecutor(SWDSBinBuildExecutor): + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + with _mnist_data_path.open("rb") as data_file, _mnist_label_path.open( + "rb" + ) as label_file: + _, data_number, height, width = struct.unpack(">IIII", data_file.read(16)) + _, label_number = struct.unpack(">II", label_file.read(8)) + print( + f">data({data_file.name}) split data:{data_number}, label:{label_number} group" + ) + image_size = height * width + + for i in range(0, min(data_number, label_number)): + _data = data_file.read(image_size) + _label = struct.unpack(">B", label_file.read(1))[0] + yield GrayscaleImage( + _data, + display_name=f"{i}", + shape=(height, width, 1), + ), {"label": _label} class _UserRawMNIST(UserRawBuildExecutor): - def iter_data_slice(self, path: str): - size = 28 * 28 - file_size = Path(path).stat().st_size - offset = 16 - while offset < file_size: - yield Link(offset=offset, size=size, mime_type=MIMEType.GRAYSCALE) - offset += size - - def iter_label_slice(self, path: str): - fpath = Path(path) - - with fpath.open("rb") as f: - f.seek(8) - while True: - content = f.read(1) - if not content: - break - yield struct.unpack(">B", content)[0] + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + with _mnist_data_path.open("rb") as data_file, _mnist_label_path.open( + "rb" + ) as label_file: + _, data_number, height, width = struct.unpack(">IIII", data_file.read(16)) + _, label_number = struct.unpack(">II", label_file.read(8)) + + image_size = height * width + offset = 16 + + for i in range(0, min(data_number, label_number)): + _data = Link( + uri=str(_mnist_data_path.absolute()), + offset=offset, + size=image_size, + data_type=GrayscaleImage( + display_name=f"{i}", shape=(height, width, 1) + ), + with_local_fs_data=True, + ) + _label = struct.unpack(">B", label_file.read(1))[0] + yield _data, {"label": _label} + offset += image_size class TestDatasetBuildExecutor(BaseTestCase): @@ -57,26 +82,14 @@ def setUp(self) -> None: ) self.raw_data = os.path.join(self.local_storage, ".user", "data") self.workdir = os.path.join(self.local_storage, ".user", "workdir") - - self.data_fpath = os.path.join(self.raw_data, "mnist-data-0") - ensure_dir(self.raw_data) - with open(self.data_fpath, "wb") as f: - f.write(_mnist_data) - - self.data_file_sign = blake2b_file(self.data_fpath) - - with open(os.path.join(self.raw_data, "mnist-label-0"), "wb") as f: - f.write(_mnist_label) + self.data_file_sign = blake2b_file(_mnist_data_path) def test_user_raw_workflow(self) -> None: with _UserRawMNIST( dataset_name="mnist", dataset_version="332211", project_name="self", - data_dir=Path(self.raw_data), workdir=Path(self.workdir), - data_filter="mnist-data-*", - label_filter="mnist-data-*", alignment_bytes_size=64, volume_bytes_size=100, ) as e: @@ -111,13 +124,9 @@ def test_swds_bin_workflow(self) -> None: dataset_name="mnist", dataset_version="112233", project_name="self", - data_dir=Path(self.raw_data), workdir=Path(self.workdir), - data_filter="mnist-data-*", - label_filter="mnist-data-*", alignment_bytes_size=64, volume_bytes_size=100, - data_mime_type=MIMEType.GRAYSCALE, ) as e: assert e.data_tmpdir.exists() summary = e.make_swds() @@ -167,4 +176,6 @@ def test_swds_bin_workflow(self) -> None: assert meta.data_offset == 32 assert meta.extra_kw["_swds_bin_offset"] == 0 assert meta.data_uri in data_files_sign - assert meta.data_mime_type == MIMEType.GRAYSCALE + assert meta.data_type["type"] == ArtifactType.Image.value + assert meta.data_type["mime_type"] == MIMEType.GRAYSCALE.value + assert meta.data_type["shape"] == [28, 28, 1] diff --git a/client/tests/sdk/test_loader.py b/client/tests/sdk/test_loader.py index 8020fd564a..2a3a61c237 100644 --- a/client/tests/sdk/test_loader.py +++ b/client/tests/sdk/test_loader.py @@ -15,7 +15,7 @@ SWDSBinDataLoader, UserRawDataLoader, ) -from starwhale.core.dataset.type import DatasetSummary +from starwhale.core.dataset.type import Image, ArtifactType, DatasetSummary from starwhale.core.dataset.store import ( DatasetStorage, S3StorageBackend, @@ -53,10 +53,13 @@ def test_user_raw_local_store( data_uri=fname, data_offset=16, data_size=784, - label=0, + annotations={"label": 0}, data_origin=DataOriginType.NEW, data_format=DataFormatType.UNDEFINED, - data_mime_type=MIMEType.UNDEFINED, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="", ) ] @@ -72,13 +75,12 @@ def test_user_raw_local_store( rows = list(loader) assert len(rows) == 1 - _data, _label = rows[0] + _idx, _data, _annotations = rows[0] + assert _idx == 0 + assert _annotations["label"] == 0 - assert _label.idx == 0 - assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} - assert _data.data_size == len(_data.data) - assert len(_data.data) == 28 * 28 - assert isinstance(_data.data, bytes) + assert len(_data.to_bytes()) == 28 * 28 + assert isinstance(_data, Image) assert loader.kind == DataFormatType.USER_RAW assert list(loader._stores.keys()) == ["local."] @@ -136,10 +138,13 @@ def test_user_raw_remote_store( data_uri=f"s3://127.0.0.1:9000@starwhale/project/2/dataset/11/{version}", data_offset=16, data_size=784, - label=0, + annotations={"label": 0}, data_origin=DataOriginType.NEW, data_format=DataFormatType.USER_RAW, - data_mime_type=MIMEType.GRAYSCALE, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="server1", ), TabularDatasetRow( @@ -148,10 +153,13 @@ def test_user_raw_remote_store( data_uri=f"s3://127.0.0.1:19000@starwhale/project/2/dataset/11/{version}", data_offset=16, data_size=784, - label=1, + annotations={"label": 1}, data_origin=DataOriginType.NEW, data_format=DataFormatType.USER_RAW, - data_mime_type=MIMEType.GRAYSCALE, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="server2", ), TabularDatasetRow( @@ -160,10 +168,13 @@ def test_user_raw_remote_store( data_uri=f"s3://starwhale/project/2/dataset/11/{version}", data_offset=16, data_size=784, - label=1, + annotations={"label": 1}, data_origin=DataOriginType.NEW, data_format=DataFormatType.USER_RAW, - data_mime_type=MIMEType.GRAYSCALE, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="server2", ), TabularDatasetRow( @@ -172,10 +183,13 @@ def test_user_raw_remote_store( data_uri=f"s3://username:password@127.0.0.1:29000@starwhale/project/2/dataset/11/{version}", data_offset=16, data_size=784, - label=1, + annotations={"label": 1}, data_origin=DataOriginType.NEW, data_format=DataFormatType.USER_RAW, - data_mime_type=MIMEType.GRAYSCALE, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="server3", ), ] @@ -204,12 +218,13 @@ def test_user_raw_remote_store( rows = list(loader) assert len(rows) == 4 - _data, _label = rows[0] - assert _label.idx == 0 - assert _label.data == "0" - assert len(_data.data) == 28 * 28 - assert len(_data.data) == _data.data_size - assert isinstance(_data.data, bytes) + _idx, _data, _annotations = rows[0] + assert _idx == 0 + assert _annotations["label"] == 0 + assert isinstance(_data, Image) + + assert len(_data.to_bytes()) == 28 * 28 + assert isinstance(_data.to_bytes(), bytes) assert len(loader._stores) == 3 assert loader._stores["remote.server1"].backend.kind == SWDSBackendType.S3 assert loader._stores["remote.server1"].bucket == "starwhale" @@ -244,10 +259,13 @@ def test_swds_bin_s3( data_size=784, _swds_bin_offset=0, _swds_bin_size=8160, - label=b"0", + annotations={"label": 0}, data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, - data_mime_type=MIMEType.UNDEFINED, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="", ) ] @@ -280,14 +298,12 @@ def test_swds_bin_s3( rows = list(loader) assert len(rows) == 1 - _data, _label = rows[0] - assert _label.idx == 0 - assert _label.data == "0" + _idx, _data, _annotations = rows[0] + assert _idx == 0 + assert _annotations["label"] == 0 - assert len(_data.data) == _data.data_size - assert _data.data_size == 10 * 28 * 28 - assert _data.ext_attr == {"ds_name": "mnist", "ds_version": version} - assert isinstance(_data.data, bytes) + assert len(_data.to_bytes()) == 10 * 28 * 28 + assert isinstance(_data, Image) assert list(loader._stores.keys()) == ["local."] backend = loader._stores["local."].backend @@ -328,10 +344,13 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non data_size=784, _swds_bin_offset=0, _swds_bin_size=8160, - label=b"0", + annotations={"label": 0}, data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, - data_mime_type=MIMEType.UNDEFINED, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="", ), TabularDatasetRow( @@ -342,10 +361,13 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non data_size=784, _swds_bin_offset=0, _swds_bin_size=8160, - label=b"1", + annotations={"label": 1}, data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, - data_mime_type=MIMEType.UNDEFINED, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="", ), ] @@ -358,14 +380,14 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non rows = list(loader) assert len(rows) == 2 - _data, _label = rows[0] - assert _label.idx == 0 - assert _label.data == "0" + _idx, _data, _annotations = rows[0] + + assert _idx == 0 + assert _annotations["label"] == 0 - assert len(_data.data) == _data.data_size - assert _data.data_size == 10 * 28 * 28 - assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} - assert isinstance(_data.data, bytes) + assert isinstance(_data, Image) + assert len(_data.to_bytes()) == 7840 + assert isinstance(_data.to_bytes(), bytes) assert list(loader._stores.keys()) == ["local."] backend = loader._stores["local."].backend diff --git a/client/tests/sdk/test_model.py b/client/tests/sdk/test_model.py index 55ce8bd8b6..e0677ef6ad 100644 --- a/client/tests/sdk/test_model.py +++ b/client/tests/sdk/test_model.py @@ -21,7 +21,7 @@ from starwhale.api.dataset import get_data_loader, UserRawDataLoader from starwhale.api._impl.job import Context from starwhale.core.eval.store import EvaluationStorage -from starwhale.core.dataset.type import MIMEType, DatasetSummary +from starwhale.core.dataset.type import MIMEType, ArtifactType, DatasetSummary from starwhale.core.dataset.store import DatasetStorage from starwhale.core.dataset.tabular import TabularDatasetRow @@ -34,7 +34,8 @@ def ppl(self, data: bytes, **kw: t.Any) -> t.Any: def cmp(self, _data_loader: t.Any) -> t.Any: for _data in _data_loader: - print(_data) + assert "result" in _data + assert "annotations" in _data return { "summary": {"a": 1}, "kind": "test", @@ -119,13 +120,13 @@ def test_cmp( "result": "gASVaQAAAAAAAABdlEsHYV2UXZQoRz4mBBuTAu5hRz4bF5vyEiX+Rz479hi1FqrRRz5MqGToQCdARz3WYwL267cBRz3TzJIFVM1PRz1u4heY2/90Rz/wAAAAAAAARz3Kj1Gg+FBvRz5s1fMUlZZ8ZWGGlC4=", "data_size": "784", "id": "0", - "label": "gASVBQAAAAAAAABLB4WULg==", + "annotations": "gASVBQAAAAAAAABLB4WULg==", }, { "result": "gASVaQAAAAAAAABdlEsCYV2UXZQoRz7HJD9vpfz2Rz7nuBHd45K7Rz/v/95AI4woRz54jeSOtfKhRz4ydvSYTUVCRz4C6uB7EvDbRz66RdBlHOhyRz4yZGRfv61uRz6WGg/Jbfu6Rz3Qy/2xeB34ZWGGlC4=", "data_size": "784", "id": "1", - "label": "gASVBQAAAAAAAABLAoWULg==", + "annotations": "gASVBQAAAAAAAABLAoWULg==", }, ] @@ -140,8 +141,6 @@ def test_cmp( ) ) as _handler: _handler._starwhale_internal_run_cmp() - m_eval_log_metrics.assert_called() - m_eval_log.assert_called() status_file_path = os.path.join(_status_dir, "current") assert os.path.exists(status_file_path) @@ -182,10 +181,13 @@ def test_ppl( data_size=784, _swds_bin_offset=0, _swds_bin_size=8160, - label=b"0", + annotations={"label": 0}, data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, - data_mime_type=MIMEType.UNDEFINED, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="", ), ] @@ -230,9 +232,6 @@ class Dummy(PipelineHandler): def ppl(self, data: bytes, **kw: t.Any) -> t.Any: return builtin_data, np_data, tensor_data - def handle_label(self, label: bytes, **kw: t.Any) -> t.Any: - return label_data - def cmp(self, _data_loader: t.Any) -> t.Any: data = [i for i in _data_loader] assert len(data) == 1 @@ -241,7 +240,7 @@ def cmp(self, _data_loader: t.Any) -> t.Any: assert np.array_equal(y, np_data) assert torch.equal(z, tensor_data) - assert label_data == data[0]["label"] + assert label_data == data[0]["annotations"]["label"] # mock dataset m_summary.return_value = DatasetSummary( @@ -261,10 +260,13 @@ def cmp(self, _data_loader: t.Any) -> t.Any: data_size=784, _swds_bin_offset=0, _swds_bin_size=8160, - label=b"0", + annotations={"label": label_data}, data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, - data_mime_type=MIMEType.UNDEFINED, + data_type={ + "type": ArtifactType.Image.value, + "mime_type": MIMEType.GRAYSCALE.value, + }, auth_name="", ), ] diff --git a/docs/docs/tutorials/pfp.md b/docs/docs/tutorials/pfp.md index b1efcb9ef3..bc0386f15b 100644 --- a/docs/docs/tutorials/pfp.md +++ b/docs/docs/tutorials/pfp.md @@ -82,81 +82,6 @@ In the training section, we use a dataset called [PennFudanPed](https://www.cis. PennFudanPed PennFudanPed.zip ``` -Before version `0.2.x`, Starwhale sliced the dataset into chunks where the batched texts and labels reside. You must tell Starwhale how to yield batches of byte arrays from each dataset file. - -To package images and labels in batch and convert them into byte arrays, we overwrite the `iter_all_dataset_slice` and `iter_all_label_slice` methods of the parent class `BuildExecutor` in the Starwhale SDK. -We package paths of images and labels into `FileBytes` so that it is easier to debug. - -```python -class FileBytes: - def __init__(self, p, byte_array): - self.file_path = p - self.content_bytes = byte_array - - -def _pickle_data(image_file_paths): - all_bytes = [FileBytes(image_f, _image_to_bytes(image_f)) for image_f in - image_file_paths] - return pickle.dumps(all_bytes) - - -def _pickle_label(label_file_paths): - all_bytes = [FileBytes(label_f, _label_to_bytes(label_f)) for label_f in - label_file_paths] - return pickle.dumps(all_bytes) - - -def _label_to_bytes(label_file_path): - img = Image.open(label_file_path) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') - return img_byte_arr.getvalue() - - -def _image_to_bytes(image_file_path): - img = Image.open(image_file_path).convert("RGB") - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') - return img_byte_arr.getvalue() - - -class PennFudanPedSlicer(BuildExecutor): - - def iter_data_slice(self, path: str): - pass - - def iter_label_slice(self, path: str): - pass - - def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]: - datafiles = [p for p in self.iter_data_files()] - idx = 0 - data_size = len(datafiles) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield _pickle_data(datafiles[last_idx:idx]) - - def iter_all_label_slice(self) -> t.Generator[t.Any, None, None]: - labelfiles = [p for p in self.iter_label_files()] - idx = 0 - data_size = len(labelfiles) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield _pickle_label(labelfiles[last_idx:idx]) - - def iter_data_slice(self, path: str): - pass - - def iter_label_slice(self, path: str): - pass -``` - You need to extend the abstract class `BuildExecutor`, so Starwhale can use it. ## Implement the inference method and evaluation metrics computing method diff --git a/docs/docs/tutorials/speech.md b/docs/docs/tutorials/speech.md index 3fd07a8e65..15b01a72ad 100644 --- a/docs/docs/tutorials/speech.md +++ b/docs/docs/tutorials/speech.md @@ -92,7 +92,7 @@ SpeechCommands speech_commands_v0.02.tar.gz Before version `0.2.x`, Starwhale sliced the dataset into chunks where the batched audios and labels reside. You must tell Starwhale how to yield batches of byte arrays from each dataset file. To read all test files in this dataset, we overwrite `load_list` method of the parent class `BuildExecutor` in Starwhale SDK. -To package audios and labels in batches and convert them into byte arrays, we overwrite `iter_all_dataset_slice` and `iter_all_label_slice` methods of the parent class `BuildExecutor` in Starwhale SDK. We package paths of audios into `FileBytes` so that it is easier to debug. +To package audios and labels in batches and convert them into byte arrays, we overwrite `iter_all_data_slice` and `iter_all_label_slice` methods of the parent class `BuildExecutor` in Starwhale SDK. We package paths of audios into `FileBytes` so that it is easier to debug. ```python class FileBytes: @@ -138,7 +138,7 @@ class SpeechCommandsSlicer(BuildExecutor): def iter_label_slice(self, path: str): pass - def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]: + def iter_all_data_slicelf) -> t.Generator[t.Any, None, None]: datafiles = [p for p in self.iter_data_files()] idx = 0 data_size = len(datafiles) diff --git a/example/PennFudanPed/code/data_slicer.py b/example/PennFudanPed/code/data_slicer.py index 7408d6e1c0..5c1e2d62f7 100644 --- a/example/PennFudanPed/code/data_slicer.py +++ b/example/PennFudanPed/code/data_slicer.py @@ -1,77 +1,76 @@ -import io -import pickle import typing as t +from pathlib import Path -from PIL import Image +import numpy as np +import torch +from PIL import Image as PILImage +from pycocotools import mask as coco_mask -from starwhale.api.dataset import BuildExecutor - - -class FileBytes: - def __init__(self, p, byte_array): - self.file_path = p - self.content_bytes = byte_array - - -def _pickle_data(image_file_paths): - all_bytes = [ - FileBytes(image_f, _image_to_bytes(image_f)) for image_f in image_file_paths - ] - return pickle.dumps(all_bytes) - - -def _pickle_label(label_file_paths): - all_bytes = [ - FileBytes(label_f, _label_to_bytes(label_f)) for label_f in label_file_paths - ] - return pickle.dumps(all_bytes) - - -def _label_to_bytes(label_file_path): - img = Image.open(label_file_path) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return img_byte_arr.getvalue() - - -def _image_to_bytes(image_file_path): - img = Image.open(image_file_path).convert("RGB") - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return img_byte_arr.getvalue() +from starwhale.api.dataset import ( + Image, + MIMEType, + BoundingBox, + BuildExecutor, + COCOObjectAnnotation, +) class PennFudanPedSlicer(BuildExecutor): - def iter_data_slice(self, path: str): - pass - - def iter_label_slice(self, path: str): - pass - - def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]: - datafiles = [p for p in self.iter_data_files()] - idx = 0 - data_size = len(datafiles) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield _pickle_data(datafiles[last_idx:idx]) - - def iter_all_label_slice(self) -> t.Generator[t.Any, None, None]: - labelfiles = [p for p in self.iter_label_files()] - idx = 0 - data_size = len(labelfiles) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield _pickle_label(labelfiles[last_idx:idx]) - - def iter_data_slice(self, path: str): - pass - - def iter_label_slice(self, path: str): - pass + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" / "PennFudanPed" + names = [p.stem for p in (root_dir / "PNGImages").iterdir()] + for idx, name in enumerate(names): + data_fpath = root_dir / "PNGImages" / f"{name}.png" + mask_fpath = root_dir / "PedMasks" / f"{name}_mask.png" + height, width = self._get_image_shape(data_fpath) + coco_annotations = self._make_coco_annotations(mask_fpath, idx) + annotations = { + "mask": Image(mask_fpath, display_name=name, mime_type=MIMEType.PNG), + "image": {"id": idx, "height": height, "width": width}, + "object_nums": len(coco_annotations), + "annotations": coco_annotations, + } + data = Image(data_fpath, display_name=name, mime_type=MIMEType.PNG) + yield data, annotations + + def _get_image_shape(self, fpath: Path) -> t.Tuple[int, int]: + with PILImage.open(str(fpath)) as f: + return f.height, f.width + + def _make_coco_annotations( + self, mask_fpath: Path, image_id: int + ) -> t.List[COCOObjectAnnotation]: + mask_img = PILImage.open(str(mask_fpath)) + + mask = np.array(mask_img) + object_ids = np.unique(mask)[1:] + binary_mask = mask == object_ids[:, None, None] + objects_num = len(object_ids) + # TODO: tune permute without pytorch + binary_mask_tensor = torch.as_tensor(binary_mask, dtype=torch.uint8) + binary_mask_tensor = ( + binary_mask_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1) + ) + + coco_annotations = [] + for i in range(0, len(object_ids)): + _pos = np.where(binary_mask[i]) + _xmin, _ymin = np.min(_pos[1]), np.min(_pos[0]) + _xmax, _ymax = np.max(_pos[1]), np.max(_pos[0]) + _bbox = BoundingBox( + x=_xmin, y=_ymin, width=_xmax - _xmin, height=_ymax - _ymin + ) + + coco_annotations.append( + COCOObjectAnnotation( + id=i, + image_id=image_id, + category_id=objects_num, + segmentation=coco_mask.encode(binary_mask_tensor[i].numpy()), # type: ignore + area=_bbox.width * _bbox.height, + bbox=_bbox, + iscrowd=0 if objects_num == 1 else 1, + ) + ) + + return coco_annotations diff --git a/example/PennFudanPed/code/ds.py b/example/PennFudanPed/code/ds.py index 544b12d2fa..5c4791dcce 100644 --- a/example/PennFudanPed/code/ds.py +++ b/example/PennFudanPed/code/ds.py @@ -1,4 +1,5 @@ import os + import numpy as np import torch from PIL import Image @@ -70,7 +71,5 @@ def __getitem__(self, idx): return img, target - - def __len__(self): return len(self.imgs) diff --git a/example/PennFudanPed/code/ppl.py b/example/PennFudanPed/code/ppl.py index 86ada0d5bb..8b171e95c1 100644 --- a/example/PennFudanPed/code/ppl.py +++ b/example/PennFudanPed/code/ppl.py @@ -1,40 +1,24 @@ import io import os import pickle -from pathlib import Path import torch from PIL import Image from torchvision.transforms import functional as F +from starwhale.api.job import Context from starwhale.api.model import PipelineHandler -from . import ds as penn_fudan_ped_ds from . import model as mask_rcnn_model from . import coco_eval, coco_utils _ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) -_DTYPE_DICT_OUTPUT = { - "boxes": torch.float32, - "labels": torch.int64, - "scores": torch.float32, - "masks": torch.uint8, -} -_DTYPE_DICT_LABEL = { - "iscrowd": torch.int64, - "image_id": torch.int64, - "area": torch.float32, - "boxes": torch.float32, - "labels": torch.int64, - "scores": torch.float32, - "masks": torch.uint8, -} class MARSKRCNN(PipelineHandler): - def __init__(self, device="cuda") -> None: - super().__init__(merge_label=True, ignore_error=True) - self.device = torch.device(device) + def __init__(self, context: Context) -> None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + super().__init__(context=context) @torch.no_grad() def ppl(self, data, **kw): @@ -54,26 +38,16 @@ def ppl(self, data, **kw): _result.append(output) return _result - def handle_label(self, label, **kw): - files_bytes = pickle.loads(label) - _result = [] - for idx, file_bytes in enumerate(files_bytes): - image = Image.open(io.BytesIO(file_bytes.content_bytes)) - target = penn_fudan_ped_ds.mask_to_coco_target(image, kw["index"] + idx) - _result.append(target) - return _result - - def cmp(self, _data_loader): - _result, _label = [], [] - for _data in _data_loader: - # _label.extend([self.list_dict_to_tensor_dict(l, True) for l in _data[self._label_field]]) - _label.extend([l for l in _data[self._label_field]]) - (result) = _data[self._ppl_data_field] - _result.extend(result) - ds = zip(_result, _label) + def cmp(self, ppl_result): + result, label = [], [] + for _data in ppl_result: + label.append(_data["annotations"]) + (result) = _data["result"] + result.extend(result) + ds = zip(result, label) coco_ds = coco_utils.convert_to_coco_api(ds) coco_evaluator = coco_eval.CocoEvaluator(coco_ds, ["bbox", "segm"]) - for outputs, targets in zip(_result, _label): + for outputs, targets in zip(result, label): res = {targets["image_id"].item(): outputs} coco_evaluator.update(res) @@ -84,12 +58,10 @@ def cmp(self, _data_loader): coco_evaluator.accumulate() coco_evaluator.summarize() - return [ - { - iou_type: coco_eval.stats.tolist() - for iou_type, coco_eval in coco_evaluator.coco_eval.items() - } - ] + return { + iou_type: coco_eval.stats.tolist() + for iou_type, coco_eval in coco_evaluator.coco_eval.items() + } def _pre(self, input: bytes): image = Image.open(io.BytesIO(input)) diff --git a/example/PennFudanPed/dataset.yaml b/example/PennFudanPed/dataset.yaml index 8b7ad1c8be..308449f6f9 100644 --- a/example/PennFudanPed/dataset.yaml +++ b/example/PennFudanPed/dataset.yaml @@ -1,14 +1,10 @@ name: penn_fudan_ped -data_dir: data -data_filter: "PNGImages/*6.png" -label_filter: "PedMasks/*6_mask.png" - process: code.data_slicer:PennFudanPedSlicer desc: PennFudanPed data and label test dataset tag: - - bin + - bin attr: alignment_size: 4k diff --git a/example/cifar10/code/data_slicer.py b/example/cifar10/code/data_slicer.py index 6c9b4c867c..2e5f310919 100644 --- a/example/cifar10/code/data_slicer.py +++ b/example/cifar10/code/data_slicer.py @@ -1,47 +1,16 @@ import pickle +import typing as t from pathlib import Path from starwhale.api.dataset import BuildExecutor -def unpickle(file): - with open(file, "rb") as fo: - content_dict = pickle.load(fo, encoding="bytes") - return content_dict - - class CIFAR10Slicer(BuildExecutor): - def iter_data_slice(self, path: str): - content_dict = unpickle(path) - data_numpy = content_dict.get(b"data") - idx = 0 - data_size = len(data_numpy) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield data_numpy[last_idx:idx].tobytes() - - def iter_label_slice(self, path: str): - content_dict = unpickle(path) - labels_list = content_dict.get(b"labels") - idx = 0 - data_size = len(labels_list) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield bytes(labels_list[last_idx:idx]) - - -if __name__ == "__main__": - executor = CIFAR10Slicer( - data_dir=Path("../data"), - data_filter="test_batch", - label_filter="test_batch", - alignment_bytes_size=4 * 1024, - volume_bytes_size=4 * 1024 * 1024, - ) - executor.make_swds() + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" + + with (root_dir / "test_batch").open("rb") as f: + content = pickle.load(f, encoding="bytes") + for data, label in zip(content[b"data"], content[b"labels"]): + annotations = {"label": label} + yield data.tobytes(), annotations diff --git a/example/cifar10/code/ppl.py b/example/cifar10/code/ppl.py index b8806f1d94..ca2316f0a8 100644 --- a/example/cifar10/code/ppl.py +++ b/example/cifar10/code/ppl.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import numpy as np @@ -6,6 +5,7 @@ from PIL import Image from torchvision import transforms +from starwhale.api.job import Context from starwhale.api.model import PipelineHandler from starwhale.api.metric import multi_classification @@ -19,9 +19,9 @@ class CIFAR10Inference(PipelineHandler): - def __init__(self, device="cpu") -> None: - super().__init__(merge_label=True, ignore_error=True) - self.device = torch.device(device) + def __init__(self, context: Context) -> None: + super().__init__(context=context) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self._load_model(self.device) def ppl(self, data, **kw): @@ -29,9 +29,6 @@ def ppl(self, data, **kw): output = self.model(data) return self._post(output) - def handle_label(self, label, **kw): - return [int(l) for l in label] - @multi_classification( confusion_matrix_normalize="all", show_hamming_loss=True, @@ -39,14 +36,14 @@ def handle_label(self, label, **kw): show_roc_auc=True, all_labels=[i for i in range(0, 10)], ) - def cmp(self, _data_loader): - _result, _label, _pr = [], [], [] - for _data in _data_loader: - _label.extend([int(l) for l in _data[self._label_field]]) - (pred, pr) = _data[self._ppl_data_field] - _result.extend([int(l) for l in pred]) - _pr.extend([l for l in pr]) - return _label, _result, _pr + def cmp(self, ppl_result): + result, label, pr = [], [], [] + for _data in ppl_result: + label.append(_data["annotations"]["label"]) + (pred, pr) = _data["result"] + result.extend(pred) + pr.extend(pr) + return label, result, pr def _pre(self, input: bytes): batch_size = 1 diff --git a/example/cifar10/dataset.yaml b/example/cifar10/dataset.yaml index fc85a10001..4bc9c939f6 100644 --- a/example/cifar10/dataset.yaml +++ b/example/cifar10/dataset.yaml @@ -1,15 +1,11 @@ name: cifar10 -data_dir: data -data_filter: "test_batch" -label_filter: "test_batch" - process: code.data_slicer:CIFAR10Slicer pip_req: requirements.txt desc: CIFAR10 data and label test dataset tag: - - bin + - bin attr: alignment_size: 4k diff --git a/example/mnist/dataset.yaml b/example/mnist/dataset.yaml index 046fb9c337..861cf86a16 100644 --- a/example/mnist/dataset.yaml +++ b/example/mnist/dataset.yaml @@ -3,8 +3,8 @@ name: mnist data_dir: data data_filter: "t10k-image*" label_filter: "t10k-label*" -#process: mnist.process:DataSetProcessExecutor -process: mnist.process:RawDataSetProcessExecutor +process: mnist.process:DataSetProcessExecutor +#process: mnist.process:RawDataSetProcessExecutor #process: mnist.process:LinkRawDataSetProcessExecutor desc: MNIST data and label test dataset diff --git a/example/mnist/mnist/ppl.py b/example/mnist/mnist/ppl.py index 4ad8ccd3ff..0f6c4f499b 100644 --- a/example/mnist/mnist/ppl.py +++ b/example/mnist/mnist/ppl.py @@ -1,14 +1,14 @@ -import os from pathlib import Path import numpy as np import torch -from PIL import Image +from PIL import Image as PILImage from torchvision import transforms from starwhale.api.job import Context from starwhale.api.model import PipelineHandler from starwhale.api.metric import multi_classification +from starwhale.api.dataset import Image try: from .model import Net @@ -16,24 +16,19 @@ from model import Net ROOTDIR = Path(__file__).parent.parent -IMAGE_WIDTH = 28 -ONE_IMAGE_SIZE = IMAGE_WIDTH * IMAGE_WIDTH class MNISTInference(PipelineHandler): - def __init__(self, context: Context, device="cpu") -> None: - super().__init__(context=context, merge_label=True, ignore_error=True) - self.device = torch.device(device) + def __init__(self, context: Context) -> None: + super().__init__(context=context) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self._load_model(self.device) - def ppl(self, data, **kw): - data = self._pre(data) - output = self.model(data) + def ppl(self, img: Image, **kw): + data_tensor = self._pre(img) + output = self.model(data_tensor) return self._post(output) - def handle_label(self, label, **kw): - return int(label) - @multi_classification( confusion_matrix_normalize="all", show_hamming_loss=True, @@ -41,21 +36,19 @@ def handle_label(self, label, **kw): show_roc_auc=True, all_labels=[i for i in range(0, 10)], ) - def cmp(self, _data_loader): - _result, _label, _pr = [], [], [] - for _data in _data_loader: - _label.append(_data[self._label_field]) - # unpack data according to the return value of function ppl - (pred, pr) = _data[self._ppl_data_field] - _result.extend(pred) - _pr.extend(pr) - return _label, _result, _pr + def cmp(self, ppl_result): + result, label, pr = [], [], [] + for _data in ppl_result: + label.append(_data["annotations"]["label"]) + result.extend(_data["result"][0]) + pr.extend(_data["result"][1]) + return label, result, pr - def _pre(self, input: bytes): - _tensor = torch.tensor(bytearray(input), dtype=torch.uint8).reshape( - IMAGE_WIDTH, IMAGE_WIDTH + def _pre(self, input: Image): + _tensor = torch.tensor(bytearray(input.to_bytes()), dtype=torch.uint8).reshape( + input.shape[0], input.shape[1] # type: ignore ) - _image_array = Image.fromarray(_tensor.numpy()) + _image_array = PILImage.fromarray(_tensor.numpy()) _image = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] )(_image_array) @@ -79,13 +72,9 @@ def _load_model(self, device): context = Context( workdir=Path("."), - dataset_uris=["mnist/version/latest"], + dataset_uris=["mnist/version/small"], project="self", version="latest", - kw={ - "status_dir": "/tmp/mnist/status", - "log_dir": "/tmp/mnist/log", - }, ) mnist = MNISTInference(context) mnist._starwhale_internal_run_ppl() diff --git a/example/mnist/mnist/process.py b/example/mnist/mnist/process.py index fdf7e83128..e4f2df4a80 100644 --- a/example/mnist/mnist/process.py +++ b/example/mnist/mnist/process.py @@ -1,82 +1,91 @@ import struct +import typing as t from pathlib import Path from starwhale.api.dataset import ( Link, - MIMEType, S3LinkAuth, + GrayscaleImage, SWDSBinBuildExecutor, UserRawBuildExecutor, ) -def _do_iter_label_slice(path: str): - fpath = Path(path) - - with fpath.open("rb") as f: - _, number = struct.unpack(">II", f.read(8)) - print(f">label({fpath.name}) split {number} group") - - while True: - content = f.read(1) - if not content: - break - yield struct.unpack(">B", content)[0] - - class DataSetProcessExecutor(SWDSBinBuildExecutor): - def iter_data_slice(self, path: str): - fpath = Path(path) - - with fpath.open("rb") as f: - _, number, height, width = struct.unpack(">IIII", f.read(16)) - print(f">data({fpath.name}) split {number} group") - - while True: - content = f.read(height * width) - if not content: - break - yield content + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" + + with (root_dir / "t10k-images-idx3-ubyte").open("rb") as data_file, ( + root_dir / "t10k-labels-idx1-ubyte" + ).open("rb") as label_file: + _, data_number, height, width = struct.unpack(">IIII", data_file.read(16)) + _, label_number = struct.unpack(">II", label_file.read(8)) + print( + f">data({data_file.name}) split data:{data_number}, label:{label_number} group" + ) + image_size = height * width - def iter_label_slice(self, path: str): - return _do_iter_label_slice(path) + for i in range(0, min(data_number, label_number)): + _data = data_file.read(image_size) + _label = struct.unpack(">B", label_file.read(1))[0] + yield GrayscaleImage( + _data, + display_name=f"{i}", + shape=(height, width, 1), + ), {"label": _label} class RawDataSetProcessExecutor(UserRawBuildExecutor): - def iter_data_slice(self, path: str): - fpath = Path(path) + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" + data_fpath = root_dir / "t10k-images-idx3-ubyte" + label_fpath = root_dir / "t10k-labels-idx1-ubyte" - with fpath.open("rb") as f: - _, number, height, width = struct.unpack(">IIII", f.read(16)) - size = height * width - offset = 16 - - for _ in range(number): - yield Link(offset=offset, size=size, mime_type=MIMEType.GRAYSCALE) - offset += size - - def iter_label_slice(self, path: str): - return _do_iter_label_slice(path) + with data_fpath.open("rb") as data_file, label_fpath.open("rb") as label_file: + _, data_number, height, width = struct.unpack(">IIII", data_file.read(16)) + _, label_number = struct.unpack(">II", label_file.read(8)) + image_size = height * width + offset = 16 -class LinkRawDataSetProcessExecutor(RawDataSetProcessExecutor): + for i in range(0, min(data_number, label_number)): + _data = Link( + uri=str(data_fpath.absolute()), + offset=offset, + size=image_size, + data_type=GrayscaleImage( + display_name=f"{i}", shape=(height, width, 1) + ), + with_local_fs_data=True, + ) + _label = struct.unpack(">B", label_file.read(1))[0] + yield _data, {"label": _label} + offset += image_size + + +class LinkRawDataSetProcessExecutor(UserRawBuildExecutor): _auth = S3LinkAuth(name="mnist", access_key="minioadmin", secret="minioadmin") _endpoint = "10.131.0.1:9000" _bucket = "users" - def iter_all_dataset_slice(self): - offset = 16 - size = 28 * 28 - uri = ( - f"s3://{self._endpoint}@{self._bucket}/dataset/mnist/t10k-images-idx3-ubyte" - ) - for _ in range(10000): - link = Link( - f"{uri}", - self._auth, - offset=offset, - size=size, - mime_type=MIMEType.GRAYSCALE, - ) - yield link - offset += size + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" + + with (root_dir / "t10k-labels-idx1-ubyte").open("rb") as label_file: + _, label_number = struct.unpack(">II", label_file.read(8)) + + offset = 16 + image_size = 28 * 28 + + uri = f"s3://{self._endpoint}@{self._bucket}/dataset/mnist/t10k-images-idx3-ubyte" + for i in range(label_number): + _data = Link( + f"{uri}", + self._auth, + offset=offset, + size=image_size, + data_type=GrayscaleImage(display_name=f"{i}", shape=(28, 28, 1)), + ) + _label = struct.unpack(">B", label_file.read(1))[0] + yield _data, {"label": _label} + offset += image_size diff --git a/example/nmt/code/dataset.py b/example/nmt/code/dataset.py index ba5af87895..12ff9fef5b 100644 --- a/example/nmt/code/dataset.py +++ b/example/nmt/code/dataset.py @@ -1,58 +1,22 @@ -from io import open -from time import process_time_ns +import typing as t +from pathlib import Path -from starwhale.api.dataset import BuildExecutor +from starwhale.api.dataset import Text, BuildExecutor -try: - from .helper import filterComment, normalizeString -except ImportError: - from helper import filterComment, normalizeString - - -def prepareData(path): - print("preapring data...") - # Read the file and split into lines - lines = open(path, encoding="utf-8").read().strip().split("\n") - - # Split every line into pairs and normalize - pairs = [ - [normalizeString(s) for s in l.split("\t") if not filterComment(s) and s] - for l in lines - ] - - return pairs +from .helper import normalizeString class DataSetProcessExecutor(BuildExecutor): - def iter_data_slice(self, path: str): - pairs = prepareData(path) - index = 0 - lines = len(pairs) - while True: - last_index = index - index += 1 - index = min(index, lines - 1) - print("data:%s, %s" % (last_index, index)) - data_batch = [src for src, tgt in pairs[last_index:index]] - join = "\n".join(data_batch) - - print("res-data:%s" % join) - yield join.encode() - if index >= lines - 1: - break - - def iter_label_slice(self, path: str): - pairs = prepareData(path) - index = 0 - lines = len(pairs) - while True: - last_index = index - index += 1 - index = min(index, lines - 1) - print("label:%s, %s" % (last_index, index)) - data_batch = [tgt for src, tgt in pairs[last_index:index]] - join = "\n".join(data_batch) - print("res-label:%s" % join) - yield join.encode() - if index >= lines - 1: - break + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" + + with (root_dir / "test_eng-fra.txt").open("r") as f: + for line in f.readlines(): + line = line.strip() + if not line or line.startswith("CC-BY"): + continue + + _data, _label = line.split("\t", 1) + data = Text(normalizeString(_data), encoding="utf-8") + annotations = {"label": normalizeString(_label)} + yield data, annotations diff --git a/example/nmt/code/helper.py b/example/nmt/code/helper.py index 94f6f8256c..27646b9185 100644 --- a/example/nmt/code/helper.py +++ b/example/nmt/code/helper.py @@ -1,18 +1,20 @@ -import time +import re import math +import time import unicodedata + import torch -import re + try: - from .config import MAX_LENGTH, EOS_token + from .config import EOS_token, MAX_LENGTH except ImportError: - from config import MAX_LENGTH, EOS_token + from config import EOS_token, MAX_LENGTH def asMinutes(s): m = math.floor(s / 60) s -= m * 60 - return '%dm %ds' % (m, s) + return "%dm %ds" % (m, s) def timeSince(since, percent): @@ -20,24 +22,31 @@ def timeSince(since, percent): s = now - since es = s / (percent) rs = es - s - return '%s (- %s)' % (asMinutes(s), asMinutes(rs)) - + return "%s (- %s)" % (asMinutes(s), asMinutes(rs)) eng_prefixes = ( - "i am ", "i m ", - "he is", "he s ", - "she is", "she s ", - "you are", "you re ", - "we are", "we re ", - "they are", "they re " + "i am ", + "i m ", + "he is", + "he s ", + "she is", + "she s ", + "you are", + "you re ", + "we are", + "we re ", + "they are", + "they re ", ) def filterPair(p): - return len(p[0].split(' ')) < MAX_LENGTH and \ - len(p[1].split(' ')) < MAX_LENGTH and \ - p[1].startswith(eng_prefixes) + return ( + len(p[0].split(" ")) < MAX_LENGTH + and len(p[1].split(" ")) < MAX_LENGTH + and p[1].startswith(eng_prefixes) + ) def filterPairs(pairs): @@ -45,7 +54,7 @@ def filterPairs(pairs): def indexesFromSentence(lang, sentence): - return [lang.word2index[word] for word in sentence.split(' ')] + return [lang.word2index[word] for word in sentence.split(" ")] def tensorFromSentence(lang, sentence, device): @@ -63,11 +72,11 @@ def tensorsFromPair(input_lang, output_lang, pair, device): # Turn a Unicode string to plain ASCII, thanks to # https://stackoverflow.com/a/518232/2809427 def unicodeToAscii(s): - return ''.join( - c for c in unicodedata.normalize('NFD', s) - if unicodedata.category(c) != 'Mn' + return "".join( + c for c in unicodedata.normalize("NFD", s) if unicodedata.category(c) != "Mn" ) + # Lowercase, trim, and remove non-letter characters @@ -77,5 +86,6 @@ def normalizeString(s): s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) return s + def filterComment(s): - return s.startswith("CC-BY") \ No newline at end of file + return s.startswith("CC-BY") diff --git a/example/nmt/code/ppl.py b/example/nmt/code/ppl.py index bfb8322c36..c9d7802348 100644 --- a/example/nmt/code/ppl.py +++ b/example/nmt/code/ppl.py @@ -1,27 +1,26 @@ import os import torch -from regex import W +from starwhale.api.job import Context from starwhale.api.model import PipelineHandler try: from .eval import evaluate_batch from .model import EncoderRNN, AttnDecoderRNN - from .vocab import Lang, Vocab from .calculatebleu import BLEU except ImportError: from eval import evaluate_batch from model import EncoderRNN, AttnDecoderRNN - from vocab import Lang, Vocab from calculatebleu import BLEU _ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) class NMTPipeline(PipelineHandler): - def __init__(self) -> None: - super().__init__(merge_label=True, ignore_error=True) + def __init__(self, context: Context) -> None: + super().__init__(context=context) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.vocab = self._load_vocab() self.encoder = self._load_encoder_model(self.device) @@ -41,20 +40,14 @@ def ppl(self, data, **kw): self.decoder, ) - def handle_label(self, label, **kw): - labels = label.decode().split("\n") - print("src labels: %s" % len(labels)) - return labels - def cmp(self, _data_loader): _result, _label = [], [] for _data in _data_loader: - _label.extend(_data[self._label_field]) - (result) = _data[self._ppl_data_field] + _label.extend(_data["annotations"]["label"]) + (result) = _data["result"] _result.extend(result) bleu = BLEU(_result, [_label]) - return {"summary": {"bleu_score": bleu}} def _load_vocab(self): diff --git a/example/nmt/dataset.yaml b/example/nmt/dataset.yaml index 4485a98f53..ccfbbe9375 100644 --- a/example/nmt/dataset.yaml +++ b/example/nmt/dataset.yaml @@ -1,14 +1,10 @@ name: nmt -data_dir: data -data_filter: "test_eng-fra.txt" -label_filter: "test_eng-fra.txt" - process: code.dataset:DataSetProcessExecutor desc: nmt data and label test dataset tag: - - bin + - bin attr: alignment_size: 4K diff --git a/example/speech_command/code/data_slicer.py b/example/speech_command/code/data_slicer.py index 0812b2efcf..d3921bccf5 100644 --- a/example/speech_command/code/data_slicer.py +++ b/example/speech_command/code/data_slicer.py @@ -1,72 +1,27 @@ -import os -import pickle import typing as t from pathlib import Path -from starwhale.api.dataset import BuildExecutor - - -class FileBytes: - def __init__(self, p): - self.file_path = p - self.content_bytes = open(p, "rb").read() - - -def _pickle_data(audio_file_paths): - all_bytes = [FileBytes(audio_f) for audio_f in audio_file_paths] - return pickle.dumps(all_bytes) - - -def _pickle_label(audio_file_paths): - all_strings = [ - os.path.basename(os.path.dirname(str(audio_f))) for audio_f in audio_file_paths - ] - return pickle.dumps(all_strings) +from starwhale.api.dataset import Audio, MIMEType, BuildExecutor class SpeechCommandsSlicer(BuildExecutor): - def load_list(self, file_filter): - filepath = self.data_dir / file_filter - with open(filepath) as fileobj: - return [self.data_dir / line.strip() for line in fileobj] - - def _iter_files( - self, file_filter: str, sort_key: t.Optional[t.Any] = None - ) -> t.Generator[Path, None, None]: - _key = sort_key - if _key is not None and not callable(_key): - raise Exception(f"data_sort_func({_key}) is not callable.") - - _files = sorted(self.load_list(file_filter), key=_key) - for p in _files: - if not p.is_file(): - continue - yield p - - def iter_data_slice(self, path: str): - pass - - def iter_label_slice(self, path: str): - pass - - def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]: - datafiles = [p for p in self.iter_data_files()] - idx = 0 - data_size = len(datafiles) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield _pickle_data(datafiles[last_idx:idx]) - - def iter_all_label_slice(self) -> t.Generator[t.Any, None, None]: - datafiles = [p for p in self.iter_data_files()] - idx = 0 - data_size = len(datafiles) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - yield _pickle_label(datafiles[last_idx:idx]) + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + dataset_dir = ( + Path(__file__).parent.parent + / "data" + / "SpeechCommands" + / "speech_commands_v0.02" + ) + + with (dataset_dir / "testing_list.txt").open() as f: + for item in f.readlines(): + item = item.strip() + if not item: + continue + + data_path = dataset_dir / item + data = Audio( + data_path, display_name=item, shape=(1,), mime_type=MIMEType.WAV + ) + annotations = {"label": data_path.parent.name} + yield data, annotations diff --git a/example/speech_command/code/ppl.py b/example/speech_command/code/ppl.py index f5f0618fe8..471714b64a 100644 --- a/example/speech_command/code/ppl.py +++ b/example/speech_command/code/ppl.py @@ -6,6 +6,7 @@ import torch import torchaudio +from starwhale.api.job import Context from starwhale.api.model import PipelineHandler from starwhale.api.metric import multi_classification @@ -53,12 +54,12 @@ class M5Inference(PipelineHandler): - def __init__(self, device="cpu") -> None: - super().__init__(merge_label=True, ignore_error=True) - self.device = torch.device(device) + def __init__(self, context: Context) -> None: + super().__init__(context=context) + self.device = torch.device("cpu") self.model = self._load_model(self.device) self.transform = torchaudio.transforms.Resample(orig_freq=16000, new_freq=8000) - self.transform = self.transform.to(device) + self.transform = self.transform.to("cpu") def ppl(self, data, **kw): audios = self._pre(data) @@ -71,9 +72,6 @@ def ppl(self, data, **kw): result.append("ERROR") return result - def handle_label(self, label, **kw): - return pickle.loads(label) - @multi_classification( confusion_matrix_normalize="all", show_hamming_loss=True, @@ -81,14 +79,13 @@ def handle_label(self, label, **kw): show_roc_auc=False, all_labels=labels, ) - def cmp(self, _data_loader): - _result, _label, _pr = [], [], [] - for _data in _data_loader: - _label.extend(_data[self._label_field]) - (result) = _data[self._ppl_data_field] - _result.extend(result) - # _pr.extend(_data["pr"]) - return _result, _label + def cmp(self, ppl_result): + result, label = [], [] + for _data in ppl_result: + label.append(_data["annotations"]["label"]) + (result) = _data["result"] + result.extend(result) + return result, label def _pre(self, input: bytes): audios = pickle.loads(input) diff --git a/example/speech_command/dataset.yaml b/example/speech_command/dataset.yaml index 2e3182a8af..cb28c60eaf 100644 --- a/example/speech_command/dataset.yaml +++ b/example/speech_command/dataset.yaml @@ -1,14 +1,10 @@ name: SpeechCommands -data_dir: data/SpeechCommands/speech_commands_v0.02 -data_filter: "testing_list.txt" -label_filter: "testing_list.txt" - process: code.data_slicer:SpeechCommandsSlicer desc: SpeechCommands data and label test dataset tag: - - bin + - bin attr: alignment_size: 4k diff --git a/example/text_cls_AG_NEWS/code/data_slicer.py b/example/text_cls_AG_NEWS/code/data_slicer.py index 63d8793bd4..84596981df 100644 --- a/example/text_cls_AG_NEWS/code/data_slicer.py +++ b/example/text_cls_AG_NEWS/code/data_slicer.py @@ -1,25 +1,18 @@ -from starwhale.api.dataset import BuildExecutor +import re +import csv +import typing as t +from pathlib import Path -from . import ag_news +from starwhale.api.dataset import Text, BuildExecutor -def yield_data(path, label=False): - data = ag_news.load_ag_data(path) - idx = 0 - data_size = len(data) - while True: - last_idx = idx - idx += 1 - if idx > data_size: - break - data_batch = [lbl if label else txt for lbl, txt in data[last_idx:idx]] - join = "#@#@#@#".join(data_batch) - yield join.encode() +class AGNewsSlicer(BuildExecutor): + def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]: + root_dir = Path(__file__).parent.parent / "data" - -class AGNEWSSlicer(BuildExecutor): - def iter_data_slice(self, path: str): - yield from yield_data(path) - - def iter_label_slice(self, path: str): - yield from yield_data(path, True) + with (root_dir / "test.csv").open("r", encoding="utf-8") as f: + for row in csv.reader(f, delimiter=",", quotechar='"'): + annotations = {"label": row[0]} + data = " ".join(row[1:]) + data = re.sub("^\s*(.-)\s*$", "%1", data).replace("\\n", "\n") + yield Text(content=data), annotations diff --git a/example/text_cls_AG_NEWS/code/ppl.py b/example/text_cls_AG_NEWS/code/ppl.py index a491063768..316adfc818 100644 --- a/example/text_cls_AG_NEWS/code/ppl.py +++ b/example/text_cls_AG_NEWS/code/ppl.py @@ -3,27 +3,24 @@ import torch from torchtext.data.utils import get_tokenizer +from starwhale.api.job import Context from starwhale.api.model import PipelineHandler from starwhale.api.metric import multi_classification try: - from . import predict -except ImportError: - import predict - -try: - from . import model + from . import model, predict except ImportError: import model + import predict _ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) class TextClassificationHandler(PipelineHandler): - def __init__(self, device="cpu") -> None: - super().__init__(merge_label=True, ignore_error=True) - self.device = torch.device(device) + def __init__(self, context: Context) -> None: + super().__init__(context=context) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @torch.no_grad() def ppl(self, data, **kw): @@ -33,10 +30,6 @@ def ppl(self, data, **kw): map(lambda text: predict.predict(text, _model, vocab, tokenizer, 2), texts) ) - def handle_label(self, label, **kw): - labels = label.decode().split("#@#@#@#") - return [int(label) for label in labels] - @multi_classification( confusion_matrix_normalize="all", show_hamming_loss=True, @@ -44,14 +37,13 @@ def handle_label(self, label, **kw): show_roc_auc=False, all_labels=[i for i in range(1, 5)], ) - def cmp(self, _data_loader): - _result, _label = [], [] - for _data in _data_loader: - print(_data) - _label.extend([int(l) for l in _data[self._label_field]]) - (result) = _data[self._ppl_data_field] - _result.extend([int(r) for r in result]) - return _label, _result + def cmp(self, ppl_result): + result, label = [], [] + for _data in ppl_result: + label.append(_data["annotations"]["label"]) + (result) = _data["result"] + result.extend([int(r) for r in result]) + return label, result def _load_model(self, device): model_path = _ROOT_DIR + "/models/model.i" diff --git a/example/text_cls_AG_NEWS/dataset.yaml b/example/text_cls_AG_NEWS/dataset.yaml index e6ed0b3c1a..1916d08ce1 100644 --- a/example/text_cls_AG_NEWS/dataset.yaml +++ b/example/text_cls_AG_NEWS/dataset.yaml @@ -1,15 +1,10 @@ name: AG_NEWS -data_dir: data -data_filter: "test.csv" -label_filter: "test.csv" - -process: code.data_slicer:AGNEWSSlicer -pip_req: requirements.txt +process: code.data_slicer:AGNewsSlicer desc: AG_NEWS data and label test dataset tag: - - bin + - bin attr: alignment_size: 4k diff --git a/scripts/run_demo.sh b/scripts/run_demo.sh index 1c9c66ef34..836f771a2d 100755 --- a/scripts/run_demo.sh +++ b/scripts/run_demo.sh @@ -93,7 +93,7 @@ length_must_equal() { build_rc_and_check() { length_must_equal 0 "$1" - swcli "$1" build . + swcli -vvv "$1" build . length_must_equal 1 "$1" } @@ -103,7 +103,7 @@ build_rc_and_check dataset echo "do ppl and cmp" length_must_equal 0 eval "job list" -swcli eval run --model mnist/version/latest --dataset mnist/version/latest +swcli -vvv eval run --model mnist/version/latest --dataset mnist/version/latest length_must_equal 1 eval "job list" #echo "check result"