diff --git a/releasenotes/notes/fix-graphviz-escaping-39a2d1cebb586eca.yaml b/releasenotes/notes/fix-graphviz-escaping-39a2d1cebb586eca.yaml new file mode 100644 index 000000000..13d1e3bbf --- /dev/null +++ b/releasenotes/notes/fix-graphviz-escaping-39a2d1cebb586eca.yaml @@ -0,0 +1,22 @@ +--- +fixes: + - | + Fixed an issue in the :func:`~.graphviz_draw`, :meth:`.PyGraph.to_dot`, and + :meth:`.PyDiGraph.to_dot` which was incorrectly escaping strings when + upgrading to 0.15.0. In earlier versions of rustworkx if you manually + placed quotes in a string for an attr callback get that to pass through to + the output dot file this was incorrectly being converted in rustworkx 0.15.0 + to duplicate the quotes and escape them. For example, if you defined a + callback like:: + + def color_node(_node): + return { + "color": '"#422952"' + } + + to set the color attribute in the output dot file with the string + `"#422952"` (with the quotes) this was incorrectly being converted to + `\"#422952\"`. This no longer occurs, in rustworkx 0.16.0 there will likely + be additional options exposed in :func:`~.graphviz_draw`, + :meth:`.PyGraph.to_dot`, and :meth:`.PyDiGraph.to_dot` to expose further + options around this. diff --git a/src/dot_utils.rs b/src/dot_utils.rs index af1f708eb..07136f592 100644 --- a/src/dot_utils.rs +++ b/src/dot_utils.rs @@ -64,6 +64,8 @@ where Ok(()) } +static ATTRS_TO_ESCAPE: [&str; 2] = ["label", "tooltip"]; + /// Convert an attr map to an output string fn attr_map_to_string<'a>( py: Python, @@ -85,15 +87,13 @@ fn attr_map_to_string<'a>( let attr_string = attrs .iter() .map(|(key, value)| { - let escaped_value = serde_json::to_string(value).map_err(|_err| { - pyo3::exceptions::PyValueError::new_err("could not escape character") - })?; - let escaped_value = &escaped_value.get(1..escaped_value.len() - 1).ok_or( - pyo3::exceptions::PyValueError::new_err("could not escape character"), - )?; - Ok(format!("{}=\"{}\"", key, escaped_value)) + if ATTRS_TO_ESCAPE.contains(&key.as_str()) { + format!("{}=\"{}\"", key, value) + } else { + format!("{}={}", key, value) + } }) - .collect::>>()? + .collect::>() .join(", "); Ok(format!("[{}]", attr_string)) } diff --git a/tests/digraph/test_dot.py b/tests/digraph/test_dot.py index 8d3b8917f..3d893a892 100644 --- a/tests/digraph/test_dot.py +++ b/tests/digraph/test_dot.py @@ -43,9 +43,9 @@ def test_digraph_to_dot_to_file(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'digraph {\n0 [color="black", fillcolor="green", label="a", ' - 'style="filled"];\n1 [color="black", fillcolor="red", label="a", ' - 'style="filled"];\n0 -> 1 [label="1", name="1"];\n}\n' + 'digraph {\n0 [color=black, fillcolor=green, label="a", ' + 'style=filled];\n1 [color=black, fillcolor=red, label="a", ' + 'style=filled];\n0 -> 1 [label="1", name=1];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge, filename=self.path) self.addCleanup(os.remove, self.path) diff --git a/tests/graph/test_dot.py b/tests/graph/test_dot.py index c2fb275e4..63c98804e 100644 --- a/tests/graph/test_dot.py +++ b/tests/graph/test_dot.py @@ -43,9 +43,9 @@ def test_graph_to_dot(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'graph {\n0 [color="black", fillcolor="green", label="a", style="filled"' - '];\n1 [color="black", fillcolor="red", label="a", style="filled"];' - '\n0 -- 1 [label="1", name="1"];\n}\n' + 'graph {\n0 [color=black, fillcolor=green, label="a", style=filled' + '];\n1 [color=black, fillcolor=red, label="a", style=filled];' + '\n0 -- 1 [label="1", name=1];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge) self.assertEqual(expected, res) @@ -70,9 +70,9 @@ def test_digraph_to_dot(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'digraph {\n0 [color="black", fillcolor="green", label="a", ' - 'style="filled"];\n1 [color="black", fillcolor="red", label="a", ' - 'style="filled"];\n0 -> 1 [label="1", name="1"];\n}\n' + 'digraph {\n0 [color=black, fillcolor=green, label="a", ' + 'style=filled];\n1 [color=black, fillcolor=red, label="a", ' + 'style=filled];\n0 -> 1 [label="1", name=1];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge) self.assertEqual(expected, res) @@ -97,9 +97,9 @@ def test_graph_to_dot_to_file(self): ) graph.add_edge(0, 1, dict(label="1", name="1")) expected = ( - 'graph {\n0 [color="black", fillcolor="green", label="a", ' - 'style="filled"];\n1 [color="black", fillcolor="red", label="a", ' - 'style="filled"];\n0 -- 1 [label="1", name="1"];\n}\n' + 'graph {\n0 [color=black, fillcolor=green, label="a", ' + 'style=filled];\n1 [color=black, fillcolor=red, label="a", ' + 'style=filled];\n0 -- 1 [label="1", name=1];\n}\n' ) res = graph.to_dot(lambda node: node, lambda edge: edge, filename=self.path) self.addCleanup(os.remove, self.path) diff --git a/tests/visualization/test_graphviz.py b/tests/visualization/test_graphviz.py index e722d7e33..7303df789 100644 --- a/tests/visualization/test_graphviz.py +++ b/tests/visualization/test_graphviz.py @@ -150,6 +150,55 @@ def test_filename(self): if not SAVE_IMAGES: self.addCleanup(os.remove, "test_graphviz_filename.svg") + def test_qiskit_style_visualization(self): + """This test is to test visualizations like qiskit performs which regressed in 0.15.0.""" + graph = rustworkx.generators.cycle_graph(4) + colors = ["#422952", "#492d58", "#4f305c", "#5e3767"] + edge_colors = ["#4d2f5b", "#693d6f", "#995a88", "#382449"] + pos = [(0, 0), (0, 1), (1, 0), (1, 1)] + for node in graph.node_indices(): + graph[node] = node + + for edge in graph.edge_indices(): + graph.update_edge_by_index(edge, edge) + + def color_node(node): + out_dict = { + "label": str(node), + "color": f'"{colors[node]}"', + "pos": f'"{pos[node][0]}, {pos[node][1]}"', + "fontname": '"DejaVu Sans"', + "pin": "True", + "shape": "circle", + "style": "filled", + "fillcolor": f'"{colors[node]}"', + "fontcolor": "white", + "fontsize": "10", + "height": "0.322", + "fixedsize": "True", + } + return out_dict + + def color_edge(edge): + out_dict = { + "color": f'"{edge_colors[edge]}"', + "fillcolor": f'"{edge_colors[edge]}"', + "penwidth": str(5), + } + return out_dict + + graphviz_draw( + graph, + node_attr_fn=color_node, + edge_attr_fn=color_edge, + filename="test_qiskit_style_visualization.png", + image_type="png", + method="neato", + ) + self.assertTrue(os.path.isfile("test_qiskit_style_visualization.png")) + if not SAVE_IMAGES: + self.addCleanup(os.remove, "test_qiskit_style_visualization.png") + def test_escape_sequences(self): # Create a simple graph graph = rustworkx.generators.path_graph(2)