Skip to content

Commit

Permalink
GH-457: classification data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed May 13, 2019
1 parent 558ad40 commit 6042d95
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 14 deletions.
152 changes: 147 additions & 5 deletions flair/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,85 @@ def __init__(
)


class ClassificationCorpus(TaggedCorpus):
def __init__(
self,
data_folder: Union[str, Path],
train_file=None,
test_file=None,
dev_file=None,
use_tokenizer: bool = True,
max_tokens_per_doc=-1,
):
"""
Helper function to get a TaggedCorpus from text classification-formatted task data
:param data_folder: base folder with the task data
:param train_file: the name of the train file
:param test_file: the name of the test file
:param dev_file: the name of the dev file, if None, dev data is sampled from train
:return: a TaggedCorpus with annotated train, dev and test data
"""

if type(data_folder) == str:
data_folder: Path = Path(data_folder)

if train_file is not None:
train_file = data_folder / train_file
if test_file is not None:
test_file = data_folder / test_file
if dev_file is not None:
dev_file = data_folder / dev_file

# automatically identify train / test / dev files
if train_file is None:
for file in data_folder.iterdir():
file_name = file.name
if "train" in file_name:
train_file = file
if "test" in file_name:
test_file = file
if "dev" in file_name:
dev_file = file
if "testa" in file_name:
dev_file = file
if "testb" in file_name:
test_file = file

log.info("Reading data from {}".format(data_folder))
log.info("Train: {}".format(train_file))
log.info("Dev: {}".format(dev_file))
log.info("Test: {}".format(test_file))

train: Dataset = ClassificationDataset(
train_file,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
)
test: Dataset = ClassificationDataset(
test_file,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
)

if dev_file is not None:
dev: Dataset = ClassificationDataset(
dev_file,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
)
else:
train_length = len(train)
dev_size: int = round(train_length / 10)
splits = random_split(train, [train_length - dev_size, dev_size])
train = splits[0]
dev = splits[1]

super(ClassificationCorpus, self).__init__(
train, dev, test, name=data_folder.name
)


class ColumnDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -219,9 +298,6 @@ def __len__(self):
def __getitem__(self, index: int = 0) -> Sentence:
return self.sentences[index]

def __iter__(self):
return iter(self.sentences)


class UniversalDependenciesDataset(Dataset):
def __init__(self, path_to_conll_file: Path):
Expand Down Expand Up @@ -274,8 +350,60 @@ def __len__(self):
def __getitem__(self, index: int = 0) -> Sentence:
return self.sentences[index]

def __iter__(self):
return iter(self.sentences)

class ClassificationDataset(Dataset):
def __init__(
self, path_to_file: Union[str, Path], max_tokens_per_doc=-1, use_tokenizer=True
):
"""
Reads a data file for text classification. The file should contain one document/text per line.
The line should have the following format:
__label__<class_name> <text>
If you have a multi class task, you can have as many labels as you want at the beginning of the line, e.g.,
__label__<class_name_1> __label__<class_name_2> <text>
:param path_to_file: the path to the data file
:param max_tokens_per_doc: Takes at most this amount of tokens per document. If set to -1 all documents are taken as is.
:return: list of sentences
"""
if type(path_to_file) == str:
path_to_file: Path = Path(path_to_file)

assert path_to_file.exists()

label_prefix = "__label__"
self.sentences = []

with open(str(path_to_file), encoding="utf-8") as f:
for line in f:
words = line.split()

labels = []
l_len = 0

for i in range(len(words)):
if words[i].startswith(label_prefix):
l_len += len(words[i]) + 1
label = words[i].replace(label_prefix, "")
labels.append(label)
else:
break

text = line[l_len:].strip()

if text and labels:
sentence = Sentence(
text, labels=labels, use_tokenizer=use_tokenizer
)
if len(sentence) > max_tokens_per_doc and max_tokens_per_doc > 0:
sentence.tokens = sentence.tokens[:max_tokens_per_doc]
if len(sentence.tokens) > 0:
self.sentences.append(sentence)

def __len__(self):
return len(self.sentences)

def __getitem__(self, index: int = 0) -> Sentence:
return self.sentences[index]


class CONLL_03(ColumnCorpus):
Expand All @@ -292,6 +420,20 @@ def __init__(self, base_path=None, tag_to_biloes: str = "ner"):
)


class GERMEVAL(ColumnCorpus):
def __init__(self, base_path=None, tag_to_biloes: str = "ner"):
columns = {1: "text", 2: "ner"}

# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / "germeval"

super(GERMEVAL, self).__init__(
data_folder, columns, tag_to_biloes=tag_to_biloes
)


class CONLL_2000(ColumnCorpus):
def __init__(self, base_path=None, tag_to_biloes: str = "np"):
columns = {0: "text", 1: "pos", 2: "np"}
Expand Down
18 changes: 9 additions & 9 deletions tests/test_data_fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_load_imdb_data(tasks_base_path):
# get training, test and dev data
corpus = NLPTaskDataFetcher.load_corpus("imdb", tasks_base_path)
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")

assert len(corpus.train) == 5
assert len(corpus.dev) == 5
Expand All @@ -26,8 +26,11 @@ def test_load_ag_news_data(tasks_base_path):


def test_load_sequence_labeling_data(tasks_base_path):

# get training, test and dev data
corpus = NLPTaskDataFetcher.load_corpus(NLPTask.FASHION, tasks_base_path)
corpus = flair.datasets.ColumnCorpus(
tasks_base_path / "fashion", column_format={0: "text", 2: "ner"}
)

assert len(corpus.train) == 6
assert len(corpus.dev) == 1
Expand All @@ -36,7 +39,7 @@ def test_load_sequence_labeling_data(tasks_base_path):

def test_load_germeval_data(tasks_base_path):
# get training, test and dev data
corpus = NLPTaskDataFetcher.load_corpus(NLPTask.GERMEVAL, tasks_base_path)
corpus = flair.datasets.GERMEVAL(tasks_base_path)

assert len(corpus.train) == 2
assert len(corpus.dev) == 1
Expand All @@ -57,9 +60,6 @@ def test_load_no_dev_data(tasks_base_path):
corpus = flair.datasets.ColumnCorpus(
tasks_base_path / "fashion_nodev", column_format={0: "text", 2: "ner"}
)
# corpus = NLPTaskDataFetcher.load_column_corpus(
# tasks_base_path / "fashion_nodev", {0: "text", 2: "ner"}
# )

assert len(corpus.train) == 5
assert len(corpus.dev) == 1
Expand All @@ -68,9 +68,9 @@ def test_load_no_dev_data(tasks_base_path):

def test_load_no_dev_data_explicit(tasks_base_path):
# get training, test and dev data
corpus = NLPTaskDataFetcher.load_column_corpus(
corpus = flair.datasets.ColumnCorpus(
tasks_base_path / "fashion_nodev",
{0: "text", 2: "ner"},
column_format={0: "text", 2: "ner"},
train_file="train.tsv",
test_file="test.tsv",
)
Expand All @@ -93,7 +93,7 @@ def test_multi_corpus(tasks_base_path):

def test_download_load_data(tasks_base_path):
# get training, test and dev data for full English UD corpus from web
corpus = NLPTaskDataFetcher.load_corpus(NLPTask.UD_ENGLISH)
corpus = flair.datasets.UD_ENGLISH()

assert len(corpus.train) == 12543
assert len(corpus.dev) == 2002
Expand Down

0 comments on commit 6042d95

Please sign in to comment.