Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add HuggingfaceDatasetSplitReader for using Huggingface datasets
Browse files Browse the repository at this point in the history
Added a new reader to allow for reading huggingface datasets as instance
Mapped limited `datasets.features` to `allenlp.data.fields`

Verified for selective dataset and/or dataset configurations

New Dependency - "datasets==1.5.0"

Signed-off-by: Abhishek P (VMware) <pab@vmware.com>
  • Loading branch information
Abhishek-P committed Apr 4, 2021
1 parent f82d3f1 commit 78a1487
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 2 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added

- Add `HuggingfaceDatasetSplitReader` for using huggingface datasets in AllenNLP with limited support
- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.

### Changed
Expand Down Expand Up @@ -264,7 +264,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added sampler class and parameter in beam search for non-deterministic search, with several
implementations, including `MultinomialSampler`, `TopKSampler`, `TopPSampler`, and
`GumbelSampler`. Utilizing `GumbelSampler` will give [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).

### Changed

- Pass batch metrics to `BatchCallback`.
Expand Down
170 changes: 170 additions & 0 deletions allennlp/data/dataset_readers/hugging_face_datasets_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Iterable, Optional

from allennlp.data import DatasetReader, Token
from allennlp.data.fields import TextField, LabelField, ListField
from allennlp.data.instance import Instance
from datasets import load_dataset
from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages
from datasets.features import Value


class HuggingfaceDatasetSplitReader(DatasetReader):
"""
This reader implementation wraps the huggingface datasets package to utilize it's dataset management functionality
and load the information in AllenNLP friendly formats
Note: Reader works w.r.t to only one split of the dataset, i.e. you would need to create seperate reader for seperate splits
Following dataset and configurations have been verified and work with this reader
Dataset Dataset Configuration
`xnli` `ar`
`xnli` `en`
`xnli` `de`
`xnli` `all_languages`
`glue` `cola`
`glue` `mrpc`
`glue` `sst2`
`glue` `qqp`
`glue` `mnli`
`glue` `mnli_matched`
`universal_dependencies` `en_lines`
`universal_dependencies` `ko_kaist`
`universal_dependencies` `af_afribooms`
`afrikaans_ner_corpus` `NA`
`swahili` `NA`
`conll2003` `NA`
`dbpedia_14` `NA`
`trec` `NA`
`emotion` `NA`
"""

def __init__(
self,
max_instances: Optional[int] = None,
manual_distributed_sharding: bool = False,
manual_multiprocess_sharding: bool = False,
serialization_dir: Optional[str] = None,
dataset_name: [str] = None,
split: str = 'train',
config_name: Optional[str] = None,
) -> None:
super().__init__(max_instances, manual_distributed_sharding, manual_multiprocess_sharding, serialization_dir)

# It would be cleaner to create a separate reader object for different dataset
self.dataset = None
self.dataset_name = dataset_name
self.config_name = config_name
self.index = -1

if config_name:
self.dataset = load_dataset(self.dataset_name, self.config_name, split=split)
else:
self.dataset = load_dataset(self.dataset_name, split=split)

def _read(self, file_path) -> Iterable[Instance]:
"""
Reads the dataset and converts the entry to AllenNLP friendly instance
"""
for entry in self.dataset:
yield self.text_to_instance(entry)

def text_to_instance(self, *inputs) -> Instance:
"""
Takes care of converting dataset entry into AllenNLP friendly instance
Currently it is implemented in an unseemly catch-up model where it converts datasets.features that are required
for the supported dataset, ideally it would require design where we cleanly map dataset.feature to an AllenNLP model
and then go ahead with converting it one by one
Doing that would provide the best chance of providing largest possible coverage with datasets
Currently this is how datasets.features types are mapped to AllenNLP Fields
dataset.feature type allennlp.data.fields
`ClassLabel` `LabelField` in feature name namespace
`Value.string` `TextField` with value as Token
`Value.*` `LabelField` with value being label in feature name namespace
`Sequence.string` `ListField` of `TextField` with individual string as token
`Sequence.ClassLabel` `ListField` of `ClassLabel` in feature name namespace
`Translation` `ListField` of 2 ListField (ClassLabel and TextField)
`TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField)
"""

# features indicate the different information available in each entry from dataset
# feature types decide what type of information they are
# e.g. In a Sentiment an entry could have one feature indicating the text and another indica
features = self.dataset.features
fields = dict()

# TODO we need to support all different datasets features of https://huggingface.co/docs/datasets/features.html
for feature in features:
value = features[feature]

# datasets ClassLabel maps to LabelField
if isinstance(value, ClassLabel):
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True)

# datasets Value can be of different types
elif isinstance(value, Value):

# String value maps to TextField
if value.dtype == 'string':
# Since TextField has to be made of Tokens add whole text as a token
# TODO Should we use simple heuristics to identify what is token and what is not?
field = TextField([Token(inputs[0][feature])])

else:
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True)


elif isinstance(value, Sequence):
# datasets Sequence of strings to ListField of TextField
if value.feature.dtype == 'string':
field_list = list()
for item in inputs[0][feature]:
item_field = TextField([Token(item)])
field_list.append(item_field)
if len(field_list) == 0:
continue
field = ListField(field_list)

# datasets Sequence of strings to ListField of LabelField
elif isinstance(value.feature, ClassLabel):
field_list = list()
for item in inputs[0][feature]:
item_field = LabelField(label=item, label_namespace=feature, skip_indexing=True)
field_list.append(item_field)
if len(field_list) == 0:
continue
field = ListField(field_list)

# datasets Translation cannot be mapped directly but it's dict structure can be mapped to a ListField of 2 ListField
elif isinstance(value, Translation):
if value.dtype == "dict":
input_dict = inputs[0][feature]
langs = list(input_dict.keys())
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
langs_field = ListField(field_langs)
texts = list()
for lang in langs:
texts.append(TextField([Token(input_dict[lang])]))
field = ListField([langs_field, ListField(texts)])

# TranslationVariableLanguages is functionally a pair of Lists and hence mapped to a ListField of 2 ListField
elif isinstance(value, TranslationVariableLanguages):
# Although it is indicated as dict made up of a pair of lists
if value.dtype == "dict":
input_dict = inputs[0][feature]
langs = input_dict["language"]
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
langs_field = ListField(field_langs)
texts = list()
for lang in langs:
index = langs.index(lang)
texts.append(TextField([Token(input_dict["translation"][index])]))
field = ListField([langs_field, ListField(texts)])

else:
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")

fields[feature] = field

return Instance(fields)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"lmdb",
"more-itertools",
"wandb>=0.10.0,<0.11.0",
"datasets==1.5.0"
],
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
include_package_data=True,
Expand Down
45 changes: 45 additions & 0 deletions tests/data/dataset_readers/huggingface_datasets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from allennlp.data.dataset_readers.conll2003 import Conll2003DatasetReader
from allennlp.data.dataset_readers.hugging_face_datasets_reader import HuggingfaceDatasetSplitReader
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import ensure_list
from allennlp.common.testing import AllenNlpTestCase
import logging

logger = logging.getLogger(__name__)

# TODO these UTs are actually downloading the datasets and will be very very slow
class HuggingfaceDatasetSplitReaderTest:


SUPPORTED_DATASETS_WITHOUT_CONFIG = ["afrikaans_ner_corpus", "dbpedia_14", "trec", "swahili", "conll2003", "emotion"]

"""
Running the tests for supported datasets which do not require config name to be specified
"""
@pytest.mark.parametrize("dataset", SUPPORTED_DATASETS_WITHOUT_CONFIG)
def test_read_for_datasets_without_config(self, dataset):
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset)
instances = list(huggingface_reader.read(None))
assert len(instances) == len(huggingface_reader.dataset)

# Not testing for all configurations only some
SUPPORTED_DATASET_CONFIGURATION = (
("glue", "cola"),
("universal_dependencies", "af_afribooms"),
("xnli", "all_languages")
)

"""
Running the tests for supported datasets which require config name to be specified
"""
@pytest.mark.parametrize("dataset, config", SUPPORTED_DATASET_CONFIGURATION)
def test_read_for_datasets_requiring_config(self, dataset, config):
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset, config_name=config)
instances = list(huggingface_reader.read(None))
assert len(instances) == len(huggingface_reader.dataset)




0 comments on commit 78a1487

Please sign in to comment.