Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dataset): support annotations and data type for dataset #1130

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ example/PennFudanPed/data/
example/PennFudanPed/models/
example/nmt/models/
example/nmt/data/
example/speech_command/data/
example/speech_command/models/

# UI
node_modules
Expand Down
2 changes: 0 additions & 2 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,7 @@ def __init__(
yield r

def dump(self) -> None:
logger.debug(f"start dump, tables size:{len(self.tables.values())}")
for table in list(self.tables.values()):
logger.debug(f"dump {table.table_name} to {self.root_path}")
table.dump(self.root_path)


Expand Down
20 changes: 16 additions & 4 deletions client/starwhale/api/_impl/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from starwhale.core.dataset.type import (
Link,
Text,
Audio,
Image,
Binary,
LinkType,
MIMEType,
DataField,
ClassLabel,
S3LinkAuth,
BoundingBox,
GrayscaleImage,
LocalFSLinkAuth,
DefaultS3LinkAuth,
COCOObjectAnnotation,
)

from .mnist import MNISTBuildExecutor
from .loader import get_data_loader, SWDSBinDataLoader, UserRawDataLoader
from .builder import BuildExecutor, SWDSBinBuildExecutor, UserRawBuildExecutor

Expand All @@ -20,11 +26,17 @@
"S3LinkAuth",
"MIMEType",
"LinkType",
"DataField",
"BuildExecutor", # SWDSBinBuildExecutor alias
"UserRawBuildExecutor",
"SWDSBinBuildExecutor",
"MNISTBuildExecutor",
"SWDSBinDataLoader",
"UserRawDataLoader",
"Binary",
"Text",
"Audio",
"Image",
"ClassLabel",
"BoundingBox",
"GrayscaleImage",
"COCOObjectAnnotation",
]
172 changes: 60 additions & 112 deletions client/starwhale/api/_impl/dataset/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from __future__ import annotations

import sys
import struct
import typing as t
import tempfile
Expand All @@ -19,8 +16,10 @@
from starwhale.core.dataset import model
from starwhale.core.dataset.type import (
Link,
Binary,
LinkAuth,
MIMEType,
BaseArtifact,
DatasetSummary,
D_ALIGNMENT_SIZE,
D_FILE_VOLUME_SIZE,
Expand All @@ -45,10 +44,7 @@ def __init__(
dataset_name: str,
dataset_version: str,
project_name: str,
data_dir: Path = Path("."),
workdir: Path = Path("./sw_output"),
data_filter: str = "*",
label_filter: str = "*",
alignment_bytes_size: int = D_ALIGNMENT_SIZE,
volume_bytes_size: int = D_FILE_VOLUME_SIZE,
append: bool = False,
Expand All @@ -58,10 +54,6 @@ def __init__(
) -> None:
# TODO: add more docstring for args
# TODO: validate group upper and lower?
self.data_dir = data_dir
self.data_filter = data_filter
self.label_filter = label_filter

self.workdir = workdir
self.data_output_dir = workdir / "data"
ensure_dir(self.data_output_dir)
Expand Down Expand Up @@ -118,6 +110,10 @@ def __exit__(

print("cleanup done.")

@abstractmethod
def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]:
raise NotImplementedError

@abstractmethod
def make_swds(self) -> DatasetSummary:
raise NotImplementedError
Expand All @@ -128,59 +124,16 @@ def _merge_forked_summary(self, s: DatasetSummary) -> DatasetSummary:
s.rows += _fs.rows
s.unchanged_rows += _fs.rows
s.data_byte_size += _fs.data_byte_size
s.label_byte_size += _fs.label_byte_size
s.annotations = list(set(s.annotations) | set(_fs.annotations))
s.include_link |= _fs.include_link
s.include_user_raw |= _fs.include_user_raw

return s

def _iter_files(
self, filter: str, sort_key: t.Optional[t.Any] = None
) -> t.Generator[Path, None, None]:
_key = sort_key
if _key is not None and not callable(_key):
raise Exception(f"data_sort_func({_key}) is not callable.")

_files = sorted(self.data_dir.rglob(filter), key=_key)
for p in _files:
if not p.is_file():
continue
yield p

def iter_data_files(self) -> t.Generator[Path, None, None]:
return self._iter_files(self.data_filter, self.data_sort_func())

def iter_label_files(self) -> t.Generator[Path, None, None]:
return self._iter_files(self.label_filter, self.label_sort_func())

def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]:
for p in self.iter_data_files():
for d in self.iter_data_slice(str(p.absolute())):
yield p, d

def iter_all_label_slice(self) -> t.Generator[t.Any, None, None]:
for p in self.iter_label_files():
for d in self.iter_label_slice(str(p.absolute())):
yield p, d

@abstractmethod
def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]:
raise NotImplementedError

@abstractmethod
def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]:
raise NotImplementedError

@property
def data_format_type(self) -> DataFormatType:
raise NotImplementedError

def data_sort_func(self) -> t.Any:
return None

def label_sort_func(self) -> t.Any:
return None


class SWDSBinBuildExecutor(BaseBuildExecutor):
"""
Expand Down Expand Up @@ -244,39 +197,45 @@ def make_swds(self) -> DatasetSummary:
ds_copy_candidates[fno] = dwriter_path

increased_rows = 0
total_label_size, total_data_size = 0, 0
total_data_size = 0
dataset_annotations: t.Dict[str, t.Any] = {}

for idx, ((_, data), (_, label)) in enumerate(
zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()),
start=self._forked_last_idx + 1,
for idx, (row_data, row_annotations) in enumerate(
self.iter_item(), start=self._forked_last_idx + 1
):
if isinstance(data, (tuple, list)):
_data_content, _data_mime_type = data
if not isinstance(row_annotations, dict):
raise FormatError(f"annotations({row_annotations}) must be dict type")

_artifact: BaseArtifact
if isinstance(row_data, bytes):
_artifact = Binary(row_data, self.default_data_mime_type)
elif isinstance(row_data, BaseArtifact):
_artifact = row_data
else:
_data_content, _data_mime_type = data, self.default_data_mime_type
raise NoSupportError(f"data type {type(row_data)}")

if not isinstance(_data_content, bytes):
raise FormatError("data content must be bytes type")
if not dataset_annotations:
# TODO: check annotations type and name
dataset_annotations = row_annotations

_bin_section = self._write(dwriter, _data_content)
_bin_section = self._write(dwriter, _artifact.to_bytes())
self.tabular_dataset.put(
TabularDatasetRow(
id=idx,
data_uri=str(fno),
label=label,
data_format=self.data_format_type,
object_store_type=ObjectStoreType.LOCAL,
data_offset=_bin_section.raw_data_offset,
data_size=_bin_section.raw_data_size,
_swds_bin_offset=_bin_section.offset,
_swds_bin_size=_bin_section.size,
data_origin=DataOriginType.NEW,
data_mime_type=_data_mime_type or self.default_data_mime_type,
data_type=_artifact.astype(),
annotations=row_annotations,
)
)

total_data_size += _bin_section.size
total_label_size += sys.getsizeof(label)

wrote_size += _bin_section.size
if wrote_size > self.volume_bytes_size:
Expand All @@ -303,10 +262,10 @@ def make_swds(self) -> DatasetSummary:
summary = DatasetSummary(
rows=increased_rows,
increased_rows=increased_rows,
label_byte_size=total_label_size,
data_byte_size=total_data_size,
include_user_raw=False,
include_link=False,
annotations=list(dataset_annotations.keys()),
)
return self._merge_forked_summary(summary)

Expand Down Expand Up @@ -335,39 +294,53 @@ def _copy_files(

_dest_path.symlink_to(_obj_path)

# TODO: tune performance scan after put in a second
for row in self.tabular_dataset.scan(*row_pos):
self.tabular_dataset.update(
row_id=row.id, data_uri=map_fno_sign[int(row.data_uri)]
)

def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]:
with Path(path).open() as f:
yield f.read()

def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]:
yield Path(path).name


BuildExecutor = SWDSBinBuildExecutor


class UserRawBuildExecutor(BaseBuildExecutor):
def make_swds(self) -> DatasetSummary:
increased_rows = 0
total_label_size, total_data_size = 0, 0
total_data_size = 0
auth_candidates = {}
include_link = False

map_path_sign: t.Dict[str, t.Tuple[str, Path]] = {}
dataset_annotations: t.Dict[str, t.Any] = {}

for idx, (data, (_, label)) in enumerate(
zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()),
for idx, (row_data, row_annotations) in enumerate(
self.iter_item(),
start=self._forked_last_idx + 1,
):
if isinstance(data, Link):
_remote_link = data
if not isinstance(row_annotations, dict):
raise FormatError(f"annotations({row_annotations}) must be dict type")

if not dataset_annotations:
# TODO: check annotations type and name
dataset_annotations = row_annotations

if not isinstance(row_data, Link):
raise FormatError(f"data({row_data}) must be Link type")

if row_data.with_local_fs_data:
_local_link = row_data
_data_fpath = _local_link.uri
if _data_fpath not in map_path_sign:
map_path_sign[_data_fpath] = DatasetStorage.save_data_file(
_data_fpath
)
data_uri, _ = map_path_sign[_data_fpath]
auth = ""
object_store_type = ObjectStoreType.LOCAL
else:
_remote_link = row_data
data_uri = _remote_link.uri
data_offset, data_size = _remote_link.offset, _remote_link.size
if _remote_link.auth:
auth = _remote_link.auth.name
auth_candidates[
Expand All @@ -377,42 +350,23 @@ def make_swds(self) -> DatasetSummary:
auth = ""
object_store_type = ObjectStoreType.REMOTE
include_link = True
data_mime_type = _remote_link.mime_type
elif isinstance(data, (tuple, list)):
_data_fpath, _local_link = data
if _data_fpath not in map_path_sign:
map_path_sign[_data_fpath] = DatasetStorage.save_data_file(
_data_fpath
)

if not isinstance(_local_link, Link):
raise NoSupportError("data only support Link type")

data_mime_type = _local_link.mime_type
data_offset, data_size = _local_link.offset, _local_link.size
data_uri, _ = map_path_sign[_data_fpath]
auth = ""
object_store_type = ObjectStoreType.LOCAL
else:
raise FormatError(f"data({data}) type error, no list, tuple or Link")

self.tabular_dataset.put(
TabularDatasetRow(
id=idx,
data_uri=data_uri,
label=label,
data_format=self.data_format_type,
object_store_type=object_store_type,
data_offset=data_offset,
data_size=data_size,
data_offset=row_data.offset,
data_size=row_data.size,
data_origin=DataOriginType.NEW,
auth_name=auth,
data_mime_type=data_mime_type,
data_type=row_data.astype(),
annotations=row_annotations,
)
)

total_data_size += data_size
total_label_size += sys.getsizeof(label)
total_data_size += row_data.size
increased_rows += 1

self._copy_files(map_path_sign)
Expand All @@ -422,10 +376,10 @@ def make_swds(self) -> DatasetSummary:
summary = DatasetSummary(
rows=increased_rows,
increased_rows=increased_rows,
label_byte_size=total_label_size,
data_byte_size=total_data_size,
include_link=include_link,
include_user_raw=True,
annotations=list(dataset_annotations.keys()),
)
return self._merge_forked_summary(summary)

Expand All @@ -444,12 +398,6 @@ def _copy_auth(self, auth_candidates: t.Dict[str, LinkAuth]) -> None:
for auth in auth_candidates.values():
f.write("\n".join(auth.dump_env()))

def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]:
yield 0, Path(path).stat().st_size

def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]:
yield Path(path).name

@property
def data_format_type(self) -> DataFormatType:
return DataFormatType.USER_RAW
Loading