Skip to content

Commit

Permalink
Make dfp_azure_pipeline inference output file configurable. (#1290)
Browse files Browse the repository at this point in the history
Closes #1287

Authors:
  - Devin Robison (https://github.com/drobison00)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #1290
  • Loading branch information
drobison00 authored Oct 19, 2023
1 parent 78c6e3a commit ec6f12c
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
type=str,
default="DFP-azure-{user_id}",
help="The MLflow model name template to use when logging models. ")
@click.option('--inference_detection_file_name', type=str, default="dfp_detections_azure.csv")
def run_pipeline(train_users,
skip_user: typing.Tuple[str],
only_user: typing.Tuple[str],
Expand All @@ -159,6 +160,7 @@ def run_pipeline(train_users,
filter_threshold,
mlflow_experiment_name_template,
mlflow_model_name_template,
inference_detection_file_name,
**kwargs):
"""Runs the DFP pipeline."""
# To include the generic, we must be training all or generic
Expand All @@ -167,7 +169,7 @@ def run_pipeline(train_users,
# To include individual, we must be either training or inferring
include_individual = train_users != "generic"

# None indicates we arent training anything
# None indicates we aren't training anything
is_training = train_users != "none"

skip_users = list(skip_user)
Expand Down Expand Up @@ -353,7 +355,7 @@ def run_pipeline(train_users,
pipeline.add_stage(SerializeStage(config, exclude=['batch_count', 'origin_hash', '_row_hash', '_batch_id']))

# Write all anomalies to a CSV file
pipeline.add_stage(WriteToFileStage(config, filename="dfp_detections_azure.csv", overwrite=True))
pipeline.add_stage(WriteToFileStage(config, filename=inference_detection_file_name, overwrite=True))

# Run the pipeline
pipeline.run()
Expand Down

0 comments on commit ec6f12c

Please sign in to comment.