From 2404c4e9cce0c77e06b6dc6f8f191932db33f6b9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Sun, 8 Sep 2024 11:38:11 -0700 Subject: [PATCH] Akoumparouli/nemo ux validate dataset asset accessibility (#10309) * Add validate_dataset_asset_accessibility Signed-off-by: Alexandros Koumparoulis * Add CI tests for validate_dataset_asset_accessibility Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * fix Signed-off-by: Alexandros Koumparoulis * fix for zipped lists Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * fix Signed-off-by: Alexandros Koumparoulis --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Co-authored-by: akoumpa --- nemo/collections/llm/gpt/data/pre_training.py | 63 +++++++++++++++++++ .../llm/gpt/data/test_pre_training_data.py | 34 ++++++++++ 2 files changed, 97 insertions(+) diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index ccb2d21729ed..534922efe3a3 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -34,6 +35,66 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +def is_number_tryexcept(s): + """Returns True if string is a number.""" + if s is None: + return False + try: + float(s) + return True + except ValueError: + return False + + +def is_zipped_list(paths): + # ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"] + even = paths[::2] + if len(even) == 0: + return False + is_num = list(map(is_number_tryexcept, even)) + if any(is_num): + assert all(is_num), "Got malformatted zipped list" + return is_num[0] + + +def validate_dataset_asset_accessibility(paths): + if paths is None: + raise ValueError("Expected path to have a value.") + + if isinstance(paths, tuple) or isinstance(paths, list): + if is_zipped_list(paths): + # remove weights from paths. + paths = paths[1::2] + for p in paths: + validate_dataset_asset_accessibility(p) + return + elif isinstance(paths, dict): + for p in paths.values(): + validate_dataset_asset_accessibility(p) + return + + if not isinstance(paths, str) and not isisntance(paths, Path): + raise ValueError("Expected path to be of string or Path type.") + + path = Path(paths) + suffices = ('.bin', '.idx') + if path.is_dir(): + if not os.access(path, os.R_OK): + raise PermissionError(f"Expected {str(path)} to be readable.") + # Will let the downstream class confirm contents are ok. + return + if path.exists(): + if not os.access(path, os.R_OK): + raise PermissionError(f"Expected {str(path)} to be readable.") + return + for suffix in suffices: + file_path = Path(str(path) + suffix) + if not file_path.exists(): + raise FileNotFoundError(f"Expected {str(file_path)} to exist.") + if not os.access(file_path, os.R_OK): + raise PermissionError(f"Expected {str(file_path)} to be readable.") + + class PreTrainingDataModule(pl.LightningDataModule, IOMixin): """PyTorch Lightning-compatible data module for pre-training GPT-style models. @@ -100,6 +161,8 @@ def __init__( from megatron.core.datasets.utils import get_blend_from_list + validate_dataset_asset_accessibility(paths) + build_kwargs = {} if isinstance(paths, dict): if split is not None: diff --git a/tests/collections/llm/gpt/data/test_pre_training_data.py b/tests/collections/llm/gpt/data/test_pre_training_data.py index 31a7b51cdf53..24dacc7bf33c 100644 --- a/tests/collections/llm/gpt/data/test_pre_training_data.py +++ b/tests/collections/llm/gpt/data/test_pre_training_data.py @@ -78,3 +78,37 @@ def test_multiple_data_distributions(tokenizer, trainer): ## this should succeed data.setup(stage="dummy") + + +def test_validate_dataset_asset_accessibility_file_does_not_exist(tokenizer, trainer): + raised_exception = False + try: + data = PreTrainingDataModule( + paths=["/this/path/should/not/exist/"], + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + tokenizer=tokenizer, + ) + data.trainer = trainer + except FileNotFoundError: + raised_exception = True + + assert raised_exception == True, "Expected to raise a FileNotFoundError" + + +def test_validate_dataset_asset_accessibility_file_is_none(tokenizer, trainer): + raised_exception = False + try: + data = PreTrainingDataModule( + paths=None, + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + tokenizer=tokenizer, + ) + data.trainer = trainer + except ValueError: + raised_exception = True + + assert raised_exception == True, "Expected to raise a ValueError"