diff --git a/releasenotes/notes/fix-graphviz-draw-tooltip-3f697d71c4b79e60.yaml b/releasenotes/notes/fix-graphviz-draw-tooltip-3f697d71c4b79e60.yaml new file mode 100644 index 000000000..0d7b184c4 --- /dev/null +++ b/releasenotes/notes/fix-graphviz-draw-tooltip-3f697d71c4b79e60.yaml @@ -0,0 +1,14 @@ +--- +fixes: + - | + :func:`.graphviz_draw` can now handle special characters + + .. jupyter-execute:: + + import rustworkx as rx + from rustworkx.visualization import graphviz_draw + + graphviz_draw( + rx.generators.path_graph(2), + node_attr_fn=lambda x: {"label": "the\nlabel", "tooltip": "the\ntooltip"}, + ) diff --git a/src/dot_utils.rs b/src/dot_utils.rs index 3411c2f3e..af1f708eb 100644 --- a/src/dot_utils.rs +++ b/src/dot_utils.rs @@ -82,17 +82,18 @@ fn attr_map_to_string<'a>( if attrs.is_empty() { return Ok("".to_string()); } - let attr_string = attrs .iter() .map(|(key, value)| { - if key == "label" { - format!("{}=\"{}\"", key, value) - } else { - format!("{}={}", key, value) - } + let escaped_value = serde_json::to_string(value).map_err(|_err| { + pyo3::exceptions::PyValueError::new_err("could not escape character") + })?; + let escaped_value = &escaped_value.get(1..escaped_value.len() - 1).ok_or( + pyo3::exceptions::PyValueError::new_err("could not escape character"), + )?; + Ok(format!("{}=\"{}\"", key, escaped_value)) }) - .collect::>() + .collect::>>()? .join(", "); Ok(format!("[{}]", attr_string)) } diff --git a/tests/digraph/test_dot.py b/tests/digraph/test_dot.py index 3d893a892..8d3b8917f 100644 --- a/tests/digraph/test_dot.py +++ b/tests/digraph/test_dot.py @@ -43,9 +43,9 @@ def test_digraph_to_dot_to_file(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'digraph {\n0 [color=black, fillcolor=green, label="a", ' - 'style=filled];\n1 [color=black, fillcolor=red, label="a", ' - 'style=filled];\n0 -> 1 [label="1", name=1];\n}\n' + 'digraph {\n0 [color="black", fillcolor="green", label="a", ' + 'style="filled"];\n1 [color="black", fillcolor="red", label="a", ' + 'style="filled"];\n0 -> 1 [label="1", name="1"];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge, filename=self.path) self.addCleanup(os.remove, self.path) diff --git a/tests/graph/test_dot.py b/tests/graph/test_dot.py index 63c98804e..c2fb275e4 100644 --- a/tests/graph/test_dot.py +++ b/tests/graph/test_dot.py @@ -43,9 +43,9 @@ def test_graph_to_dot(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'graph {\n0 [color=black, fillcolor=green, label="a", style=filled' - '];\n1 [color=black, fillcolor=red, label="a", style=filled];' - '\n0 -- 1 [label="1", name=1];\n}\n' + 'graph {\n0 [color="black", fillcolor="green", label="a", style="filled"' + '];\n1 [color="black", fillcolor="red", label="a", style="filled"];' + '\n0 -- 1 [label="1", name="1"];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge) self.assertEqual(expected, res) @@ -70,9 +70,9 @@ def test_digraph_to_dot(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'digraph {\n0 [color=black, fillcolor=green, label="a", ' - 'style=filled];\n1 [color=black, fillcolor=red, label="a", ' - 'style=filled];\n0 -> 1 [label="1", name=1];\n}\n' + 'digraph {\n0 [color="black", fillcolor="green", label="a", ' + 'style="filled"];\n1 [color="black", fillcolor="red", label="a", ' + 'style="filled"];\n0 -> 1 [label="1", name="1"];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge) self.assertEqual(expected, res) @@ -97,9 +97,9 @@ def test_graph_to_dot_to_file(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'graph {\n0 [color=black, fillcolor=green, label="a", ' - 'style=filled];\n1 [color=black, fillcolor=red, label="a", ' - 'style=filled];\n0 -- 1 [label="1", name=1];\n}\n' + 'graph {\n0 [color="black", fillcolor="green", label="a", ' + 'style="filled"];\n1 [color="black", fillcolor="red", label="a", ' + 'style="filled"];\n0 -- 1 [label="1", name="1"];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge, filename=self.path) self.addCleanup(os.remove, self.path) diff --git a/tests/visualization/test_graphviz.py b/tests/visualization/test_graphviz.py index ca1aa4765..4a6a472dc 100644 --- a/tests/visualization/test_graphviz.py +++ b/tests/visualization/test_graphviz.py @@ -14,7 +14,6 @@ import subprocess import tempfile import unittest - import rustworkx from rustworkx.visualization import graphviz_draw @@ -150,3 +149,36 @@ def test_filename(self): self.assertTrue(os.path.isfile("test_graphviz_filename.svg")) if not SAVE_IMAGES: self.addCleanup(os.remove, "test_graphviz_filename.svg") + + def test_escape_sequences(self): + # Create a simple graph + graph = rustworkx.generators.path_graph(2) + + escape_sequences = { + "\\n": "\n", # Newline + "\\t": "\t", # Horizontal tab + "\\'": "'", # Single quote + '\\"': '"', # Double quote + "\\\\": "\\", # Backslash + "\\r": "\r", # Carriage return + "\\b": "\b", # Backspace + "\\f": "\f", # Form feed + } + + for escaped_seq, raw_seq in escape_sequences.items(): + + def node_attr(node): + """ + Define node attributes including escape sequences for labels and tooltips. + """ + label = f"label{escaped_seq}" + tooltip = f"tooltip{escaped_seq}" + return {"label": label, "tooltip": tooltip} + + # Draw the graph using graphviz_draw + dot_str = graph.to_dot(node_attr) + + # Assert that the escape sequences are correctly placed and escaped in the dot string + self.assertIn( + escaped_seq, dot_str, f"Escape sequence {escaped_seq} not found in dot output" + )