Skip to content

Commit

Permalink
FIX #158 - Autopickle input parameters of inference pipeline in Pipel…
Browse files Browse the repository at this point in the history
…ineML
  • Loading branch information
Galileo-Galilei committed Feb 20, 2021
1 parent ef3b0e2 commit 5d0fb2d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
38 changes: 22 additions & 16 deletions kedro_mlflow/framework/hooks/pipeline_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Union

import mlflow
Expand Down Expand Up @@ -139,24 +140,29 @@ def after_pipeline_run(
"""

if isinstance(pipeline, PipelineML):
pipeline_catalog = pipeline._extract_pipeline_catalog(catalog)
artifacts = pipeline.extract_pipeline_artifacts(pipeline_catalog)
with TemporaryDirectory() as tmp_dir:
pipeline_catalog = pipeline._extract_pipeline_catalog(catalog)
artifacts = pipeline.extract_pipeline_artifacts(
pipeline_catalog, temp_folder=tmp_dir
)

if pipeline.model_signature == "auto":
input_data = pipeline_catalog.load(pipeline.input_name)
model_signature = infer_signature(model_input=input_data)
else:
model_signature = pipeline.model_signature
if pipeline.model_signature == "auto":
input_data = pipeline_catalog.load(pipeline.input_name)
model_signature = infer_signature(model_input=input_data)
else:
model_signature = pipeline.model_signature

mlflow.pyfunc.log_model(
artifact_path=pipeline.model_name,
python_model=KedroPipelineModel(
pipeline_ml=pipeline, catalog=pipeline_catalog, **pipeline.kwargs
),
artifacts=artifacts,
conda_env=_format_conda_env(pipeline.conda_env),
signature=model_signature,
)
mlflow.pyfunc.log_model(
artifact_path=pipeline.model_name,
python_model=KedroPipelineModel(
pipeline_ml=pipeline,
catalog=pipeline_catalog,
**pipeline.kwargs,
),
artifacts=artifacts,
conda_env=_format_conda_env(pipeline.conda_env),
signature=model_signature,
)
# Close the mlflow active run at the end of the pipeline to avoid interactions with further runs
mlflow.end_run()

Expand Down
39 changes: 30 additions & 9 deletions kedro_mlflow/pipeline/pipeline_ml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, Iterable, Optional, Union

from kedro.extras.datasets.pickle import PickleDataSet
from kedro.io import DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
Expand Down Expand Up @@ -171,6 +173,7 @@ def _check_inference(self, inference: Pipeline) -> None:
def _extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:

# check that the pipeline is consistent in case its attributes have been
# modified manually
self._check_consistency()

sub_catalog = DataCatalog()
Expand All @@ -183,7 +186,9 @@ def _extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:
else:
try:
data_set = catalog._data_sets[data_set_name]
if isinstance(data_set, MemoryDataSet):
if isinstance(
data_set, MemoryDataSet
) and not data_set_name.startswith("params:"):
raise KedroMlflowPipelineMLDatasetsError(
"""
The datasets of the training pipeline must be persisted locally
Expand All @@ -210,15 +215,31 @@ def _extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:

return sub_catalog

def extract_pipeline_artifacts(self, catalog: DataCatalog):
def extract_pipeline_artifacts(
self, catalog: DataCatalog, temp_folder: TemporaryDirectory
):
pipeline_catalog = self._extract_pipeline_catalog(catalog)
artifacts = {
name: Path(dataset._filepath.as_posix())
.resolve()
.as_uri() # weird bug when directly converting PurePosixPath to windows: it is considered as relative
for name, dataset in pipeline_catalog._data_sets.items()
if name != self.input_name
}

artifacts = {}
for name, dataset in pipeline_catalog._data_sets.items():
if name != self.input_name:
if name.startswith("params:"):
# we need to persist it locally for mlflow access
absolute_param_path = (
Path(temp_folder.name) / f"params_{name[7:]}.pkl"
)
persisted_dataset = PickleDataSet(filepath=absolute_param_path)
persisted_dataset.save(dataset.load())
artifact_path = Path(absolute_param_path).as_uri()
else:
# In this second case, we know it cannot be a MemoryDataSet
# weird bug when directly converting PurePosixPath to windows: it is considered as relative
artifact_path = (
Path(dataset._filepath.as_posix()).resolve().as_uri()
)

artifacts[name] = artifact_path

return artifacts

def _check_consistency(self) -> None:
Expand Down

0 comments on commit 5d0fb2d

Please sign in to comment.