From 99961673275470a099f853d147f0a0fc8ebeffda Mon Sep 17 00:00:00 2001 From: tianwei Date: Sun, 28 Aug 2022 17:10:53 +0800 Subject: [PATCH] feat(dataset): support dataset mime_type config (#1033) support dataset mime_type config --- client/starwhale/api/_impl/dataset/builder.py | 48 ++++++++++++++----- client/starwhale/api/_impl/dataset/mnist.py | 2 +- client/starwhale/core/dataset/model.py | 5 +- client/starwhale/core/dataset/type.py | 10 +++- client/tests/sdk/test_dataset.py | 12 +++-- example/mnist/dataset.yaml | 5 +- example/mnist/mnist/process.py | 2 +- 7 files changed, 61 insertions(+), 23 deletions(-) diff --git a/client/starwhale/api/_impl/dataset/builder.py b/client/starwhale/api/_impl/dataset/builder.py index d779afd7c6..dfcc0c2743 100644 --- a/client/starwhale/api/_impl/dataset/builder.py +++ b/client/starwhale/api/_impl/dataset/builder.py @@ -15,11 +15,12 @@ from starwhale.base.uri import URI from starwhale.utils.fs import empty_dir, ensure_dir from starwhale.base.type import DataFormatType, DataOriginType, ObjectStoreType -from starwhale.utils.error import FormatError +from starwhale.utils.error import FormatError, NoSupportError from starwhale.core.dataset import model from starwhale.core.dataset.type import ( Link, LinkAuth, + MIMEType, DatasetSummary, D_ALIGNMENT_SIZE, D_FILE_VOLUME_SIZE, @@ -53,6 +54,7 @@ def __init__( append: bool = False, append_from_version: str = "", append_from_uri: t.Optional[URI] = None, + data_mime_type: MIMEType = MIMEType.UNDEFINED, ) -> None: # TODO: add more docstring for args # TODO: validate group upper and lower? @@ -70,6 +72,7 @@ def __init__( self.alignment_bytes_size = alignment_bytes_size self.volume_bytes_size = volume_bytes_size + self.default_data_mime_type = data_mime_type self.project_name = project_name self.dataset_name = dataset_name @@ -237,10 +240,15 @@ def make_swds(self) -> DatasetSummary: zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()), start=self._forked_last_idx + 1, ): - if not isinstance(data, bytes): - raise FormatError("data must be bytes type") + if isinstance(data, (tuple, list)): + _data_content, _data_mime_type = data + else: + _data_content, _data_mime_type = data, self.default_data_mime_type + + if not isinstance(_data_content, bytes): + raise FormatError("data content must be bytes type") - data_offset, data_size = self._write(dwriter, data) + data_offset, data_size = self._write(dwriter, _data_content) self.tabular_dataset.put( TabularDatasetRow( id=idx, @@ -251,6 +259,7 @@ def make_swds(self) -> DatasetSummary: data_offset=data_offset, data_size=data_size, data_origin=DataOriginType.NEW, + data_mime_type=_data_mime_type or self.default_data_mime_type, ) ) @@ -344,20 +353,32 @@ def make_swds(self) -> DatasetSummary: start=self._forked_last_idx + 1, ): if isinstance(data, Link): - data_uri = data.uri - data_offset, data_size = data.offset, data.size - if data.auth: - auth = data.auth.name - auth_candidates[f"{data.auth.ltype}.{data.auth.name}"] = data.auth + _remote_link = data + data_uri = _remote_link.uri + data_offset, data_size = _remote_link.offset, _remote_link.size + if _remote_link.auth: + auth = _remote_link.auth.name + auth_candidates[ + f"{_remote_link.auth.ltype}.{_remote_link.auth.name}" + ] = _remote_link.auth else: auth = "" object_store_type = ObjectStoreType.REMOTE include_link = True + data_mime_type = _remote_link.mime_type elif isinstance(data, (tuple, list)): - data_path, (data_offset, data_size) = data - if data_path not in map_path_sign: - map_path_sign[data_path] = DatasetStorage.save_data_file(data_path) - data_uri, _ = map_path_sign[data_path] + _data_fpath, _local_link = data + if _data_fpath not in map_path_sign: + map_path_sign[_data_fpath] = DatasetStorage.save_data_file( + _data_fpath + ) + + if not isinstance(_local_link, Link): + raise NoSupportError("data only support Link type") + + data_mime_type = _local_link.mime_type + data_offset, data_size = _local_link.offset, _local_link.size + data_uri, _ = map_path_sign[_data_fpath] auth = "" object_store_type = ObjectStoreType.LOCAL else: @@ -374,6 +395,7 @@ def make_swds(self) -> DatasetSummary: data_size=data_size, data_origin=DataOriginType.NEW, auth_name=auth, + data_mime_type=data_mime_type, ) ) diff --git a/client/starwhale/api/_impl/dataset/mnist.py b/client/starwhale/api/_impl/dataset/mnist.py index 666feb466e..00953b6d55 100644 --- a/client/starwhale/api/_impl/dataset/mnist.py +++ b/client/starwhale/api/_impl/dataset/mnist.py @@ -30,4 +30,4 @@ def iter_label_slice(self, path: str) -> t.Generator[bytes, None, None]: content = f.read(1) if not content: break - yield content + yield struct.unpack(">B", content)[0] diff --git a/client/starwhale/core/dataset/model.py b/client/starwhale/core/dataset/model.py index 41e6bf20c2..14478c2be8 100644 --- a/client/starwhale/core/dataset/model.py +++ b/client/starwhale/core/dataset/model.py @@ -290,7 +290,9 @@ def _call_make_swds( # TODO: add more import format support, current is module:class logger.info(f"[info:swds]try to import {swds_config.process} @ {workdir}") - _cls = import_cls(workdir, swds_config.process, BaseBuildExecutor) + _cls: t.Type[BaseBuildExecutor] = import_cls( + workdir, swds_config.process, BaseBuildExecutor + ) with _cls( dataset_name=self.uri.object.name, @@ -305,6 +307,7 @@ def _call_make_swds( append=append, append_from_version=append_from_version, append_from_uri=append_from_uri, + data_mime_type=swds_config.attr.data_mime_type, ) as _obj: console.print( f":ghost: import [red]{swds_config.process}@{workdir.resolve()}[/] to make swds..." diff --git a/client/starwhale/core/dataset/type.py b/client/starwhale/core/dataset/type.py index ccd2a52691..ed9e2e50c4 100644 --- a/client/starwhale/core/dataset/type.py +++ b/client/starwhale/core/dataset/type.py @@ -160,7 +160,7 @@ def create_by_file_suffix(cls, name: str) -> MIMEType: class Link: def __init__( self, - uri: str, + uri: str = "", auth: t.Optional[LinkAuth] = DefaultS3LinkAuth, offset: int = FilePosition.START, size: int = -1, @@ -235,15 +235,21 @@ def __init__( self, volume_size: t.Union[int, str] = D_FILE_VOLUME_SIZE, alignment_size: t.Union[int, str] = D_ALIGNMENT_SIZE, + data_mime_type: MIMEType = MIMEType.UNDEFINED, **kw: t.Any, ) -> None: self.volume_size = convert_to_bytes(volume_size) self.alignment_size = convert_to_bytes(alignment_size) + self.data_mime_type = data_mime_type self.kw = kw def as_dict(self) -> t.Dict[str, t.Any]: + # TODO: refactor an asdict mixin class _rd = deepcopy(self.__dict__) - _rd.pop("kw") + _rd.pop("kw", None) + for k, v in _rd.items(): + if isinstance(v, Enum): + _rd[k] = v.value return _rd diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index 3585a50515..aa8f83af59 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -1,9 +1,15 @@ import os import json +import struct from pathlib import Path from starwhale.utils.fs import ensure_dir, blake2b_file -from starwhale.api.dataset import MNISTBuildExecutor, UserRawBuildExecutor +from starwhale.api.dataset import ( + Link, + MIMEType, + MNISTBuildExecutor, + UserRawBuildExecutor, +) from starwhale.core.dataset.store import DatasetStorage from starwhale.core.dataset.tabular import TabularDataset from starwhale.api._impl.dataset.builder import ( @@ -27,7 +33,7 @@ def iter_data_slice(self, path: str): file_size = Path(path).stat().st_size offset = 16 while offset < file_size: - yield offset, size + yield Link(offset=offset, size=size, mime_type=MIMEType.GRAYSCALE) offset += size def iter_label_slice(self, path: str): @@ -39,7 +45,7 @@ def iter_label_slice(self, path: str): content = f.read(1) if not content: break - yield content + yield struct.unpack(">B", content)[0] class TestDatasetBuildExecutor(BaseTestCase): diff --git a/example/mnist/dataset.yaml b/example/mnist/dataset.yaml index 40bf721db6..046fb9c337 100644 --- a/example/mnist/dataset.yaml +++ b/example/mnist/dataset.yaml @@ -12,5 +12,6 @@ tag: - bin attr: - alignment_size: 4k - volume_size: 2M + alignment_size: 1k + volume_size: 4M + data_mime_type: "x/grayscale" diff --git a/example/mnist/mnist/process.py b/example/mnist/mnist/process.py index 635dc200d8..fdf7e83128 100644 --- a/example/mnist/mnist/process.py +++ b/example/mnist/mnist/process.py @@ -52,7 +52,7 @@ def iter_data_slice(self, path: str): offset = 16 for _ in range(number): - yield offset, size + yield Link(offset=offset, size=size, mime_type=MIMEType.GRAYSCALE) offset += size def iter_label_slice(self, path: str):