Skip to content
This repository has been archived by the owner on Dec 19, 2024. It is now read-only.

MLFlow: Lookup client_id and host_id from Kubernetes env variable #133

Merged
merged 13 commits into from
Feb 1, 2021
39 changes: 14 additions & 25 deletions datasetinsights/io/tracker/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ class TrackerFactory:
"""Factory: responsible for creating and holding singleton instance
of tracker classes"""

TRACKER = "tracker"
HOST_ID = "host"
MLFLOW_TRACKER = "mlflow"
__singleton_lock = threading.Lock()
__tracker_instance = None
RUN_FAILED = "FAILED"
TRACKER = "tracker"

@staticmethod
def create(config=None, tracker_type=None):
Expand All @@ -31,22 +30,16 @@ def create(config=None, tracker_type=None):
"""
if TrackerFactory.MLFLOW_TRACKER == tracker_type:

tracker = config.get(TrackerFactory.TRACKER, None)
if tracker and tracker.get(TrackerFactory.MLFLOW_TRACKER, None):
mlflow_config = tracker.get(TrackerFactory.MLFLOW_TRACKER)
if mlflow_config.get(TrackerFactory.HOST_ID, None):
try:
mlf_tracker = TrackerFactory._mlflow_tracker_instance(
mlflow_config
).get_mlflow()
logger.info("initializing mlflow_tracker")
return mlf_tracker
except Exception as e:
logger.warning(
"failed mlflow initialization, "
"starting null_tracker",
e,
)
try:
mlf_tracker = TrackerFactory._mlflow_tracker_instance(
config
).get_mlflow()
logger.info("initializing mlflow_tracker")
return mlf_tracker
except Exception as e:
logger.warning(
"failed mlflow initialization, " "starting null_tracker", e,
)
86sanj marked this conversation as resolved.
Show resolved Hide resolved

logger.info("initializing null_tracker")
return TrackerFactory._null_tracker()
Expand All @@ -55,23 +48,19 @@ def create(config=None, tracker_type=None):
raise InvalidTrackerError

@staticmethod
def _mlflow_tracker_instance(mlflow_config):
def _mlflow_tracker_instance(config):

"""Static instance access method.

Args:
host_id: MlTracker server host
client_id: MLFlow tracking server client id
exp_name: name of the experiment
config : config object, holds server details
Returns:
tracker singleton instance.
"""
if not TrackerFactory.__tracker_instance:
with TrackerFactory.__singleton_lock:
if not TrackerFactory.__tracker_instance:
TrackerFactory.__tracker_instance = MLFlowTracker(
mlflow_config
)
TrackerFactory.__tracker_instance = MLFlowTracker(config)
logger.info("getting tracker instance")
return TrackerFactory.__tracker_instance

Expand Down
60 changes: 48 additions & 12 deletions datasetinsights/io/tracker/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,27 @@ class MLFlowTracker:
EXP_NAME = "experiment"
RUN_NAME = "run"
DEFAULT_RUN_NAME = "run-" + TIMESTAMP_SUFFIX
TRACKER = "tracker"
MLFLOW_TRACKER = "mlflow"
DEFAULT_EXP_NAME = "datasetinsights"

def __init__(self, mlflow_config):
def __init__(self, config):
86sanj marked this conversation as resolved.
Show resolved Hide resolved
"""constructor.
Args:
mlflow_config:map of mlflow configuration
config : config object, holds run details
"""
host_id = mlflow_config.get(MLFlowTracker.HOST_ID)
client_id = mlflow_config.get(MLFlowTracker.CLIENT_ID, None)
exp_name = mlflow_config.get(MLFlowTracker.EXP_NAME, None)
run_name = mlflow_config.get(MLFlowTracker.RUN_NAME, None)
if not run_name:
run_name = MLFlowTracker.DEFAULT_RUN_NAME
logger.info(f"setting default mlflow run name: {run_name}")
client_id, host_id, run_name, exp_name = MLFlowTracker._get_variables(
config
)

if client_id:
_refresh_token(client_id)
thread = RefreshTokenThread(client_id)
thread.daemon = True
thread.start()
mlflow.set_tracking_uri(host_id)
if exp_name:
mlflow.set_experiment(experiment_name=exp_name)
logger.info(f"setting mlflow experiment name: {exp_name}")
mlflow.set_experiment(experiment_name=exp_name)
logger.info(f"setting mlflow experiment name: {exp_name}")

self.__mlflow = mlflow
self.__mlflow.start_run(run_name=run_name)
Expand All @@ -80,6 +79,43 @@ def get_mlflow(self):
logger.info("get mlflow")
return self.__mlflow

@staticmethod
def _get_variables(config):
"""initialize mlflow variables.
Args:
config : config object, holds run details
Returns:
client_id: MLFlow tracking server client id
host_id: MLFlow tracking server host id
run_name: run name
exp_name: experiment name
"""
client_id = os.environ.get("MLFLOW_CLIENT_ID", None)
host_id = os.environ.get("MLFLOW_HOST_ID", None)
run_name = MLFlowTracker.DEFAULT_RUN_NAME
exp_name = MLFlowTracker.DEFAULT_EXP_NAME
tracker = config.get(MLFlowTracker.TRACKER, None)
logger.debug(
f"client_id:{client_id} and host_id: {host_id} from "
f"kubernetes env variable"
)
if tracker and tracker.get(MLFlowTracker.MLFLOW_TRACKER, None):
mlflow_config = tracker.get(MLFlowTracker.MLFLOW_TRACKER)
host_id = mlflow_config.get(MLFlowTracker.HOST_ID, host_id)
client_id = mlflow_config.get(MLFlowTracker.CLIENT_ID, client_id)
run_name = mlflow_config.get(MLFlowTracker.RUN_NAME, run_name)
exp_name = mlflow_config.get(MLFlowTracker.EXP_NAME, exp_name)
logger.debug(
f"client_id:{client_id} and host_id:{host_id} from yaml config"
)
logger.info(
f"client_id:{client_id} and host_id:{host_id} connecting to mlflow"
)
if not host_id:
logger.warning(f"host_id not found")
raise ValueError("host_id not configured")
return client_id, host_id, run_name, exp_name


class RefreshTokenThread(threading.Thread):
""" Its service thread which keeps running till main thread runs
Expand Down
8 changes: 7 additions & 1 deletion kubeflow/compiled/evaluate_the_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: evaluate-the-model-
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2020-11-04T17:16:55.837308',
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2021-01-04T15:53:39.220225',
pipelines.kubeflow.org/pipeline_spec: '{"description": "Evaluate the model", "inputs":
[{"default": "unitytechnologies/datasetinsights:latest", "name": "docker", "optional":
true, "type": "String"}, {"default": "https://storage.googleapis.com/datasetinsights/data/groceries/v3.zip",
Expand Down Expand Up @@ -47,6 +47,12 @@ spec:
--kfp-ui-metadata-filename=kfp_ui_metadata.json, --kfp-metrics-filename=kfp_metrics.json]
command: [datasetinsights, evaluate]
env:
- name: MLFLOW_HOST_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_HOST_ID}
- name: MLFLOW_CLIENT_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_CLIENT_ID}
- {name: GOOGLE_APPLICATION_CREDENTIALS, value: /secret/gcp-credentials/user-gcp-sa.json}
- {name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE, value: /secret/gcp-credentials/user-gcp-sa.json}
image: '{{inputs.parameters.docker}}'
Expand Down
8 changes: 7 additions & 1 deletion kubeflow/compiled/train_on_real_world_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: train-on-real-world-dataset-
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2020-11-04T17:16:53.753331',
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2021-01-04T15:53:36.873123',
pipelines.kubeflow.org/pipeline_spec: '{"description": "Train on real world dataset",
"inputs": [{"default": "unitytechnologies/datasetinsights:latest", "name": "docker",
"optional": true, "type": "String"}, {"default": "https://storage.googleapis.com/datasetinsights/data/groceries/v3.zip",
Expand Down Expand Up @@ -72,6 +72,12 @@ spec:
--kfp-ui-metadata-filename=kfp_ui_metadata.json, '--checkpoint-dir={{inputs.parameters.checkpoint_dir}}']
command: [datasetinsights, train]
env:
- name: MLFLOW_HOST_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_HOST_ID}
- name: MLFLOW_CLIENT_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_CLIENT_ID}
- {name: GOOGLE_APPLICATION_CREDENTIALS, value: /secret/gcp-credentials/user-gcp-sa.json}
- {name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE, value: /secret/gcp-credentials/user-gcp-sa.json}
image: '{{inputs.parameters.docker}}'
Expand Down
8 changes: 7 additions & 1 deletion kubeflow/compiled/train_on_synthdet_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: train-on-the-synthdet-sample-
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2020-11-04T17:16:52.999713',
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2021-01-04T15:53:36.094994',
pipelines.kubeflow.org/pipeline_spec: '{"description": "Train on the SynthDet
sample", "inputs": [{"default": "unitytechnologies/datasetinsights:latest",
"name": "docker", "optional": true, "type": "String"}, {"default": "https://storage.googleapis.com/datasetinsights/data/synthetic/SynthDet.zip",
Expand Down Expand Up @@ -73,6 +73,12 @@ spec:
command: [python, -m, torch.distributed.launch, --nproc_per_node=8, --use_env,
datasetinsights, train]
env:
- name: MLFLOW_HOST_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_HOST_ID}
- name: MLFLOW_CLIENT_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_CLIENT_ID}
- {name: GOOGLE_APPLICATION_CREDENTIALS, value: /secret/gcp-credentials/user-gcp-sa.json}
- {name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE, value: /secret/gcp-credentials/user-gcp-sa.json}
image: '{{inputs.parameters.docker}}'
Expand Down
8 changes: 7 additions & 1 deletion kubeflow/compiled/train_on_synthetic_and_real_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: train-on-synthetic-real-world-dataset-
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2020-11-04T17:16:54.469730',
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2021-01-04T15:53:37.650746',
pipelines.kubeflow.org/pipeline_spec: '{"description": "Train on Synthetic + Real
World Dataset", "inputs": [{"default": "unitytechnologies/datasetinsights:latest",
"name": "docker", "optional": true, "type": "String"}, {"default": "https://storage.googleapis.com/datasetinsights/data/groceries/v3.zip",
Expand Down Expand Up @@ -75,6 +75,12 @@ spec:
'--checkpoint-file={{inputs.parameters.checkpoint_file}}']
command: [datasetinsights, train]
env:
- name: MLFLOW_HOST_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_HOST_ID}
- name: MLFLOW_CLIENT_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_CLIENT_ID}
- {name: GOOGLE_APPLICATION_CREDENTIALS, value: /secret/gcp-credentials/user-gcp-sa.json}
- {name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE, value: /secret/gcp-credentials/user-gcp-sa.json}
image: '{{inputs.parameters.docker}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: train-on-synthetic-dataset-unity-simulation-
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2020-11-04T17:16:55.152591',
annotations: {pipelines.kubeflow.org/kfp_sdk_version: 1.0.1, pipelines.kubeflow.org/pipeline_compilation_time: '2021-01-04T15:53:38.437907',
pipelines.kubeflow.org/pipeline_spec: '{"description": "Train on synthetic dataset
Unity Simulation", "inputs": [{"default": "unitytechnologies/datasetinsights:latest",
"name": "docker", "optional": true, "type": "String"}, {"default": "<unity-project-id>",
Expand Down Expand Up @@ -79,6 +79,12 @@ spec:
command: [python, -m, torch.distributed.launch, --nproc_per_node=8, --use_env,
datasetinsights, train]
env:
- name: MLFLOW_HOST_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_HOST_ID}
- name: MLFLOW_CLIENT_ID
valueFrom:
secretKeyRef: {name: dev-mlflow-secret, key: MLFLOW_CLIENT_ID}
- {name: GOOGLE_APPLICATION_CREDENTIALS, value: /secret/gcp-credentials/user-gcp-sa.json}
- {name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE, value: /secret/gcp-credentials/user-gcp-sa.json}
image: '{{inputs.parameters.docker}}'
Expand Down
15 changes: 15 additions & 0 deletions kubeflow/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
KFP_UI_METADATA_FILENAME = "kfp_ui_metadata.json"
KFP_METRICS_FILENAME = "kfp_metrics.json"

mlflow_host_env = {
"name": "MLFLOW_HOST_ID",
"valueFrom": {
"secretKeyRef": {"name": "dev-mlflow-secret", "key": "MLFLOW_HOST_ID"}
},
}
mlflow_client_env = {
"name": "MLFLOW_CLIENT_ID",
"valueFrom": {
"secretKeyRef": {"name": "dev-mlflow-secret", "key": "MLFLOW_CLIENT_ID"}
},
}


def volume_op(*, volume_size):
""" Create Kubernetes persistant volume to store data.
Expand Down Expand Up @@ -134,6 +147,7 @@ def train_op(
command=command,
arguments=arguments,
pvolumes={DATA_PATH: volume},
container_kwargs={"env": [mlflow_host_env, mlflow_client_env]},
file_outputs={
"mlpipeline-ui-metadata": os.path.join(
KFP_LOG_DIR, KFP_UI_METADATA_FILENAME
Expand Down Expand Up @@ -198,6 +212,7 @@ def evaluate_op(
f"--kfp-ui-metadata-filename={KFP_UI_METADATA_FILENAME}",
f"--kfp-metrics-filename={KFP_METRICS_FILENAME}",
],
container_kwargs={"env": [mlflow_host_env, mlflow_client_env]},
file_outputs={
"mlpipeline-metrics": os.path.join(
KFP_LOG_DIR, KFP_METRICS_FILENAME
Expand Down
6 changes: 3 additions & 3 deletions tests/configs/faster_rcnn_groceries_real_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ val_interval: 1
tracker:
mlflow:
experiment: datasetinsights
run:
client_id:
host:
run: test
client_id: test
host: test
42 changes: 23 additions & 19 deletions tests/test_mlflow_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ def test_refresh_token(mock_id_token):
mock_id_token.assert_called_once()


@patch("datasetinsights.io.tracker.mlflow.MLFlowTracker")
@patch("datasetinsights.io.tracker.factory.MLFlowTracker")
def test_get_mltracker_instance(mock_tracker, config):
mlflow_config = config["tracker"].get(TrackerFactory.MLFLOW_TRACKER)
tf = TrackerFactory()
instance1 = tf._mlflow_tracker_instance(mlflow_config)
instance2 = tf._mlflow_tracker_instance(mlflow_config)
instance1 = tf._mlflow_tracker_instance(config)
instance2 = tf._mlflow_tracker_instance(config)
assert instance1 == instance2


Expand All @@ -59,22 +58,20 @@ def test_get_nulltracker_instance(mock_tracker):
def test_factory_create_mltracker(mock_get_tracker, config):
mock_mlflow = MagicMock()
mock_get_tracker.return_value = mock_mlflow
mlflow_config = config["tracker"].get(TrackerFactory.MLFLOW_TRACKER)
config.tracker.mlflow.client_id = CLIENT_ID
config.tracker.mlflow.host = HOST_ID
config.tracker.mlflow.experiment = EXP_NAME
TrackerFactory.create(config, TrackerFactory.MLFLOW_TRACKER)
mock_get_tracker.assert_called_with(mlflow_config)
mock_get_tracker.assert_called_with(config)
mock_mlflow.get_mlflow.assert_called_once()


@patch(
"datasetinsights.io.tracker.factory.TrackerFactory."
"_mlflow_tracker_instance"
)
@patch("datasetinsights.io.tracker.factory.TrackerFactory._null_tracker")
def test_factory_create_nulltracker(mock_get_tracker, config):
config.tracker.mlflow.client_id = None
config.tracker.mlflow.host = None
config.tracker.mlflow.experiment = None
TrackerFactory.create(config, TrackerFactory.MLFLOW_TRACKER)
mock_get_tracker.assert_called_once()
def test_factory_create_nulltracker(mock_get_tracker, mock_instance, config):
with pytest.raises(Exception):
TrackerFactory.create(config, TrackerFactory.MLFLOW_TRACKER)
mock_get_tracker.assert_called_once()


@patch("datasetinsights.io.tracker.factory.NullTracker._stdout_handler")
Expand All @@ -91,8 +88,7 @@ def test_mLflow_tracker(mock_refresh, mock_mlflow, mock_thread_start, config):
config.tracker.mlflow.client_id = CLIENT_ID
config.tracker.mlflow.host = HOST_ID
config.tracker.mlflow.experiment = EXP_NAME
mlflow_config = config["tracker"].get(TrackerFactory.MLFLOW_TRACKER)
MLFlowTracker(mlflow_config)
MLFlowTracker(config)
mock_thread_start.assert_called_once()
mock_refresh.assert_called_once()
mock_mlflow.set_tracking_uri.assert_called_with(HOST_ID)
Expand All @@ -113,8 +109,16 @@ def test_mLflow_tracker_run(mock_refresh, mock_mlflow, mock_thread_run, config):
config.tracker.mlflow.client_id = CLIENT_ID
config.tracker.mlflow.host = HOST_ID
config.tracker.mlflow.experiment = EXP_NAME
mlflow_config = config["tracker"].get(TrackerFactory.MLFLOW_TRACKER)
MLFlowTracker(mlflow_config)
MLFlowTracker(config)
mock_thread_run.assert_called_once()
mock_refresh.assert_called_once()
mock_mlflow.set_tracking_uri.assert_called_with(HOST_ID)


def test_get_variables(config):
config.tracker.mlflow.client_id = CLIENT_ID
config.tracker.mlflow.host = HOST_ID
config.tracker.mlflow.experiment = EXP_NAME
config.tracker.mlflow.run = RUN_NAME
return_val = MLFlowTracker._get_variables(config)
assert (CLIENT_ID, HOST_ID, RUN_NAME, EXP_NAME) == return_val