diff --git a/python/interpret-core/interpret/glassbox/mlflow/__init__.py b/python/interpret-core/interpret/glassbox/mlflow/__init__.py index 48d9a352a..86eb871eb 100644 --- a/python/interpret-core/interpret/glassbox/mlflow/__init__.py +++ b/python/interpret-core/interpret/glassbox/mlflow/__init__.py @@ -12,11 +12,25 @@ def load_model(*args, **kwargs): return mlflow.pyfunc.load_model(*args, **kwargs) +def _sanitize_explanation_data(data): # TODO Explanations should have a to_json() + if isinstance(data, dict): + for key, val in data.items(): + data[key] = _sanitize_explanation_data(data[key]) + return data + + elif isinstance(data, list): + return [_sanitize_explanation_data[x] for x in data] + else: + # numpy type conversion to python https://stackoverflow.com/questions/9452775 primitive + return data.item() if hasattr(data, "item") else data + + def _load_pyfunc(path): import cloudpickle as pickle with open(os.path.join(path, "model.pkl"), "rb") as f: return pickle.load(f) + def _save_model(model, output_path): import cloudpickle as pickle if not os.path.exists(output_path): @@ -25,7 +39,15 @@ def _save_model(model, output_path): pickle.dump(model, stream) try: with open(os.path.join(output_path, "global_explanation.json"), "w") as stream: - json.dump(model.explain_global().data(-1)["mli"], stream) + data = model.explain_global().data(-1)["mli"] + if isinstance(data, list): + data = data[0] + if "global" not in data["explanation_type"]: + raise Exception("Invalid explanation, not global") + for key in data: + if isinstance(data[key], list): + data[key] = [float(x) for x in data[key]] + json.dump(data, stream) except ValueError as e: raise Exception("Unsupported glassbox model type {}. Failed with error {}.".format(type(model), e)) @@ -34,14 +56,18 @@ def log_model(path, model): import mlflow.pyfunc except ImportError as e: raise Exception("Could not log_model to mlflow. Missing mlflow dependency, pip install mlflow to resolve the error: {}.".format(e)) + import cloudpickle as pickle with TemporaryDirectory() as tempdir: _save_model(model, tempdir) conda_env = {"name": "mlflow-env", "channels": ["defaults"], - "dependencies": ["interpret=".format(interpret.version.__version__), - "cloudpickle==0.5.8" + "dependencies": ["pip", + {"pip": [ + "interpret=={}".format(interpret.version.__version__), + "cloudpickle=={}".format(pickle.__version__)] + } ] } conda_path = os.path.join(tempdir, "conda.yaml") # TODO Open issue and bug fix for dict support