diff --git a/mlflow/getml/autologging.py b/mlflow/getml/autologging.py index 27ed80ecfa8cf..1216b7aaa93f0 100644 --- a/mlflow/getml/autologging.py +++ b/mlflow/getml/autologging.py @@ -1,12 +1,18 @@ import json +import logging import threading from dataclasses import dataclass, field from typing import Any import mlflow +from mlflow.data.pandas_dataset import PandasDataset +from mlflow.entities.dataset_input import DatasetInput +from mlflow.entities.input_tag import InputTag from mlflow.utils.autologging_utils import safe_patch from mlflow.utils.autologging_utils.client import MlflowAutologgingQueueingClient +from mlflow.utils.mlflow_tags import MLFLOW_DATASET_CONTEXT +_logger = logging.getLogger(__name__) @dataclass class LogInfo: @@ -148,24 +154,12 @@ def _extract_engine_system_metrics( step += 1 stop_event.wait(1) - def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs): + def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs) -> getml.pipeline.Pipeline: autologging_client = MlflowAutologgingQueueingClient() assert (active_run := mlflow.active_run()) run_id = active_run.info.run_id - pipeline_log_info = _extract_pipeline_informations(self) - # with open("my_dict.json", "w") as f: - # json.dump(pipeline_log_info.params, f) - # mlflow.log_artifact("my_dict.json") - # mlflow.log_dict(pipeline_log_info.params, 'params.json') - autologging_client.log_params( - run_id=run_id, - params=pipeline_log_info.params, - ) - if tags := pipeline_log_info.tags: - autologging_client.set_tags(run_id=run_id, tags=tags) - - engine_metrics_to_be_tracked = _collect_available_engine_metrics() + engine_metrics_to_be_tracked = _log_pretraining_metadata(autologging_client, self, run_id, *args) if engine_metrics_to_be_tracked: stop_event = threading.Event() metrics_thread = threading.Thread( @@ -193,7 +187,7 @@ def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs): autologging_client.flush(synchronous=True) return fit_output - def patched_score_method(original, self: getml.Pipeline, *args, **kwargs): + def patched_score_method(original, self: getml.Pipeline, *args, **kwargs) -> getml.pipeline.Scores: target = self.data_model.population.roles.target[0] pop_df = args[0].population.to_pandas() @@ -208,9 +202,45 @@ def patched_score_method(original, self: getml.Pipeline, *args, **kwargs): model_type=["regressor" if self.is_regression else "classifier"][0], evaluators=["default"], ) - return original(self, *args, **kwargs) + def _log_pretraining_metadata(autologging_client: MlflowAutologgingQueueingClient, + self: getml.Pipeline, + run_id: str, + *args + ) -> dict: + + pipeline_log_info = _extract_pipeline_informations(self) + autologging_client.log_params( + run_id=run_id, + params=pipeline_log_info.params, + ) + if tags := pipeline_log_info.tags: + autologging_client.set_tags(run_id=run_id, tags=tags) + + engine_metrics_to_be_tracked = _collect_available_engine_metrics() + + if log_datasets: + try: + datasets = [] + population_dataset: PandasDataset = mlflow.data.from_pandas(args[0].population.to_pandas(), name = args[0].population.base.name) + tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value='Population')] + datasets.append(DatasetInput(dataset=population_dataset._to_mlflow_entity(), tags=tags)) + + for name, peripheral in args[0].peripheral.items(): + tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value='Peripheral')] + peripheral_dataset: PandasDataset = mlflow.data.from_pandas(peripheral.to_pandas(), name = name) + datasets.append(DatasetInput(dataset=peripheral_dataset._to_mlflow_entity(), tags=tags)) + + autologging_client.log_inputs( + run_id=run_id, datasets=datasets + ) + + except Exception as e: + _logger.warning( + "Failed to log training dataset information to MLflow Tracking. Reason: %s", e + ) + return engine_metrics_to_be_tracked _patch_pipeline_method( flavor_name=flavor_name,