Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Plotly materialization example #544

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/plotly/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Plotly materializer extension

By importing `hamilton.plugins.plotly_extensions`, you can register two additional materializers for Plotly figures. The `to.plotly()` creates static image files ([docs](https://plotly.com/python/static-image-export/)) and the `to.html()` outputs interactive HTML files ([docs](https://plotly.com/python/interactive-html-export/)).

## How to
You need to install `plotly` (low-level API) to annotate your function with `plotly.graph_objects.Figure` even if you are using `plotly_express` (high-level API) to generate figures.
```python
# 1. define a function returning a `plotly.graph_objects.Figure` in a python module.
def confusion_matrix(...) -> plotly.graph_objects.Figure:
return plotly.express.imshow(...)

# 2. import the module and create the Hamilton driver
dr = (
driver.Builder()
.with_config({...})
.with_modules(MODULE_NAME)
.build()
)

# 3. define the materializers
from hamilton.io.materialization import to

materializers = [
to.plotly(
dependencies=["confusion_matrix_figure"],
id="confusion_matrix_png",
path="./static.png",
),
to.html(
dependencies=["confusion_matrix_figure"],
id="confusion_matrix_html",
path="./interactive.html",
),
]

# 4. materialize figures
dr.materialize(*materializers)
```

## Notes
Here are a few things to consider when using the plotly materializers:
- Any plotly figure is a subclass of `plotly.graph_objects.Figure`, including anything from `plotly.express`, `plotly.graph_objects`, `plotly.figure_factory`.
- `to.plotly()` supports all filetypes of the plotly rendering engine (PNG, SVG, etc.). The output type will be automatically inferred from the `path` value passed to the materializer. Or, you can specify the file type explicitly as `kwarg`.
- `to.html()` outputs an interactive HTML file. These files will be at least ~3Mb each since they include they bundle the plotly JS library. You can reduce that by using the `include_plotlyjs` `kwarg`. Read more about it in the documentation at `https://plotly.com/python/interactive-html-export/`
- `to.html()` will include the data that's being visually displayed, including what's part of the tooltips, which can grow filesize quickly.
14 changes: 14 additions & 0 deletions examples/plotly/interactive.html

Large diffs are not rendered by default.

135 changes: 135 additions & 0 deletions examples/plotly/model_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Dict

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn import base, datasets, linear_model, metrics, svm, utils
from sklearn.model_selection import train_test_split

from hamilton import function_modifiers


@function_modifiers.config.when(data_loader="iris")
def data__iris() -> utils.Bunch:
return datasets.load_digits()


@function_modifiers.config.when(data_loader="digits")
def data__digits() -> utils.Bunch:
return datasets.load_digits()


def target(data: utils.Bunch) -> np.ndarray:
return data.target


def target_names(data: utils.Bunch) -> np.ndarray:
return data.target_names


def feature_matrix(data: utils.Bunch) -> np.ndarray:
return data.data


@function_modifiers.config.when(clf="svm")
def prefit_clf__svm(gamma: float = 0.001) -> base.ClassifierMixin:
"""Returns an unfitted SVM classifier object.

:param gamma: ...
:return:
"""
return svm.SVC(gamma=gamma)


@function_modifiers.config.when(clf="logistic")
def prefit_clf__logreg(penalty: str) -> base.ClassifierMixin:
"""Returns an unfitted Logistic Regression classifier object.

:param penalty:
:return:
"""
return linear_model.LogisticRegression(penalty)


@function_modifiers.extract_fields(
{"X_train": np.ndarray, "X_test": np.ndarray, "y_train": np.ndarray, "y_test": np.ndarray}
)
def train_test_split_func(
feature_matrix: np.ndarray,
target: np.ndarray,
test_size_fraction: float,
shuffle_train_test_split: bool,
) -> Dict[str, np.ndarray]:
"""Function that creates the training & test splits.

It this then extracted out into constituent components and used downstream.

:param feature_matrix:
:param target:
:param test_size_fraction:
:param shuffle_train_test_split:
:return:
"""
X_train, X_test, y_train, y_test = train_test_split(
feature_matrix, target, test_size=test_size_fraction, shuffle=shuffle_train_test_split
)
return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}


def y_test_with_labels(y_test: np.ndarray, target_names: np.ndarray) -> np.ndarray:
"""Adds labels to the target output."""
return np.array([target_names[idx] for idx in y_test])


def fit_clf(
prefit_clf: base.ClassifierMixin, X_train: np.ndarray, y_train: np.ndarray
) -> base.ClassifierMixin:
"""Calls fit on the classifier object; it mutates it."""
prefit_clf.fit(X_train, y_train)
return prefit_clf


def predicted_output(fit_clf: base.ClassifierMixin, X_test: np.ndarray) -> np.ndarray:
"""Exercised the fit classifier to perform a prediction."""
return fit_clf.predict(X_test)


def predicted_output_with_labels(
predicted_output: np.ndarray, target_names: np.ndarray
) -> np.ndarray:
"""Replaces the predictions with the desired labels."""
return np.array([target_names[idx] for idx in predicted_output])


def classification_report(
predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray
) -> str:
"""Returns a classification report."""
return metrics.classification_report(y_test_with_labels, predicted_output_with_labels)


def confusion_matrix(
predicted_output_with_labels: np.ndarray, y_test_with_labels: np.ndarray
) -> np.ndarray:
"""Returns a confusion matrix report."""
return metrics.confusion_matrix(y_test_with_labels, predicted_output_with_labels)


def confusion_matrix_figure(confusion_matrix: np.ndarray, target_names: np.ndarray) -> go.Figure:
"""Create a plotly interactive heatmap of the confusion matrix"""
class_indices = np.arange(len(target_names))
return px.imshow(
confusion_matrix,
x=class_indices,
y=class_indices,
labels=dict(
x="Predicted labels",
y="True labels",
color="Count",
),
)


def model_parameters(fit_clf: base.ClassifierMixin) -> dict:
"""Returns a dictionary of model parameters."""
return fit_clf.get_params()
150 changes: 150 additions & 0 deletions examples/plotly/notebook.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/plotly/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
plotly
sf-hamilton[visualization]
Binary file added examples/plotly/static.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.