Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow loss masking for defined spans of characters #113

Merged
merged 51 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9367fcd
convert character spans to token spans
sohamparikh Jan 14, 2025
515dcb5
handle null spans
sohamparikh Jan 14, 2025
3457ba2
handle spans in data iterator, fix test
sohamparikh Jan 15, 2025
c7373b9
bump dataset version
sohamparikh Jan 16, 2025
0699e0f
create a document class
sohamparikh Jan 16, 2025
419acd7
make loss masking work for prepare and training
sohamparikh Jan 24, 2025
acad1e4
merge main
sohamparikh Jan 24, 2025
daa2ad7
bos and eos options for tokenizer
sohamparikh Jan 25, 2025
bb175bf
loss masking for triton cross entropy
sohamparikh Jan 27, 2025
0e7ad8b
fix random data tests
sohamparikh Jan 28, 2025
989a8f8
revert precommit versions
sohamparikh Jan 28, 2025
9633f88
fix memmap dataset test
sohamparikh Jan 28, 2025
4f955ff
fix remaining dataset tests
sohamparikh Jan 28, 2025
70e40e8
Merge branch 'main' into soham/loss-masking-spans
sohamparikh Jan 28, 2025
1ac5052
compose tests
sohamparikh Jan 28, 2025
aebb5a0
handle special tokens from config
sohamparikh Jan 28, 2025
d8e3ae1
fix fim to handle bos and eos
sohamparikh Jan 28, 2025
a887dd6
address review comments
sohamparikh Jan 28, 2025
40a80f6
fix memmap tests
sohamparikh Jan 28, 2025
e908303
fix fim tests
sohamparikh Jan 28, 2025
20ffae8
special tokens mode -> sequence delimiters
sohamparikh Jan 29, 2025
753e731
GPTDataBatch -> GPTBatch
sohamparikh Jan 29, 2025
cce0701
GPTMemmapDocument, GPTMemmapSample -> GPTSample
sohamparikh Jan 29, 2025
0583dec
make loss masking opt-in in cross-entropy
sohamparikh Jan 30, 2025
7c40bf2
make spans opt-in during prepare
sohamparikh Jan 30, 2025
1998b9f
make spans opt-in for train
sohamparikh Jan 30, 2025
913a9d3
revert tests and random dataset
sohamparikh Jan 30, 2025
23dc7eb
partially fix existing tests
sohamparikh Jan 30, 2025
6712d5e
fix existing tests
sohamparikh Jan 30, 2025
6802627
remove get_span_sizes
sohamparikh Jan 30, 2025
fbf5157
typing for custom model
sohamparikh Jan 30, 2025
4bcb488
fix memmap tests
sohamparikh Jan 30, 2025
8494b6a
test for spans
sohamparikh Jan 30, 2025
246456c
Merge branch 'main' into soham/loss-masking-spans
sohamparikh Jan 30, 2025
769d466
fix triton cross-entropy
sohamparikh Jan 30, 2025
599e073
Merge remote-tracking branch 'refs/remotes/origin/soham/loss-masking-…
sohamparikh Jan 30, 2025
5cbb342
Update data.py
sohamparikh Jan 31, 2025
a04cc94
fix loss mask in cross entropy
sohamparikh Feb 5, 2025
2f2495d
review comments
sohamparikh Feb 5, 2025
f95f67c
Merge branch 'main' into soham/loss-masking-spans
sohamparikh Feb 5, 2025
348a17c
fix fused cross-entropy
sohamparikh Feb 5, 2025
06f81c7
cleaner collating
sohamparikh Feb 5, 2025
15b9033
fix collate
sohamparikh Feb 5, 2025
5222a5e
merge conflicts, change fim tests
sohamparikh Feb 6, 2025
b277c8d
run pre-commit on all
sohamparikh Feb 6, 2025
a72c813
misc
jlamypoirier Feb 7, 2025
a13bf2d
misc
jlamypoirier Feb 7, 2025
c1bbadf
misc
jlamypoirier Feb 7, 2025
69a59c4
Simplfy tests
jlamypoirier Feb 7, 2025
3aa735b
fix
jlamypoirier Feb 7, 2025
0821abe
fix loss mask
sohamparikh Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ def _init(self, name: str, prefix: pathlib.Path | str):
offset=offset + self._document_sizes.nbytes,
)

self._num_spans = np.frombuffer(
sohamparikh marked this conversation as resolved.
Show resolved Hide resolved
self._index_bin_buffer,
dtype=np.int32,
count=self._num_documents,
offset=offset + self._document_sizes.nbytes + self._pointers.nbytes,
)
spans = []
offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes
for n_spans in self._num_spans:
span = np.frombuffer(
self._index_bin_buffer,
dtype=np.int32,
count=n_spans * 2,
offset=offset,
).reshape(-1, 2)
spans.append(span)
offset += span.nbytes
self._spans = spans

self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

Expand All @@ -64,11 +83,14 @@ def __del__(self):
del self._index_bin_buffer_mmap

def get(self, idx, offset=0, length=None):
return np.frombuffer(
self._bin_buffer,
dtype=self._dtype,
count=self._document_sizes[idx] - offset if length is None else length,
offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize,
return (
np.frombuffer(
self._bin_buffer,
dtype=self._dtype,
count=self._document_sizes[idx] - offset if length is None else length,
offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize,
),
self._spans[idx],
)

@property
Expand All @@ -92,20 +114,23 @@ def get_document_sizes(self) -> "np.ndarray":
return self._document_sizes

@classmethod
def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np.ndarray]):
def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[tuple[np.ndarray, np.ndarray]]):
sohamparikh marked this conversation as resolved.
Show resolved Hide resolved
# Initialize metadata
dtype = None
num_documents = 0
lengths = []
pointers = []
offset = 0
# number of spans for each document
num_spans = []
spans = []

prefix = pathlib.Path(prefix)
prefix.parent.mkdir(parents=True, exist_ok=True)

# Write the binary data file (.bin) lazily
with prefix.with_suffix(".bin").open("wb") as bin_stream:
for document in documents:
for document, mask_spans in documents:
# Infer dtype from the first document
if dtype is None:
dtype = document.dtype
Expand All @@ -121,12 +146,16 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np
doc_length = len(document)
lengths.append(doc_length)
pointers.append(offset)
num_spans.append(len(mask_spans))
spans.append(mask_spans)
offset += doc_length * np.dtype(dtype).itemsize
num_documents += 1

# Finalize metadata arrays
lengths = np.array(lengths, dtype=np.int32)
pointers = np.array(pointers, dtype=np.int64)
num_spans = np.array(num_spans, dtype=np.int32)
spans = np.vstack(spans, dtype=np.int32)

# Write the index file (.idx)
with prefix.with_suffix(".idx").open("wb") as idx_stream:
Expand All @@ -142,5 +171,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np
idx_stream.write(lengths.tobytes(order="C"))
# Sequence (document) begin offsets in the bin file
idx_stream.write(pointers.tobytes(order="C"))
# Number of spans per document
idx_stream.write(num_spans.tobytes(order="C"))
# Span indices for each document
idx_stream.write(spans.tobytes(order="C"))
# Document indices, unused but needed for compatibility with Megatron-LM
idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C"))
3 changes: 3 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class GPTHuggingfaceDatasetConfig(Config):
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
spans_field: None | str = Field(
sohamparikh marked this conversation as resolved.
Show resolved Hide resolved
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
)
data_type: DataType | None = Field(
default=None,
desc="Data type of the dataset field."
Expand Down
40 changes: 35 additions & 5 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,37 @@ class GPTMemmapDatasetPreparator(DatasetPreparator):
_tokenizer: Tokenizer
_data_type: DataType

def _tokenize_with_spans(self, sample):
"""
Perform span-aware tokenization and return the tokenized input_ids along with token spans.
"""
char_spans = sample.get(self._config.dataset.spans_field, [])
text = sample[self._config.dataset.field]
input_ids = []
token_spans = []
char_pos = 0
for start, end in char_spans:
if char_pos < start:
curr_text = text[char_pos:start]
tokenized_text = self._tokenizer.tokenize(curr_text)
input_ids.extend(tokenized_text)
curr_text = text[start : end + 1]
tokenized_text = self._tokenizer.tokenize(curr_text)
input_ids.extend(tokenized_text)
token_spans.append((len(token_spans), len(token_spans) + len(tokenized_text) - 1))
char_pos = end + 1
if char_pos < len(text):
curr_text = text[char_pos:]
tokenized_text = self._tokenizer.tokenize(curr_text)
input_ids.extend(tokenized_text)
return np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32)

def _tokenize_batch(self, batch):
input_ids = [
np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy)
for text in batch[self._config.dataset.field]
]
input_ids, token_spans = zip(*[self._tokenize_with_spans(sample) for sample in batch])
num_tokens = [len(x) for x in input_ids]
return {
"input_ids": input_ids,
"token_spans": token_spans,
"num_tokens": num_tokens,
}

Expand All @@ -40,7 +63,9 @@ def _save_shard(self, args) -> dict:

def _document_generator():
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
yield np.array(item["input_ids"], dtype=self._data_type.numpy)
yield np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(
item["token_spans"], dtype=np.int32
)

GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator())

Expand Down Expand Up @@ -126,6 +151,11 @@ def run(self):
)
if self._config.dataset.field not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.")
if (
self._config.dataset.spans_field is not None
and self._config.dataset.spans_field not in dataset.column_names
):
raise ValueError(f"Dataset does not have spans field '{self._config.dataset.spans_field}'.")

# Tokenize the dataset in parallel
tokenized_dataset = dataset.map(
Expand Down
20 changes: 17 additions & 3 deletions tests/test_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@

@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values())
def test_gpt_memmap_dataset(dtype):
documents = [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)]
documents = list(
zip(
[np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)],
np.array([[]] * 100, dtype=np.int32),
)
)
with tempfile.TemporaryDirectory() as temp_dir:
prefix = pathlib.Path(temp_dir)
GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents)
dataset = GPTMemmapDataset(name="foo", prefix=prefix)
for i, document in enumerate(documents):
for i, (document, spans) in enumerate(documents):
memmap_document, memmap_spans = dataset.get(i)
assert np.array_equal(
dataset.get(i), document, equal_nan=True
memmap_document, document, equal_nan=True
), f"Mismatch for document {i}: {document} != {dataset.get(i)}."
if len(spans) > 0:
assert np.array_equal(
memmap_spans, spans, equal_nan=True
), f"Mismatch for non-empty spans {i}: {spans} != {dataset.get(i)}."
else:
assert np.array_equal(
memmap_spans.flatten(), spans.flatten(), equal_nan=True
), f"Mismatch for empty spans {i}: {spans} != {dataset.get(i)}."
Loading