Skip to content

Commit

Permalink
Akoumparouli/nemo ux validate dataset asset accessibility (#10309)
Browse files Browse the repository at this point in the history
* Add validate_dataset_asset_accessibility

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add CI tests for validate_dataset_asset_accessibility

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix for zipped lists

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa and akoumpa committed Sep 8, 2024
1 parent 30385aa commit 2404c4e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
63 changes: 63 additions & 0 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
Expand All @@ -34,6 +35,66 @@
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec


def is_number_tryexcept(s):
"""Returns True if string is a number."""
if s is None:
return False
try:
float(s)
return True
except ValueError:
return False


def is_zipped_list(paths):
# ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
even = paths[::2]
if len(even) == 0:
return False
is_num = list(map(is_number_tryexcept, even))
if any(is_num):
assert all(is_num), "Got malformatted zipped list"
return is_num[0]


def validate_dataset_asset_accessibility(paths):
if paths is None:
raise ValueError("Expected path to have a value.")

if isinstance(paths, tuple) or isinstance(paths, list):
if is_zipped_list(paths):
# remove weights from paths.
paths = paths[1::2]
for p in paths:
validate_dataset_asset_accessibility(p)
return
elif isinstance(paths, dict):
for p in paths.values():
validate_dataset_asset_accessibility(p)
return

if not isinstance(paths, str) and not isisntance(paths, Path):
raise ValueError("Expected path to be of string or Path type.")

path = Path(paths)
suffices = ('.bin', '.idx')
if path.is_dir():
if not os.access(path, os.R_OK):
raise PermissionError(f"Expected {str(path)} to be readable.")
# Will let the downstream class confirm contents are ok.
return
if path.exists():
if not os.access(path, os.R_OK):
raise PermissionError(f"Expected {str(path)} to be readable.")
return
for suffix in suffices:
file_path = Path(str(path) + suffix)
if not file_path.exists():
raise FileNotFoundError(f"Expected {str(file_path)} to exist.")
if not os.access(file_path, os.R_OK):
raise PermissionError(f"Expected {str(file_path)} to be readable.")


class PreTrainingDataModule(pl.LightningDataModule, IOMixin):
"""PyTorch Lightning-compatible data module for pre-training
GPT-style models.
Expand Down Expand Up @@ -100,6 +161,8 @@ def __init__(

from megatron.core.datasets.utils import get_blend_from_list

validate_dataset_asset_accessibility(paths)

build_kwargs = {}
if isinstance(paths, dict):
if split is not None:
Expand Down
34 changes: 34 additions & 0 deletions tests/collections/llm/gpt/data/test_pre_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,37 @@ def test_multiple_data_distributions(tokenizer, trainer):

## this should succeed
data.setup(stage="dummy")


def test_validate_dataset_asset_accessibility_file_does_not_exist(tokenizer, trainer):
raised_exception = False
try:
data = PreTrainingDataModule(
paths=["/this/path/should/not/exist/"],
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
)
data.trainer = trainer
except FileNotFoundError:
raised_exception = True

assert raised_exception == True, "Expected to raise a FileNotFoundError"


def test_validate_dataset_asset_accessibility_file_is_none(tokenizer, trainer):
raised_exception = False
try:
data = PreTrainingDataModule(
paths=None,
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
tokenizer=tokenizer,
)
data.trainer = trainer
except ValueError:
raised_exception = True

assert raised_exception == True, "Expected to raise a ValueError"

0 comments on commit 2404c4e

Please sign in to comment.