Skip to content

Commit

Permalink
add args
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
  • Loading branch information
sarahyurick committed Oct 30, 2024
1 parent 687c9bc commit b79f9dd
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion nemo_curator/scripts/semdedup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Please edit `config/sem_dedup_config.yaml` to configure the pipeline and run it

2) Compute embeddings:
```sh
semdedup_extract_embeddings --input-data-dir "$INPUT_DATA_DIR" --input-file-type "jsonl" --input-file-extension "json" --input-column "text" --config-file "$CONFIG_FILE"
semdedup_extract_embeddings --input-data-dir "$INPUT_DATA_DIR" --input-file-type "jsonl" --input-file-extension "json" --input-text-field "text" --config-file "$CONFIG_FILE"
```
**Input:** `input_data_dir/*.jsonl` and YAML file from step (1)

Expand Down
4 changes: 2 additions & 2 deletions nemo_curator/scripts/semdedup/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(args):
embedding_dataset = DocumentDataset(embedding_df)

clustering_model = ClusteringModel(
id_column=id_column, # TODO
id_column=args.id_column,
max_iter=semdedup_config.max_iter,
n_clusters=semdedup_config.n_clusters,
clustering_output_dir=clustering_output_dir,
Expand Down Expand Up @@ -96,7 +96,7 @@ def attach_args():
" max_iter for the maximum iterations for clustering,"
" kmeans_with_cos_dist for using KMeans with cosine distance,"
),
add_input_args=False,
add_input_args=True,
)
return parser

Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/scripts/semdedup/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(args):
embedding_output_dir=os.path.join(
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
),
input_column=input_column, # TODO
input_column=args.input_text_field,
logger=logger,
write_to_filename=False, # TODO
)
Expand Down
6 changes: 3 additions & 3 deletions nemo_curator/scripts/semdedup/extract_dedup_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def main(args):
sorted_clusters_dir=os.path.join(
cache_dir, semdedup_config.clustering_save_loc, "sorted"
),
id_col=id_column, # TODO
id_col_type=id_column_type, # TODO
id_col=args.id_column,
id_col_type=args.id_column_type,
which_to_keep=semdedup_config.which_to_keep,
output_dir=os.path.join(
semdedup_config.cache_dir, semdedup_config.clustering_save_loc
Expand Down Expand Up @@ -75,7 +75,7 @@ def attach_args():
"eps_thresholds for epsilon thresholds to calculate if semantically similar or not"
"and eps_to_extract for the epsilon value to extract deduplicated data."
),
add_input_args=False,
add_input_args=True,
)
return parser

Expand Down
21 changes: 20 additions & 1 deletion nemo_curator/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,23 @@ def add_arg_input_text_field(self):
"file that contains the text.",
)

def add_arg_id_column(self):
self.parser.add_argument(
"--id-column",
type=str,
default="id",
help="The name of the field within each datapoint object of the input "
"file that contains the ID.",
)

def add_arg_id_column_type(self):
self.parser.add_argument(
"--id-column-type",
type=str,
default="int",
help="The datatype of the ID field, either \"int\" or \"str\".",
)

def add_arg_minhash_length(self):
self.parser.add_argument(
"--minhash-length",
Expand Down Expand Up @@ -556,10 +573,12 @@ def parse_semdedup_args(
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_args()
if add_input_args:
argumentHelper.add_arg_input_data_dir(required=True)
argumentHelper.add_arg_input_data_dir()
argumentHelper.add_arg_input_file_extension()
argumentHelper.add_arg_input_file_type()
argumentHelper.add_arg_input_text_field()
argumentHelper.add_arg_id_column()
argumentHelper.add_arg_id_column_type()

argumentHelper.parser.add_argument(
"--config-file",
Expand Down

0 comments on commit b79f9dd

Please sign in to comment.