Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add show_graph to display a GraphViz plot for expressions #19365

Merged
merged 8 commits into from
Nov 2, 2024
6 changes: 5 additions & 1 deletion crates/polars-plan/src/dsl/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::*;
use crate::plans::conversion::is_regex_projection;
use crate::plans::ir::tree_format::TreeFmtVisitor;
use crate::plans::visitor::{AexprNode, TreeWalker};
use crate::prelude::tree_format::TreeFmtVisitorDisplay;

/// Specialized expressions for Categorical dtypes.
pub struct MetaNameSpace(pub(crate) Expr);
Expand Down Expand Up @@ -159,10 +160,13 @@ impl MetaNameSpace {

/// Get a hold to an implementor of the `Display` trait that will format as
/// the expression as a tree
pub fn into_tree_formatter(self) -> PolarsResult<impl Display> {
pub fn into_tree_formatter(self, display_as_dot: bool) -> PolarsResult<impl Display> {
let mut arena = Default::default();
let node = to_aexpr(self.0, &mut arena)?;
let mut visitor = TreeFmtVisitor::default();
if display_as_dot {
visitor.display = TreeFmtVisitorDisplay::DisplayDot;
}

AexprNode::new(node).visit(&mut visitor, &arena)?;

Expand Down
66 changes: 60 additions & 6 deletions crates/polars-plan/src/plans/ir/tree_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,20 @@ impl<'a> TreeFmtNode<'a> {
}
}

#[derive(Default)]
pub enum TreeFmtVisitorDisplay {
#[default]
DisplayText,
DisplayDot,
}

#[derive(Default)]
pub(crate) struct TreeFmtVisitor {
levels: Vec<Vec<String>>,
prev_depth: usize,
depth: usize,
width: usize,
pub(crate) display: TreeFmtVisitorDisplay,
}

impl Visitor for TreeFmtVisitor {
Expand Down Expand Up @@ -868,18 +876,64 @@ impl fmt::Display for Canvas {
}
}

fn tree_fmt_text(tree: &TreeFmtVisitor, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
let tree_view: TreeView<'_> = tree.levels.as_slice().into();
let canvas: Canvas = tree_view.into();
write!(f, "{canvas}")?;

Ok(())
}

// GraphViz Output
// Create a simple DOT graph String from TreeFmtVisitor
fn tree_fmt_dot(tree: &TreeFmtVisitor, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
// Build a dot graph as a string
let tree_view: TreeView<'_> = tree.levels.as_slice().into();
let mut relations: Vec<String> = Vec::new();

// Non-empty cells (nodes) and their connections (edges)
for (i, row) in tree_view.matrix.iter().enumerate() {
for (j, cell) in row.iter().enumerate() {
if !cell.text.is_empty() {
// Add node
let node_label = &cell.text.join("\n");
let node_desc = format!("n{i}{j} [label=\"{node_label}\", ordering=\"out\"]");
relations.push(node_desc);

// Add child edges
if i < tree_view.rows.len() - 1 {
// Iter in reversed order to undo the reversed child order when iterating expressions
for child_col in cell.children_columns.iter().rev() {
let next_row = i + 1;
let edge = format!("n{i}{j} -- n{next_row}{child_col}");
relations.push(edge);
}
}
}
}
}

let graph_str = relations.join("\n ");
let s = format!("graph {{\n {graph_str}\n}}");
write!(f, "{s}")?;
Ok(())
}

fn tree_fmt(tree: &TreeFmtVisitor, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match tree.display {
TreeFmtVisitorDisplay::DisplayText => tree_fmt_text(tree, f),
TreeFmtVisitorDisplay::DisplayDot => tree_fmt_dot(tree, f),
}
}

impl fmt::Display for TreeFmtVisitor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
fmt::Debug::fmt(self, f)
tree_fmt(self, f)
}
}

impl fmt::Debug for TreeFmtVisitor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
let tree_view: TreeView<'_> = self.levels.as_slice().into();
let canvas: Canvas = tree_view.into();
write!(f, "{canvas}")?;

Ok(())
tree_fmt(self, f)
}
}
12 changes: 10 additions & 2 deletions crates/polars-python/src/expr/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,21 @@ impl PyExpr {
self.inner.clone().meta()._into_selector().into()
}

fn meta_tree_format(&self) -> PyResult<String> {
fn compute_tree_format(&self, display_as_dot: bool) -> Result<String, PyErr> {
let e = self
.inner
.clone()
.meta()
.into_tree_formatter()
.into_tree_formatter(display_as_dot)
.map_err(PyPolarsErr::from)?;
Ok(format!("{e}"))
}

fn meta_tree_format(&self) -> PyResult<String> {
self.compute_tree_format(false)
}

fn meta_show_graph(&self) -> PyResult<String> {
self.compute_tree_format(true)
}
}
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/meta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The following methods are available under the `expr.meta` attribute.
Expr.meta.pop
Expr.meta.root_names
Expr.meta.serialize
Expr.meta.show_graph
Expr.meta.tree_format
Expr.meta.undo_aliases
Expr.meta.write_json
54 changes: 53 additions & 1 deletion py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Sized,
)
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -37,7 +38,7 @@
Time,
)
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
from polars.dependencies import _check_for_numpy
from polars.dependencies import _check_for_numpy, import_optional, subprocess
from polars.dependencies import numpy as np

if TYPE_CHECKING:
Expand Down Expand Up @@ -629,3 +630,54 @@ def re_escape(s: str) -> str:
# escapes _only_ those metachars with meaning to the rust regex crate
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)


def display_dot_graph(
*,
dot: str,
show: bool = True,
output_path: str | Path | None = None,
raw_output: bool = False,
figsize: tuple[float, float] = (16.0, 12.0),
) -> str | None:
if raw_output:
# we do not show a graph, nor save a graph to disk
return dot

output_type = "svg" if _in_notebook() else "png"

try:
graph = subprocess.check_output(
["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode()
)
except (ImportError, FileNotFoundError):
msg = (
"The graphviz `dot` binary should be on your PATH."
"(If not installed you can download here: https://graphviz.org/download/)"
)
raise ImportError(msg) from None

if output_path:
Path(output_path).write_bytes(graph)

if not show:
return None

if _in_notebook():
from IPython.display import SVG, display

return display(SVG(graph))
else:
import_optional(
"matplotlib",
err_prefix="",
err_suffix="should be installed to show graphs",
)
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

plt.figure(figsize=figsize)
img = mpimg.imread(BytesIO(graph))
plt.imshow(img)
plt.show()
return None
37 changes: 37 additions & 0 deletions py-polars/polars/expr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from polars._utils.deprecation import deprecate_renamed_function
from polars._utils.serde import serialize_polars_object
from polars._utils.various import display_dot_graph
from polars._utils.wrap import wrap_expr
from polars.exceptions import ComputeError

Expand Down Expand Up @@ -364,3 +365,39 @@ def tree_format(self, *, return_as_string: bool = False) -> str | None:
else:
print(s)
return None

def show_graph(
self,
*,
show: bool = True,
output_path: str | Path | None = None,
raw_output: bool = False,
figsize: tuple[float, float] = (16.0, 12.0),
) -> str | None:
"""
Format the expression as a GraphViz.

Parameters
----------
show
Show the figure.
output_path
Write the figure to disk.
raw_output
Return dot syntax. This cannot be combined with `show` and/or `output_path`.
figsize
Passed to matplotlib if `show` == True.

Examples
--------
>>> e = (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2
>>> e.meta.show_graph() # doctest: +SKIP
"""
dot = self._pyexpr.meta_show_graph()
return display_dot_graph(
dot=dot,
show=show,
output_path=output_path,
raw_output=raw_output,
figsize=figsize,
)
51 changes: 8 additions & 43 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
from polars._utils.slice import LazyPolarsSlice
from polars._utils.unstable import issue_unstable_warning, unstable
from polars._utils.various import (
_in_notebook,
_is_generator,
display_dot_graph,
extend_bool,
find_stacklevel,
is_bool_sequence,
Expand Down Expand Up @@ -1202,48 +1202,13 @@ def show_graph(
)

dot = _ldf.to_dot(optimized)

if raw_output:
# we do not show a graph, nor save a graph to disk
return dot

output_type = "svg" if _in_notebook() else "png"

try:
graph = subprocess.check_output(
["dot", "-Nshape=box", "-T" + output_type], input=f"{dot}".encode()
)
except (ImportError, FileNotFoundError):
msg = (
"The graphviz `dot` binary should be on your PATH."
"(If not installed you can download here: https://graphviz.org/download/)"
)
raise ImportError(msg) from None

if output_path:
Path(output_path).write_bytes(graph)

if not show:
return None

if _in_notebook():
from IPython.display import SVG, display

return display(SVG(graph))
else:
import_optional(
"matplotlib",
err_prefix="",
err_suffix="should be installed to show graphs",
)
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

plt.figure(figsize=figsize)
img = mpimg.imread(BytesIO(graph))
plt.imshow(img)
plt.show()
return None
return display_dot_graph(
dot=dot,
show=show,
output_path=output_path,
raw_output=raw_output,
figsize=figsize,
)

def inspect(self, fmt: str = "{}") -> LazyFrame:
"""
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/operations/namespaces/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def test_meta_tree_format(namespace_files_path: Path) -> None:
assert result.strip() == tree_fmt.strip()


def test_meta_show_graph(namespace_files_path: Path) -> None:
e = (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2
dot = e.meta.show_graph(show=False, raw_output=True)
assert len(dot) > 0
# Don't check output contents since this creates a maintenance burden
# Assume output check in test_meta_tree_format is enough


def test_literal_output_name() -> None:
e = pl.lit(1)
assert e.meta.output_name() == "literal"
Expand Down
Loading