Skip to content

Commit

Permalink
FIX #102 - MlflowMetricsDataSet now logs in the specified run_id when…
Browse files Browse the repository at this point in the history
… prefix is not provided
  • Loading branch information
Galileo-Galilei committed Oct 25, 2020
1 parent e12a74c commit 4548dad
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- kedro_mlflow now works fine with kedro jupyter notebook independently of the working directory (#64)
- You can use global variables in `mlflow.yml` which is now properly parsed if you use a `TemplatedConfigLoader` (#72)
- `mlflow init` is now getting conf path from context.CONF_ROOT instead of hardcoded conf folder. This makes the package robust to Kedro changes.
- `MlflowMetricsDataset` now saves in the specified `run_id` instead of the current one when the prefix is not specified (#102)

### Changed

Expand Down
6 changes: 3 additions & 3 deletions kedro_mlflow/framework/hooks/pipeline_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def after_catalog_created(
catalog._data_sets[name] = MlflowMetricsDataSet(
run_id=dataset._run_id, prefix=name
)
catalog._data_sets[name] = MlflowMetricsDataSet(prefix=name)
else:
catalog._data_sets[name] = MlflowMetricsDataSet(prefix=name)

@hook_impl
def before_pipeline_run(
Expand Down Expand Up @@ -74,8 +75,7 @@ def before_pipeline_run(

mlflow_conf = get_mlflow_config(self.context)
mlflow.set_tracking_uri(mlflow_conf.mlflow_tracking_uri)
# TODO : if the pipeline fails, we need to be able to end stop the mlflow run
# cannot figure out how to do this within hooks

run_name = (
mlflow_conf.run_opts["name"]
if mlflow_conf.run_opts["name"] is not None
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ def _get_local_logging_config():
"formatters": {
"simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
},
"root": {"level": "INFO", "handlers": ["console"]},
"root": {"level": "ERROR", "handlers": ["console"]},
"loggers": {
"kedro": {"level": "INFO", "handlers": ["console"], "propagate": False}
"kedro": {"level": "ERROR", "handlers": ["console"], "propagate": False}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"level": "ERROR",
"formatter": "simple",
"stream": "ext://sys.stdout",
}
},
"info_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "INFO",
"level": "ERROR",
"formatter": "simple",
"filename": "logs/info.log",
},
Expand Down
102 changes: 95 additions & 7 deletions tests/framework/hooks/test_pipeline_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def preprocess_fun(data):
def train_fun(data, param):
return 2

def metric_fun():
return {"metric": {"value": 1.1, "step": 0}}
def metric_fun(data, model):
return {"metric_key": {"value": 1.1, "step": 0}}

def predict_fun(model, data):
return data * model
Expand All @@ -143,8 +143,18 @@ def predict_fun(model, data):
outputs="model",
tags=["training"],
),
node(func=metric_fun, inputs=None, outputs="metrics",),
node(func=metric_fun, inputs=None, outputs="another_metrics",),
node(
func=metric_fun,
inputs=["model", "data"],
outputs="my_metrics",
tags=["training"],
),
node(
func=metric_fun,
inputs=["model", "data"],
outputs="another_metrics",
tags=["training"],
),
node(
func=predict_fun,
inputs=["model", "data"],
Expand Down Expand Up @@ -177,7 +187,7 @@ def dummy_catalog(tmp_path):
"params:unused_param": MemoryDataSet("blah"),
"data": MemoryDataSet(),
"model": PickleDataSet((tmp_path / "model.csv").as_posix()),
"metrics": MlflowMetricsDataSet(),
"my_metrics": MlflowMetricsDataSet(),
"another_metrics": MlflowMetricsDataSet(prefix="foo"),
}
)
Expand Down Expand Up @@ -258,7 +268,7 @@ def test_mlflow_pipeline_hook_with_different_pipeline_types(
pipeline_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=pipeline_to_run, catalog=dummy_catalog
)
runner.run(pipeline_to_run, dummy_catalog, dummy_run_params["run_id"])
runner.run(pipeline_to_run, dummy_catalog)
run_id = mlflow.active_run().info.run_id
pipeline_hook.after_pipeline_run(
run_params=dummy_run_params, pipeline=pipeline_to_run, catalog=dummy_catalog
Expand All @@ -281,10 +291,88 @@ def test_mlflow_pipeline_hook_with_different_pipeline_types(
assert nb_artifacts == 0
# Check if metrics datasets have prefix with its names.
# for metric
assert dummy_catalog._data_sets["metrics"]._prefix == "metrics"
assert dummy_catalog._data_sets["my_metrics"]._prefix == "my_metrics"
assert dummy_catalog._data_sets["another_metrics"]._prefix == "foo"


def test_mlflow_pipeline_hook_metrics_with_run_id(
mocker,
monkeypatch,
tmp_path,
config_dir,
env_from_dict,
dummy_pipeline_ml,
dummy_run_params,
dummy_mlflow_conf,
):
# config_with_base_mlflow_conf is a conftest fixture
mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True)
monkeypatch.chdir(tmp_path)

context = load_context(tmp_path)
mlflow_conf = get_mlflow_config(context)
mlflow.set_tracking_uri(mlflow_conf.mlflow_tracking_uri)

with mlflow.start_run():
existing_run_id = mlflow.active_run().info.run_id

dummy_catalog_with_run_id = DataCatalog(
{
"raw_data": MemoryDataSet(1),
"params:unused_param": MemoryDataSet("blah"),
"data": MemoryDataSet(),
"model": PickleDataSet((tmp_path / "model.csv").as_posix()),
"my_metrics": MlflowMetricsDataSet(run_id=existing_run_id),
"another_metrics": MlflowMetricsDataSet(
run_id=existing_run_id, prefix="foo"
),
}
)

pipeline_hook = MlflowPipelineHook()

runner = SequentialRunner()
pipeline_hook.after_catalog_created(
catalog=dummy_catalog_with_run_id,
# `after_catalog_created` is not using any of arguments bellow,
# so we are setting them to empty values.
conf_catalog={},
conf_creds={},
feed_dict={},
save_version="",
load_versions="",
run_id=dummy_run_params["run_id"],
)
pipeline_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=dummy_pipeline_ml,
catalog=dummy_catalog_with_run_id,
)
runner.run(dummy_pipeline_ml, dummy_catalog_with_run_id)

current_run_id = mlflow.active_run().info.run_id

pipeline_hook.after_pipeline_run(
run_params=dummy_run_params,
pipeline=dummy_pipeline_ml,
catalog=dummy_catalog_with_run_id,
)

mlflow_client = MlflowClient(mlflow_conf.mlflow_tracking_uri)
all_runs_id = set(
[run.run_id for run in mlflow_client.list_run_infos(experiment_id="0")]
)

# the metrics are supposed to have been logged inside existing_run_id
run_data = mlflow_client.get_run(existing_run_id).data

# Check if metrics datasets have prefix with its names.
# for metric
assert all_runs_id == {current_run_id, existing_run_id}
assert run_data.metrics["my_metrics.metric_key"] == 1.1
assert run_data.metrics["foo.metric_key"] == 1.1


def test_generate_kedro_commands():
# TODO : add a better test because the formatting of record_data is subject to change
# We could check that the command is recored and then rerun properly
Expand Down

0 comments on commit 4548dad

Please sign in to comment.