Skip to content

Commit

Permalink
modify get_remaining_files function
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
  • Loading branch information
sarahyurick committed Oct 30, 2024
1 parent b79f9dd commit 579ce27
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
8 changes: 5 additions & 3 deletions nemo_curator/scripts/semdedup/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def main(args):
input_file_path=args.input_data_dir,
output_file_path=output_data_dir,
input_file_type=input_file_extension,
output_file_type="parquet",
num_files=semdedup_config.num_files,
)

Expand All @@ -63,8 +64,9 @@ def main(args):
return

ddf = read_data(
# TODO
input_files=input_files, file_type=args.input_file_type, add_filename=False
input_files=input_files,
file_type=args.input_file_type,
add_filename=True,
)
ddf = ddf.reset_index(drop=True)
dataset = DocumentDataset(ddf)
Expand All @@ -80,7 +82,7 @@ def main(args):
),
input_column=args.input_text_field,
logger=logger,
write_to_filename=False, # TODO
write_to_filename=True,
)

embedding_dataset = embedding_creator(dataset=dataset)
Expand Down
33 changes: 30 additions & 3 deletions nemo_curator/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import pathlib
from functools import partial, reduce
from typing import List, Union
from typing import List, Optional, Union

import dask.bag as db
import dask.dataframe as dd
Expand Down Expand Up @@ -73,7 +73,11 @@ def get_all_files_paths_under(root, recurse_subdirectories=True, followlinks=Fal
# writing a file we can use the offset counter approach
# in jaccard shuffle as a more robust way to restart jobs
def get_remaining_files(
input_file_path, output_file_path, input_file_type, num_files=-1
input_file_path: str,
output_file_path: str,
input_file_type: str,
output_file_type: Optional[str] = None,
num_files: int = -1,
):
"""
This function returns a list of the files that still remain to be read.
Expand All @@ -82,6 +86,7 @@ def get_remaining_files(
input_file_path: The path of the input files.
output_file_path: The path of the output files.
input_file_type: The type of the input files.
output_file_type: The type of the output files.
num_files: The max number of files to be returned. If -1, all files are returned.
Returns:
A list of files that still remain to be read.
Expand All @@ -96,10 +101,12 @@ def get_remaining_files(
os.path.basename(entry.path) for entry in os.scandir(output_file_path)
]
completed_files = set(completed_files)

input_files = [
entry.path
for entry in os.scandir(input_file_path)
if os.path.basename(entry.path) not in completed_files
if os.path.basename(entry.path)
not in _update_filetype(completed_files, output_file_type, input_file_type)
]
# Guard against non extension files if present in the input directory
input_files = [f for f in input_files if f.endswith(input_file_type)]
Expand All @@ -110,10 +117,30 @@ def get_remaining_files(
left_to_sample = max(num_files - len_written_files, 0)
else:
left_to_sample = len(input_files)

input_files = input_files[:left_to_sample]
return input_files


def _update_filetype(file_set, old_file_type, new_file_type):
if old_file_type is None or new_file_type is None:
return file_set
if old_file_type == new_file_type:
return file_set

if not old_file_type.startswith("."):
old_file_type = "." + old_file_type
if not new_file_type.startswith("."):
new_file_type = "." + new_file_type

updated_file_set = {
f"{os.path.splitext(file)[0]}{new_file_type}"
if file.endswith(old_file_type) else file
for file in file_set
}
return updated_file_set


def get_batched_files(
input_file_path, output_file_path, input_file_type, batch_size=64
):
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def add_arg_id_column_type(self):
"--id-column-type",
type=str,
default="int",
help="The datatype of the ID field, either \"int\" or \"str\".",
help='The datatype of the ID field, either "int" or "str".',
)

def add_arg_minhash_length(self):
Expand Down

0 comments on commit 579ce27

Please sign in to comment.