Skip to content

Commit

Permalink
FIX #100 - Make PipelineML._extract_pipeline_catalog private
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Oct 25, 2020
1 parent e12a74c commit 31604f0
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
22 changes: 11 additions & 11 deletions docs/source/05_python_objects/03_Pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ from kedro_mlflow.mlflow import KedroPipelineModel
catalog = load_context(".").io

# artifacts are all the inputs of the inference pipelines that are persisted in the catalog
pipeline_catalog = pipeline_training.extract_pipeline_catalog(catalog)
artifacts = {name: Path(dataset._filepath).resolve().as_uri()
for name, dataset in pipeline_catalog._data_sets.items()
if name != pipeline_training.model_input_name}


mlflow.pyfunc.log_model(artifact_path="model",
python_model=KedroPipelineModel(pipeline_ml=pipeline_training,
catalog=pipeline_catalog),
artifacts=artifacts,
conda_env={"python": "3.7.0"})
artifacts = pipeline_training.extract_pipeline_artifacts(catalog)

mlflow.pyfunc.log_model(
artifact_path="model",
python_model=KedroPipelineModel(
pipeline_ml=pipeline_training,
catalog=catalog
),
artifacts=artifacts,
conda_env={"python": "3.7.0"}
)
```
2 changes: 1 addition & 1 deletion kedro_mlflow/framework/hooks/pipeline_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def after_pipeline_run(
"""

if isinstance(pipeline, PipelineML):
pipeline_catalog = pipeline.extract_pipeline_catalog(catalog)
pipeline_catalog = pipeline._extract_pipeline_catalog(catalog)
artifacts = pipeline.extract_pipeline_artifacts(pipeline_catalog)
mlflow.pyfunc.log_model(
artifact_path=pipeline.model_name,
Expand Down
2 changes: 1 addition & 1 deletion kedro_mlflow/mlflow/kedro_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class KedroPipelineModel(PythonModel):
def __init__(self, pipeline_ml: PipelineML, catalog: DataCatalog):

self.pipeline_ml = pipeline_ml
self.initial_catalog = pipeline_ml.extract_pipeline_catalog(catalog)
self.initial_catalog = pipeline_ml._extract_pipeline_catalog(catalog)
self.loaded_catalog = DataCatalog()
# we have the guarantee that there is only one output in inference
self.output_name = list(pipeline_ml.inference.outputs())[0]
Expand Down
4 changes: 2 additions & 2 deletions kedro_mlflow/pipeline/pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _check_inference(self, inference: Pipeline) -> None:
)
)

def extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:
def _extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:
sub_catalog = DataCatalog()
for data_set_name in self.inference.inputs():
if data_set_name == self.input_name:
Expand Down Expand Up @@ -182,7 +182,7 @@ def extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:
return sub_catalog

def extract_pipeline_artifacts(self, catalog: DataCatalog):
pipeline_catalog = self.extract_pipeline_catalog(catalog)
pipeline_catalog = self._extract_pipeline_catalog(catalog)
artifacts = {
name: Path(dataset._filepath.as_posix())
.resolve()
Expand Down
6 changes: 3 additions & 3 deletions tests/pipeline/test_pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_filtering_generate_invalid_pipeline_ml(
],
)
def test_catalog_extraction(pipeline_ml_obj, catalog, result):
filtered_catalog = pipeline_ml_obj.extract_pipeline_catalog(catalog)
filtered_catalog = pipeline_ml_obj._extract_pipeline_catalog(catalog)
assert set(filtered_catalog.list()) == result


Expand All @@ -309,7 +309,7 @@ def test_catalog_extraction_missing_inference_input(pipeline_ml_with_tag):
KedroMlflowPipelineMLDatasetsError,
match="since it is an input for inference pipeline",
):
pipeline_ml_with_tag.extract_pipeline_catalog(catalog)
pipeline_ml_with_tag._extract_pipeline_catalog(catalog)


def test_catalog_extraction_unpersisted_inference_input(pipeline_ml_with_tag):
Expand All @@ -320,7 +320,7 @@ def test_catalog_extraction_unpersisted_inference_input(pipeline_ml_with_tag):
KedroMlflowPipelineMLDatasetsError,
match="The datasets of the training pipeline must be persisted locally",
):
pipeline_ml_with_tag.extract_pipeline_catalog(catalog)
pipeline_ml_with_tag._extract_pipeline_catalog(catalog)


def test_too_many_free_inputs():
Expand Down

0 comments on commit 31604f0

Please sign in to comment.