-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_data.py
107 lines (87 loc) · 6.07 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from functools import partial
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
def tokenize_func(example, tokenizer, content):
return tokenizer(example[content])
def group_text(examples, context_length):
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
total_length = (total_length // context_length) * context_length
# Split by chunks of max_len.
result = {
k: [t[i: i + context_length] for i in range(0, total_length, context_length)]
for k, t in concatenated_examples.items()
}
return result
def prepare_data(dataset_name, tokenizer, context_length, dataset_cache_dir=None):
if dataset_name == "wikitext":
train_dataset, val_dataset, tokenized_test, data_collator = prep_wikitext_2_raw_v1(context_length, tokenizer,
dataset_cache_dir)
elif dataset_name == "ptb":
train_dataset, val_dataset, tokenized_test, data_collator = prep_ptb(context_length, tokenizer,
dataset_cache_dir)
elif dataset_name == "c4":
train_dataset, val_dataset, tokenized_test, data_collator = prep_c4(context_length, tokenizer,
dataset_cache_dir)
elif dataset_name == "alpaca":
train_dataset, val_dataset, tokenized_test, data_collator = prep_alpaca(context_length, tokenizer,
dataset_cache_dir)
else:
raise ValueError(f"Dataset {dataset_name} not supported")
return train_dataset, val_dataset, tokenized_test, data_collator
def prep_wikitext_2_raw_v1(context_length, tokenizer, dataset_cache_dir=None):
print("load wikitext dataset")
train_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train",
dataset_cache_dir=dataset_cache_dir)
val_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation",
dataset_cache_dir=dataset_cache_dir)
test_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", dataset_cache_dir=dataset_cache_dir)
func = partial(tokenize_func, tokenizer=tokenizer, content="text")
tokenized_train = train_raw_dataset.map(func, num_proc=4, batched=True, remove_columns="text")
tokenized_val = val_raw_dataset.map(func, num_proc=4, batched=True, remove_columns="text")
tokenized_test = tokenizer("\n\n".join(test_raw_dataset["text"]), return_tensors="pt")
func = partial(group_text, context_length=context_length)
train_dataset = tokenized_train.map(func, num_proc=4, batch_size=1024, batched=True)
val_dataset = tokenized_val.map(func, num_proc=4, batch_size=1024, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt")
return train_dataset, val_dataset, tokenized_test, data_collator
def prep_ptb(context_length, tokenizer, dataset_cache_dir=None):
print("load ptb dataset")
train_raw_dataset = load_dataset("ptb_text_only", "penn_treebank", split='train',
dataset_cache_dir=dataset_cache_dir)
val_raw_dataset = load_dataset("ptb_text_only", "penn_treebank", split='validation',
dataset_cache_dir=dataset_cache_dir)
test_raw_dataset = load_dataset("ptb_text_only", "penn_treebank", split='test',
dataset_cache_dir=dataset_cache_dir)
func = partial(tokenize_func, tokenizer=tokenizer, content="sentence")
tokenized_train_data = train_raw_dataset.map(func, num_proc=4, batched=True, remove_columns="sentence")
tokenized_val_data = val_raw_dataset.map(func, num_proc=4, batched=True, remove_columns="sentence")
tokenized_test_data = tokenizer("\n\n".join(test_raw_dataset['sentence']), return_tensors="pt")
func = partial(group_text, context_length=context_length)
train_dataset = tokenized_train_data.map(func, num_proc=4, batch_size=1024, batched=True)
val_dataset = tokenized_val_data.map(func, num_proc=4, batch_size=1024, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt")
return train_dataset, val_dataset, tokenized_test_data, data_collator
def prep_c4(context_length, tokenizer, dataset_cache_dir=None):
print("load C4 dataset")
train_raw_dataset = load_dataset("json", data_dir=dataset_cache_dir, data_files="c4-train.json")['train']
val_raw_dataset = load_dataset("json", data_dir=dataset_cache_dir, data_files="c4-validation.json")['train']
test_raw_dataset = val_raw_dataset
tokenized_test_data = tokenizer("\n\n".join(test_raw_dataset['text'][0:2000]), return_tensors="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt")
train_dataset = None
val_dataset = None
return train_dataset, val_dataset, tokenized_test_data, data_collator
def prep_alpaca(context_length, tokenizer, dataset_cache_dir=None):
print("load Alpaca dataset")
train_raw_dataset = load_dataset("tatsu-lab/alpaca", split='train',
dataset_cache_dir=dataset_cache_dir)
func = partial(tokenize_func, tokenizer=tokenizer, content="text")
tokenized_train_data = train_raw_dataset.map(func, num_proc=1, batched=True,
remove_columns=["text", "instruction", "input", "output"])
func = partial(group_text, context_length=context_length)
train_dataset = tokenized_train_data.map(func, num_proc=1, batch_size=1024, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt")
return train_dataset, None, None, data_collator