From 579ce2749f97ce0fbcd90b68b2ae027c2933e474 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Wed, 30 Oct 2024 15:41:33 -0700 Subject: [PATCH] modify get_remaining_files function Signed-off-by: Sarah Yurick --- .../scripts/semdedup/compute_embeddings.py | 8 +++-- nemo_curator/utils/file_utils.py | 33 +++++++++++++++++-- nemo_curator/utils/script_utils.py | 2 +- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/nemo_curator/scripts/semdedup/compute_embeddings.py b/nemo_curator/scripts/semdedup/compute_embeddings.py index adead9d9..af5c8655 100644 --- a/nemo_curator/scripts/semdedup/compute_embeddings.py +++ b/nemo_curator/scripts/semdedup/compute_embeddings.py @@ -54,6 +54,7 @@ def main(args): input_file_path=args.input_data_dir, output_file_path=output_data_dir, input_file_type=input_file_extension, + output_file_type="parquet", num_files=semdedup_config.num_files, ) @@ -63,8 +64,9 @@ def main(args): return ddf = read_data( - # TODO - input_files=input_files, file_type=args.input_file_type, add_filename=False + input_files=input_files, + file_type=args.input_file_type, + add_filename=True, ) ddf = ddf.reset_index(drop=True) dataset = DocumentDataset(ddf) @@ -80,7 +82,7 @@ def main(args): ), input_column=args.input_text_field, logger=logger, - write_to_filename=False, # TODO + write_to_filename=True, ) embedding_dataset = embedding_creator(dataset=dataset) diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index 6bb45d2c..55364f12 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -17,7 +17,7 @@ import os import pathlib from functools import partial, reduce -from typing import List, Union +from typing import List, Optional, Union import dask.bag as db import dask.dataframe as dd @@ -73,7 +73,11 @@ def get_all_files_paths_under(root, recurse_subdirectories=True, followlinks=Fal # writing a file we can use the offset counter approach # in jaccard shuffle as a more robust way to restart jobs def get_remaining_files( - input_file_path, output_file_path, input_file_type, num_files=-1 + input_file_path: str, + output_file_path: str, + input_file_type: str, + output_file_type: Optional[str] = None, + num_files: int = -1, ): """ This function returns a list of the files that still remain to be read. @@ -82,6 +86,7 @@ def get_remaining_files( input_file_path: The path of the input files. output_file_path: The path of the output files. input_file_type: The type of the input files. + output_file_type: The type of the output files. num_files: The max number of files to be returned. If -1, all files are returned. Returns: A list of files that still remain to be read. @@ -96,10 +101,12 @@ def get_remaining_files( os.path.basename(entry.path) for entry in os.scandir(output_file_path) ] completed_files = set(completed_files) + input_files = [ entry.path for entry in os.scandir(input_file_path) - if os.path.basename(entry.path) not in completed_files + if os.path.basename(entry.path) + not in _update_filetype(completed_files, output_file_type, input_file_type) ] # Guard against non extension files if present in the input directory input_files = [f for f in input_files if f.endswith(input_file_type)] @@ -110,10 +117,30 @@ def get_remaining_files( left_to_sample = max(num_files - len_written_files, 0) else: left_to_sample = len(input_files) + input_files = input_files[:left_to_sample] return input_files +def _update_filetype(file_set, old_file_type, new_file_type): + if old_file_type is None or new_file_type is None: + return file_set + if old_file_type == new_file_type: + return file_set + + if not old_file_type.startswith("."): + old_file_type = "." + old_file_type + if not new_file_type.startswith("."): + new_file_type = "." + new_file_type + + updated_file_set = { + f"{os.path.splitext(file)[0]}{new_file_type}" + if file.endswith(old_file_type) else file + for file in file_set + } + return updated_file_set + + def get_batched_files( input_file_path, output_file_path, input_file_type, batch_size=64 ): diff --git a/nemo_curator/utils/script_utils.py b/nemo_curator/utils/script_utils.py index 6a257b9f..b4625bc0 100644 --- a/nemo_curator/utils/script_utils.py +++ b/nemo_curator/utils/script_utils.py @@ -171,7 +171,7 @@ def add_arg_id_column_type(self): "--id-column-type", type=str, default="int", - help="The datatype of the ID field, either \"int\" or \"str\".", + help='The datatype of the ID field, either "int" or "str".', ) def add_arg_minhash_length(self):