Skip to content

Commit

Permalink
Update example_dag
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed Mar 8, 2022
1 parent 26172a4 commit 4d37a2a
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 13 deletions.
12 changes: 6 additions & 6 deletions airflow/providers/google/cloud/example_dags/example_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
* PYTHON_PACKAGE_GSC_URI - path to test model in archive.
* LOCAL_TRAINING_SCRIPT_PATH - path to local training script.
* DATASET_ID - ID of dataset which will be used in training process.
* MODEL_ID - ID of model which will be used in predict process.
* MODEL_ARTIFACT_URI - The artifact_uri should be the path to a GCS directory containing saved model
artifacts.
"""
import os
from datetime import datetime
Expand Down Expand Up @@ -186,7 +189,8 @@
{"numeric": {"column_name": "PhotoAmt"}},
]

MODEL_ID = "9182492194534064128"
MODEL_ID = os.environ.get("MODEL_ID", "test-model-id")
MODEL_ARTIFACT_URI = os.environ.get("MODEL_ARTIFACT_URI", "path_to_folder_with_model_artifacts")
MODEL_NAME = f"projects/{PROJECT_ID}/locations/{REGION}/models/{MODEL_ID}"
JOB_DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}"
BIGQUERY_SOURCE = f"bq://{PROJECT_ID}.test_iowa_liquor_sales_forecasting_us.2021_sales_predict"
Expand Down Expand Up @@ -219,11 +223,7 @@
}
MODEL_OBJ = {
"display_name": f"model-{str(uuid4())}",
# The artifact_uri should be the path to a GCS directory containing
# saved model artifacts. The bucket must be accessible for the
# project's AI Platform service account and in the same region as
# the api endpoint.
"artifact_uri": f"{STAGING_BUCKET}/aiplatform-custom-training-2021-11-26-12:12:09.339/model",
"artifact_uri": MODEL_ARTIFACT_URI,
"container_spec": {
"image_uri": MODEL_SERVING_CONTAINER_URI,
"command": [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ class GetBatchPredictionJobOperator(BaseOperator):
"""

template_fields = ("region", "project_id", "impersonation_chain")
operator_extra_links = (VertexAIBatchPredictionJobLink(),)

def __init__(
self,
Expand Down Expand Up @@ -408,6 +409,9 @@ def execute(self, context: 'Context'):
metadata=self.metadata,
)
self.log.info("Batch prediction job was gotten.")
VertexAIBatchPredictionJobLink.persist(
context=context, task_instance=self, batch_prediction_job_id=self.batch_prediction_job
)
return BatchPredictionJob.to_dict(result)
except NotFound:
self.log.info("The Batch prediction job %s does not exist.", self.batch_prediction_job)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.vertex_ai.endpoint_service import EndpointServiceHook
from airflow.providers.google.cloud.links.vertex_ai import VertexAIEndpointLink, VertexAIEndpointListLink
from airflow.providers.google.cloud.links.vertex_ai import (
VertexAIEndpointLink,
VertexAIEndpointListLink,
VertexAIModelLink,
)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -237,6 +241,7 @@ class DeployModelOperator(BaseOperator):
"""

template_fields = ("region", "endpoint_id", "project_id", "impersonation_chain")
operator_extra_links = (VertexAIModelLink(),)

def __init__(
self,
Expand Down Expand Up @@ -292,6 +297,7 @@ def execute(self, context: 'Context'):
self.log.info("Model was deployed. Deployed Model ID: %s", deployed_model_id)

self.xcom_push(context, key="deployed_model_id", value=deployed_model_id)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=deployed_model_id)
return deploy_model


Expand Down Expand Up @@ -588,6 +594,7 @@ class UpdateEndpointOperator(BaseOperator):
"""

template_fields = ("region", "endpoint_id", "project_id", "impersonation_chain")
operator_extra_links = (VertexAIEndpointLink(),)

def __init__(
self,
Expand Down Expand Up @@ -636,4 +643,5 @@ def execute(self, context: 'Context'):
metadata=self.metadata,
)
self.log.info("Endpoint was updated")
VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=self.endpoint_id)
return Endpoint.to_dict(result)
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,7 @@ class ListHyperparameterTuningJobOperator(BaseOperator):
"project_id",
"impersonation_chain",
]
operator_extra_links = [
VertexAIHyperparameterTuningJobListLink(),
]
operator_extra_links = (VertexAIHyperparameterTuningJobListLink(),)

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ class ListModelsOperator(BaseOperator):
"""

template_fields = ("region", "project_id", "impersonation_chain")
operator_extra_links = [
VertexAIModelListLink(),
]
operator_extra_links = (VertexAIModelListLink(),)

def __init__(
self,
Expand Down

0 comments on commit 4d37a2a

Please sign in to comment.