Skip to content

Commit

Permalink
Add a HuggingFace dataset loader for Space (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhou Fang authored Dec 27, 2023
1 parent 8430a33 commit 3ddac64
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
pip install mypy pylint pytest mock
pip install mypy pylint pytest mock datasets
- name: Install runtime dependencies and Space
working-directory: ./python
run: |
Expand Down
7 changes: 4 additions & 3 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ disable = [
[tool.pylint.MAIN]
ignore = "space/core/proto"
ignored-modules = [
"space.core.proto",
"google.protobuf",
"substrait",
"array_record",
"datasets",
"google.protobuf",
"space.core.proto",
"substrait"
]
5 changes: 5 additions & 0 deletions python/src/space/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ def serializer(self) -> DictSerializer:
def local(self) -> LocalRunner:
"""Get a runner that runs operations locally."""
return LocalRunner(self._storage)

def index_files(self) -> List[str]:
"""A list of full path of index files."""
data_files = self._storage.data_files()
return [self._storage.full_path(f.path) for f in data_files.index_files]
13 changes: 13 additions & 0 deletions python/src/space/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
28 changes: 28 additions & 0 deletions python/src/space/huggingface/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""HuggingFace integration with Space."""

from datasets import dataset_dict, load_dataset # type: ignore[import-untyped]

import space


def load_space_dataset(location: str) -> dataset_dict.DatasetDict:
"""Load a HuggingFace dataset from a Space dataset.
TODO: to support version (snapshot), column selection and filters.
"""
space_ds = space.Dataset.load(location)
return load_dataset("parquet", data_files={"train": space_ds.index_files()})
13 changes: 13 additions & 0 deletions python/src/space/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
8 changes: 4 additions & 4 deletions python/tests/core/loaders/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_append_parquet(self, tmp_path):
primary_keys=["int64"],
record_fields=[])

dummy_data = [{
input_data = [{
"int64": [1, 2, 3],
"float64": [0.1, 0.2, 0.3],
"bool": [True, False, False],
Expand All @@ -49,9 +49,9 @@ def test_append_parquet(self, tmp_path):
input_dir = tmp_path / "parquet"
input_dir.mkdir(parents=True)
write_parquet_file(str(input_dir / "file0.parquet"), schema,
[pa.Table.from_pydict(dummy_data[0])])
[pa.Table.from_pydict(input_data[0])])
write_parquet_file(str(input_dir / "file1.parquet"), schema,
[pa.Table.from_pydict(dummy_data[1])])
[pa.Table.from_pydict(input_data[1])])

runner = ds.local()
response = runner.append_parquet(input_dir)
Expand All @@ -64,5 +64,5 @@ def test_append_parquet(self, tmp_path):
index_data = pa.concat_tables(
(list(runner.read()))).combine_chunks().sort_by("int64")
assert index_data == pa.concat_tables([
pa.Table.from_pydict(d) for d in dummy_data
pa.Table.from_pydict(d) for d in input_data
]).combine_chunks().sort_by("int64")
42 changes: 42 additions & 0 deletions python/tests/huggingface/test_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pyarrow as pa

from space import Dataset
from space.huggingface.load import load_space_dataset


def test_load_space_dataset(tmp_path):
schema = pa.schema([("int64", pa.int64()), ("string", pa.string())])
location = str(tmp_path / "dataset")
ds = Dataset.create(location,
schema,
primary_keys=["int64"],
record_fields=[])

input_data = [{
"int64": [1, 2, 3],
"string": ["a", "b", "c"]
}, {
"int64": [0, 10],
"string": ["A", "z"]
}]

runner = ds.local()
for data in input_data:
runner.append(data)

huggingface_ds = load_space_dataset(location)
assert huggingface_ds["train"].data == runner.read_all()

0 comments on commit 3ddac64

Please sign in to comment.