Skip to content

Commit

Permalink
Fix conda dependencies and serialization issues
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo de Leon <eddeleon@microsoft.com>
  • Loading branch information
Eduardo de Leon committed Apr 18, 2020
1 parent e2e4e4c commit c0e9122
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions python/interpret-core/interpret/glassbox/mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,25 @@ def load_model(*args, **kwargs):
return mlflow.pyfunc.load_model(*args, **kwargs)


def _sanitize_explanation_data(data): # TODO Explanations should have a to_json()
if isinstance(data, dict):
for key, val in data.items():
data[key] = _sanitize_explanation_data(data[key])
return data

elif isinstance(data, list):
return [_sanitize_explanation_data[x] for x in data]
else:
# numpy type conversion to python https://stackoverflow.com/questions/9452775 primitive
return data.item() if hasattr(data, "item") else data


def _load_pyfunc(path):
import cloudpickle as pickle
with open(os.path.join(path, "model.pkl"), "rb") as f:
return pickle.load(f)


def _save_model(model, output_path):
import cloudpickle as pickle
if not os.path.exists(output_path):
Expand All @@ -25,7 +39,15 @@ def _save_model(model, output_path):
pickle.dump(model, stream)
try:
with open(os.path.join(output_path, "global_explanation.json"), "w") as stream:
json.dump(model.explain_global().data(-1)["mli"], stream)
data = model.explain_global().data(-1)["mli"]
if isinstance(data, list):
data = data[0]
if "global" not in data["explanation_type"]:
raise Exception("Invalid explanation, not global")
for key in data:
if isinstance(data[key], list):
data[key] = [float(x) for x in data[key]]
json.dump(data, stream)
except ValueError as e:
raise Exception("Unsupported glassbox model type {}. Failed with error {}.".format(type(model), e))

Expand All @@ -34,14 +56,18 @@ def log_model(path, model):
import mlflow.pyfunc
except ImportError as e:
raise Exception("Could not log_model to mlflow. Missing mlflow dependency, pip install mlflow to resolve the error: {}.".format(e))
import cloudpickle as pickle

with TemporaryDirectory() as tempdir:
_save_model(model, tempdir)

conda_env = {"name": "mlflow-env",
"channels": ["defaults"],
"dependencies": ["interpret=".format(interpret.version.__version__),
"cloudpickle==0.5.8"
"dependencies": ["pip",
{"pip": [
"interpret=={}".format(interpret.version.__version__),
"cloudpickle=={}".format(pickle.__version__)]
}
]
}
conda_path = os.path.join(tempdir, "conda.yaml") # TODO Open issue and bug fix for dict support
Expand Down

0 comments on commit c0e9122

Please sign in to comment.