Skip to content

Commit

Permalink
feat(dataset): support dataset mime_type config (#1033)
Browse files Browse the repository at this point in the history
support dataset mime_type config
  • Loading branch information
tianweidut authored Aug 28, 2022
1 parent c88a05b commit 9996167
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 23 deletions.
48 changes: 35 additions & 13 deletions client/starwhale/api/_impl/dataset/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from starwhale.base.uri import URI
from starwhale.utils.fs import empty_dir, ensure_dir
from starwhale.base.type import DataFormatType, DataOriginType, ObjectStoreType
from starwhale.utils.error import FormatError
from starwhale.utils.error import FormatError, NoSupportError
from starwhale.core.dataset import model
from starwhale.core.dataset.type import (
Link,
LinkAuth,
MIMEType,
DatasetSummary,
D_ALIGNMENT_SIZE,
D_FILE_VOLUME_SIZE,
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
append: bool = False,
append_from_version: str = "",
append_from_uri: t.Optional[URI] = None,
data_mime_type: MIMEType = MIMEType.UNDEFINED,
) -> None:
# TODO: add more docstring for args
# TODO: validate group upper and lower?
Expand All @@ -70,6 +72,7 @@ def __init__(

self.alignment_bytes_size = alignment_bytes_size
self.volume_bytes_size = volume_bytes_size
self.default_data_mime_type = data_mime_type

self.project_name = project_name
self.dataset_name = dataset_name
Expand Down Expand Up @@ -237,10 +240,15 @@ def make_swds(self) -> DatasetSummary:
zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()),
start=self._forked_last_idx + 1,
):
if not isinstance(data, bytes):
raise FormatError("data must be bytes type")
if isinstance(data, (tuple, list)):
_data_content, _data_mime_type = data
else:
_data_content, _data_mime_type = data, self.default_data_mime_type

if not isinstance(_data_content, bytes):
raise FormatError("data content must be bytes type")

data_offset, data_size = self._write(dwriter, data)
data_offset, data_size = self._write(dwriter, _data_content)
self.tabular_dataset.put(
TabularDatasetRow(
id=idx,
Expand All @@ -251,6 +259,7 @@ def make_swds(self) -> DatasetSummary:
data_offset=data_offset,
data_size=data_size,
data_origin=DataOriginType.NEW,
data_mime_type=_data_mime_type or self.default_data_mime_type,
)
)

Expand Down Expand Up @@ -344,20 +353,32 @@ def make_swds(self) -> DatasetSummary:
start=self._forked_last_idx + 1,
):
if isinstance(data, Link):
data_uri = data.uri
data_offset, data_size = data.offset, data.size
if data.auth:
auth = data.auth.name
auth_candidates[f"{data.auth.ltype}.{data.auth.name}"] = data.auth
_remote_link = data
data_uri = _remote_link.uri
data_offset, data_size = _remote_link.offset, _remote_link.size
if _remote_link.auth:
auth = _remote_link.auth.name
auth_candidates[
f"{_remote_link.auth.ltype}.{_remote_link.auth.name}"
] = _remote_link.auth
else:
auth = ""
object_store_type = ObjectStoreType.REMOTE
include_link = True
data_mime_type = _remote_link.mime_type
elif isinstance(data, (tuple, list)):
data_path, (data_offset, data_size) = data
if data_path not in map_path_sign:
map_path_sign[data_path] = DatasetStorage.save_data_file(data_path)
data_uri, _ = map_path_sign[data_path]
_data_fpath, _local_link = data
if _data_fpath not in map_path_sign:
map_path_sign[_data_fpath] = DatasetStorage.save_data_file(
_data_fpath
)

if not isinstance(_local_link, Link):
raise NoSupportError("data only support Link type")

data_mime_type = _local_link.mime_type
data_offset, data_size = _local_link.offset, _local_link.size
data_uri, _ = map_path_sign[_data_fpath]
auth = ""
object_store_type = ObjectStoreType.LOCAL
else:
Expand All @@ -374,6 +395,7 @@ def make_swds(self) -> DatasetSummary:
data_size=data_size,
data_origin=DataOriginType.NEW,
auth_name=auth,
data_mime_type=data_mime_type,
)
)

Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/api/_impl/dataset/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ def iter_label_slice(self, path: str) -> t.Generator[bytes, None, None]:
content = f.read(1)
if not content:
break
yield content
yield struct.unpack(">B", content)[0]
5 changes: 4 additions & 1 deletion client/starwhale/core/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ def _call_make_swds(

# TODO: add more import format support, current is module:class
logger.info(f"[info:swds]try to import {swds_config.process} @ {workdir}")
_cls = import_cls(workdir, swds_config.process, BaseBuildExecutor)
_cls: t.Type[BaseBuildExecutor] = import_cls(
workdir, swds_config.process, BaseBuildExecutor
)

with _cls(
dataset_name=self.uri.object.name,
Expand All @@ -305,6 +307,7 @@ def _call_make_swds(
append=append,
append_from_version=append_from_version,
append_from_uri=append_from_uri,
data_mime_type=swds_config.attr.data_mime_type,
) as _obj:
console.print(
f":ghost: import [red]{swds_config.process}@{workdir.resolve()}[/] to make swds..."
Expand Down
10 changes: 8 additions & 2 deletions client/starwhale/core/dataset/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def create_by_file_suffix(cls, name: str) -> MIMEType:
class Link:
def __init__(
self,
uri: str,
uri: str = "",
auth: t.Optional[LinkAuth] = DefaultS3LinkAuth,
offset: int = FilePosition.START,
size: int = -1,
Expand Down Expand Up @@ -235,15 +235,21 @@ def __init__(
self,
volume_size: t.Union[int, str] = D_FILE_VOLUME_SIZE,
alignment_size: t.Union[int, str] = D_ALIGNMENT_SIZE,
data_mime_type: MIMEType = MIMEType.UNDEFINED,
**kw: t.Any,
) -> None:
self.volume_size = convert_to_bytes(volume_size)
self.alignment_size = convert_to_bytes(alignment_size)
self.data_mime_type = data_mime_type
self.kw = kw

def as_dict(self) -> t.Dict[str, t.Any]:
# TODO: refactor an asdict mixin class
_rd = deepcopy(self.__dict__)
_rd.pop("kw")
_rd.pop("kw", None)
for k, v in _rd.items():
if isinstance(v, Enum):
_rd[k] = v.value
return _rd


Expand Down
12 changes: 9 additions & 3 deletions client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os
import json
import struct
from pathlib import Path

from starwhale.utils.fs import ensure_dir, blake2b_file
from starwhale.api.dataset import MNISTBuildExecutor, UserRawBuildExecutor
from starwhale.api.dataset import (
Link,
MIMEType,
MNISTBuildExecutor,
UserRawBuildExecutor,
)
from starwhale.core.dataset.store import DatasetStorage
from starwhale.core.dataset.tabular import TabularDataset
from starwhale.api._impl.dataset.builder import (
Expand All @@ -27,7 +33,7 @@ def iter_data_slice(self, path: str):
file_size = Path(path).stat().st_size
offset = 16
while offset < file_size:
yield offset, size
yield Link(offset=offset, size=size, mime_type=MIMEType.GRAYSCALE)
offset += size

def iter_label_slice(self, path: str):
Expand All @@ -39,7 +45,7 @@ def iter_label_slice(self, path: str):
content = f.read(1)
if not content:
break
yield content
yield struct.unpack(">B", content)[0]


class TestDatasetBuildExecutor(BaseTestCase):
Expand Down
5 changes: 3 additions & 2 deletions example/mnist/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ tag:
- bin

attr:
alignment_size: 4k
volume_size: 2M
alignment_size: 1k
volume_size: 4M
data_mime_type: "x/grayscale"
2 changes: 1 addition & 1 deletion example/mnist/mnist/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def iter_data_slice(self, path: str):
offset = 16

for _ in range(number):
yield offset, size
yield Link(offset=offset, size=size, mime_type=MIMEType.GRAYSCALE)
offset += size

def iter_label_slice(self, path: str):
Expand Down

0 comments on commit 9996167

Please sign in to comment.