-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbabylm_dataset.py
66 lines (55 loc) · 2.69 KB
/
babylm_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
from torch.utils.data import Dataset
from random import randrange
from pathlib import Path
class BabylmDataset(Dataset):
"""
Example usage:
tokenizer = GPT2TokenizerFast(tokenizer_file= str(tokenizer_path))
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.pad_token = "<pad>"
train_dataset = BabylmDataset(PATH / "data/babylm_10M", SEQ_LENGTH, tokenizer=tokenizer, random_chunk=True)
full_eval_dataset = BabylmDataset(PATH / "data/babylm_dev", SEQ_LENGTH, tokenizer=tokenizer, offset=0)
eval_indices = sample(range(len(full_eval_dataset)), EVAL_SAMPLES)
eval_dataset = Subset(full_eval_dataset, eval_indices)
"""
def __init__(self, data_dir: str, seq_length: int, tokenizer, offset: int=0, random_chunk: bool=False, use_only=None):
self.seq_length = seq_length
self.offset = offset
self.tokenizer = tokenizer
self.random_chunk = random_chunk
self.use_only = use_only if use_only else []
tokenizer_name = tokenizer.__class__.__name__
use_only_name = '_'.join(self.use_only) if self.use_only else 'all'
tokenized_file = Path(os.path.join(data_dir, f"tokenized_{tokenizer_name}_{tokenizer.vocab_size}_{use_only_name}.pt"))
if tokenized_file.exists():
print(f"Loading data from {tokenized_file}")
self.data = torch.load(tokenized_file)
else:
data = []
src_files = [str(f) for f in Path(data_dir).glob("**/*")
if f.is_file() and not f.name.endswith(".DS_Store") and f.suffix in [".train", ".dev"]]
for src_file in src_files:
if self.use_only and not any([name in src_file for name in self.use_only]):
continue
text = Path(src_file).read_text(encoding="utf-8")
encoded = self.tokenizer.encode(text)
print("🔥", src_file, "len:", len(encoded))
data.extend(encoded)
self.data = torch.tensor(data)
# Save tokenized data
print(f"Saving data to {tokenized_file}")
torch.save(self.data, tokenized_file)
def __len__(self):
if self.random_chunk:
return len(self.data) // self.seq_length - 1
else:
return (len(self.data) - self.offset) // self.seq_length
def __getitem__(self, i):
if self.random_chunk:
offset = randrange(self.seq_length) # Sample random offset between 0 and seq_length-1
return self.data[i*self.seq_length+offset:(i+1)*self.seq_length+offset]
else:
return self.data[i*self.seq_length+self.offset:(i+1)*self.seq_length+self.offset]