Skip to content

Commit

Permalink
fix(components): Write model resource_name to the output of training …
Browse files Browse the repository at this point in the history
…pipeline remote runner

PiperOrigin-RevId: 602426716
  • Loading branch information
KCFindstr authored and petethegreat committed Mar 27, 2024
1 parent 7dcd460 commit 7dd3d67
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions components/google-cloud/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Upcoming release
* Fix the missing output of pipeline remote runner. `AutoMLImageTrainingJobRunOp` now passes the model artifacts correctly to downstream components.

## Release 2.9.0
* Use `large_model_reference` for `model_reference_name` when uploading models from `preview.llm.rlhf_pipeline` instead of hardcoding value as `text-bison@001`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ def _parse_args(args: List[str]):
args.append('--payload')
args.append('"{}"') # Unused but required by parser_util.
parser, _ = parser_util.parse_default_args(args)
# Parse the conditionally required arguments
# Parse the conditionally required arguments.
parser.add_argument(
'--executor_input',
dest='executor_input',
type=str,
# executor_input is only needed for components that emit output artifacts.
required=True,
default=argparse.SUPPRESS,
)
parser.add_argument(
'--display_name',
dest='display_name',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""GCP remote runner for AutoML image training pipelines based on the AI Platform SDK."""

import json
import logging
from typing import Any, Dict, Optional, Sequence

Expand All @@ -25,6 +26,7 @@
from google.cloud.aiplatform import training_jobs
from google.cloud.aiplatform_v1.types import model
from google.cloud.aiplatform_v1.types import training_pipeline
from google_cloud_pipeline_components.container.v1.aiplatform import remote_runner
from google_cloud_pipeline_components.container.v1.gcp_launcher import pipeline_remote_runner
from google_cloud_pipeline_components.container.v1.gcp_launcher.utils import error_util

Expand Down Expand Up @@ -195,6 +197,7 @@ def create_pipeline(
project: str,
location: str,
gcp_resources: str,
executor_input: str,
**kwargs: Dict[str, Any],
):
"""Create and poll AutoML Vision training pipeline status till it reaches a final state.
Expand Down Expand Up @@ -222,29 +225,39 @@ def create_pipeline(
project: Project name.
location: Location to start the training job.
gcp_resources: URI for storing GCP resources.
executor_input: Pipeline executor input.
**kwargs: Extra args for creating the payload.
"""
remote_runner = pipeline_remote_runner.PipelineRemoteRunner(
runner = pipeline_remote_runner.PipelineRemoteRunner(
type, project, location, gcp_resources
)

try:
# Create AutoML vision training pipeline if it does not exist
pipeline_name = remote_runner.check_if_pipeline_exists()
pipeline_name = runner.check_if_pipeline_exists()
if pipeline_name is None:
payload = create_payload(project, location, **kwargs)
logging.info(
'AutoML Vision training payload formatted: %s',
payload,
)
pipeline_name = remote_runner.create_pipeline(
pipeline_name = runner.create_pipeline(
create_pipeline_with_client,
payload,
)

# Poll AutoML Vision training pipeline status until
# "PipelineState.PIPELINE_STATE_SUCCEEDED"
remote_runner.poll_pipeline(get_pipeline_with_client, pipeline_name)
pipeline = runner.poll_pipeline(get_pipeline_with_client, pipeline_name)

except (ConnectionError, RuntimeError) as err:
error_util.exit_with_internal_error(err.args[0])
return # No-op, suppressing uninitialized `pipeline` variable lint error.

# Writes artifact output on success.
if not isinstance(pipeline, training_pipeline.TrainingPipeline):
raise ValueError('Internal error: no training pipeline was created.')
remote_runner.write_to_artifact(
json.loads(executor_input),
pipeline.model_to_upload.name,
)

0 comments on commit 7dd3d67

Please sign in to comment.