Skip to content

Commit

Permalink
chore(components/google-cloud): Add seed to tfp_anomaly_detection (#6244
Browse files Browse the repository at this point in the history
)
  • Loading branch information
g-luo authored Aug 6, 2021
1 parent 903b986 commit 7872bd4
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -620,15 +620,15 @@
"source": [
"@dsl.pipeline(\n",
" pipeline_root=PIPELINE_ROOT, name=PIPELINE_NAME)\n",
"def pipeline(input_url: str, memory_limit: str) -> None:\n",
"def pipeline(input_url: str, memory_limit: str, seed: int) -> None:\n",
" \"\"\"\n",
" Train model and return detected anomalies.\n",
" \"\"\"\n",
" input_task = kfp.dsl.importer(\n",
" artifact_uri=input_url,\n",
" artifact_class=Dataset)\n",
" preprocess_task = preprocess_op(input_dataset=input_task.output)\n",
" anomaly_detection_task = anomaly_detection_op(input_dataset=preprocess_task.output).set_memory_limit(memory_limit)\n",
" anomaly_detection_task = anomaly_detection_op(input_dataset=preprocess_task.output, seed=seed).set_memory_limit(memory_limit)\n",
" postprocess_op(input_dataset=input_task.output, predictions_dataset=anomaly_detection_task.output)"
],
"execution_count": null,
Expand Down Expand Up @@ -732,7 +732,8 @@
"source": [
"parameter_values = {\n",
" 'input_url': chosen_task_file,\n",
" 'memory_limit': '50G'\n",
" 'memory_limit': '50G',\n",
" 'seed': 0,\n",
"}\n",
"run_pipeline(pipeline, parameter_values=parameter_values)"
],
Expand Down

0 comments on commit 7872bd4

Please sign in to comment.