Skip to content

Commit

Permalink
Merge pull request #181 from Quantomatic/feat/tikz-proof
Browse files Browse the repository at this point in the history
Add proof export to tikz
  • Loading branch information
mark-koch authored Nov 13, 2023
2 parents 1fa641c + 52ce7b4 commit a0729d3
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 11 deletions.
14 changes: 10 additions & 4 deletions zxlive/dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_file_path_and_format(parent: QWidget, filter: str) -> Optional[tuple[str

return file_path, selected_format

def export_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[tuple[str, FileFormat]]:
def save_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[tuple[str, FileFormat]]:
file_path_and_format = get_file_path_and_format(parent, ";;".join([f.filter for f in FileFormat if f != FileFormat.ZXProof]))
if file_path_and_format is None or not file_path_and_format[0]:
return None
Expand All @@ -215,7 +215,7 @@ def export_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[tuple[str,

return file_path, selected_format

def export_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[tuple[str, FileFormat]]:
def safe_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[tuple[str, FileFormat]]:
file_path_and_format = get_file_path_and_format(parent, FileFormat.ZXProof.filter)
if file_path_and_format is None or not file_path_and_format[0]:
return None
Expand All @@ -225,7 +225,7 @@ def export_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[tu
return None
return file_path, selected_format

def export_rule_dialog(rule: CustomRule, parent: QWidget) -> Optional[tuple[str, FileFormat]]:
def safe_rule_dialog(rule: CustomRule, parent: QWidget) -> Optional[tuple[str, FileFormat]]:
file_path_and_format = get_file_path_and_format(parent, FileFormat.ZXRule.filter)
if file_path_and_format is None or not file_path_and_format[0]:
return None
Expand All @@ -235,6 +235,12 @@ def export_rule_dialog(rule: CustomRule, parent: QWidget) -> Optional[tuple[str,
return None
return file_path, selected_format

def export_proof_dialog(parent: QWidget) -> Optional[str]:
file_path_and_format = get_file_path_and_format(parent, FileFormat.TikZ.filter)
if file_path_and_format is None or not file_path_and_format[0]:
return None
return file_path_and_format[0]

def get_lemma_name_and_description(parent: MainWindow) -> tuple[Optional[str], Optional[str]]:
dialog = QDialog()
parent.rewrite_form = QFormLayout(dialog)
Expand Down Expand Up @@ -283,7 +289,7 @@ def add_rewrite() -> None:
return
rule = CustomRule(parent.left_graph, parent.right_graph, name.text(), description.toPlainText())
check_rule(rule, show_error=True)
if export_rule_dialog(rule, parent):
if safe_rule_dialog(rule, parent):
dialog.accept()
button_box.accepted.connect(add_rewrite)
button_box.rejected.connect(dialog.reject)
Expand Down
27 changes: 20 additions & 7 deletions zxlive/mainwindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@
from .custom_rule import CustomRule, check_rule
from .dialogs import (FileFormat, ImportGraphOutput, ImportProofOutput,
ImportRuleOutput, create_new_rewrite,
export_diagram_dialog, export_proof_dialog,
export_rule_dialog, get_lemma_name_and_description,
import_diagram_dialog, import_diagram_from_file, show_error_msg)
save_diagram_dialog, safe_proof_dialog,
safe_rule_dialog, get_lemma_name_and_description,
import_diagram_dialog, import_diagram_from_file, show_error_msg,
export_proof_dialog)
from zxlive.settings_dialog import open_settings_dialog

from .editor_base_panel import EditorBasePanel
from .edit_panel import GraphEditPanel
from .proof_panel import ProofPanel
from .rule_panel import RulePanel
from .tikz import proof_to_tikz


class MainWindow(QMainWindow):
Expand Down Expand Up @@ -103,6 +105,8 @@ def __init__(self) -> None:
"Save the diagram by overwriting the previous loaded file.")
self.save_as = self._new_action("Save &as...", self.handle_save_as_action, QKeySequence.StandardKey.SaveAs,
"Opens a file-picker dialog to save the diagram in a chosen file format")
self.export_tikz_proof = self._new_action("Export to tikz", self.handle_export_tikz_proof_action, None,
"Exports the proof to tikz")

file_menu = menu.addMenu("&File")
file_menu.addAction(new_graph)
Expand All @@ -111,6 +115,7 @@ def __init__(self) -> None:
file_menu.addAction(self.close_action)
file_menu.addAction(self.save_file)
file_menu.addAction(self.save_as)
file_menu.addAction(self.export_tikz_proof)

self.undo_action = self._new_action("Undo", self.undo, QKeySequence.StandardKey.Undo,
"Undoes the last action", "undo.svg")
Expand Down Expand Up @@ -363,12 +368,12 @@ def handle_save_file_action(self) -> bool:
def handle_save_as_action(self) -> bool:
assert self.active_panel is not None
if isinstance(self.active_panel, ProofPanel):
out = export_proof_dialog(self.active_panel.proof_model, self)
out = safe_proof_dialog(self.active_panel.proof_model, self)
elif isinstance(self.active_panel, RulePanel):
check_rule(self.active_panel.get_rule(), show_error=True)
out = export_rule_dialog(self.active_panel.get_rule(), self)
out = safe_rule_dialog(self.active_panel.get_rule(), self)
else:
out = export_diagram_dialog(self.active_panel.graph_scene.g, self)
out = save_diagram_dialog(self.active_panel.graph_scene.g, self)
if out is None: return False
file_path, file_type = out
self.active_panel.file_path = file_path
Expand All @@ -379,6 +384,14 @@ def handle_save_as_action(self) -> bool:
self.tab_widget.setTabText(i,name)
return True

def handle_export_tikz_proof_action(self) -> bool:
assert isinstance(self.active_panel, ProofPanel)
path = export_proof_dialog(self)
if path is None:
show_error_msg("Export failed", "Invalid path")
return False
with open(path, "w") as f:
f.write(proof_to_tikz(self.active_panel.proof_model))

def cut_graph(self) -> None:
assert self.active_panel is not None
Expand Down Expand Up @@ -518,7 +531,7 @@ def proof_as_lemma(self) -> None:
lhs_graph = self.active_panel.proof_model.graphs[0]
rhs_graph = self.active_panel.proof_model.graphs[-1]
rule = CustomRule(lhs_graph, rhs_graph, name, description)
export_rule_dialog(rule, self)
safe_rule_dialog(rule, self)

def update_colors(self) -> None:
if self.active_panel is not None:
Expand Down
67 changes: 67 additions & 0 deletions zxlive/settings_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@
"tikz/edge-import": ", ".join(pyzx.tikz.synonyms_edge),
"tikz/edge-H-import": ", ".join(pyzx.tikz.synonyms_hedge),
"tikz/edge-W-import": ", ".join(pyzx.tikz.synonyms_wedge),

"tikz/layout/hspace": 2,
"tikz/layout/vspace": 2,
"tikz/layout/max-width": 10,

"tikz/names/fuse spiders": "f",
"tikz/names/bialgebra": "b",
"tikz/names/change color to Z": "cc",
"tikz/names/change color to X": "cc",
"tikz/names/remove identity": "id",
"tikz/names/Add Z identity": "id",
"tikz/names/copy 0/pi spider": "cp",
"tikz/names/push Pauli": "pi",
"tikz/names/decompose hadamard": "eu",
}

color_schemes = {
Expand All @@ -68,6 +82,14 @@
'gidney': "Gidney's Black & White",
}


# Initialise settings
settings = QSettings("zxlive", "zxlive")
for key, value in defaults.items():
if not settings.contains(key):
settings.setValue(key, value)


class SettingsDialog(QDialog):
def __init__(self, main_window: MainWindow) -> None:
super().__init__(main_window)
Expand Down Expand Up @@ -156,6 +178,51 @@ def __init__(self, main_window: MainWindow) -> None:
self.add_setting(form_import, "tikz/z-box-import", "Z box", 'str')
self.add_setting(form_import, "tikz/edge-W-import", "W io edge", 'str')

##### Tikz Layout settings #####
panel_tikz_layout = QWidget()
vlayout = QVBoxLayout()
panel_tikz_layout.setLayout(vlayout)
tab_widget.addTab(panel_tikz_layout, "Tikz layout")

vlayout.addWidget(QLabel("Tikz layout settings"))

form_layout = QFormLayout()
w = QWidget()
w.setLayout(form_layout)
vlayout.addWidget(w)
vlayout.addStretch()

self.add_setting(form_layout, "tikz/layout/hspace", "Horizontal spacing", "float")
self.add_setting(form_layout, "tikz/layout/vspace", "Vertical spacing", "float")
self.add_setting(form_layout, "tikz/layout/max-width", "Maximum width", 'float')


##### Tikz rule name settings #####
panel_tikz_names = QWidget()
vlayout = QVBoxLayout()
panel_tikz_names.setLayout(vlayout)
tab_widget.addTab(panel_tikz_names, "Tikz rule names")

vlayout.addWidget(QLabel("Tikz rule name settings"))
vlayout.addWidget(QLabel("Mapping of pyzx rule names to tikz display strings"))

form_names = QFormLayout()
w = QWidget()
w.setLayout(form_names)
vlayout.addWidget(w)
vlayout.addStretch()

self.add_setting(form_names, "tikz/names/fuse spiders", "fuse spiders", "str")
self.add_setting(form_names, "tikz/names/bialgebra", "bialgebra", "str")
self.add_setting(form_names, "tikz/names/change color to Z", "change color to Z", "str")
self.add_setting(form_names, "tikz/names/change color to X", "change color to X", "str")
self.add_setting(form_names, "tikz/names/remove identity", "remove identity", "str")
self.add_setting(form_names, "tikz/names/Add Z identity", "add Z identity", "str")
self.add_setting(form_names, "tikz/names/copy 0/pi spider", "copy 0/pi spider", "str")
self.add_setting(form_names, "tikz/names/push Pauli", "push Pauli", "str")
self.add_setting(form_names, "tikz/names/decompose hadamard", "decompose hadamard", "str")



##### Okay/Cancel Buttons #####
w= QWidget()
Expand Down
49 changes: 49 additions & 0 deletions zxlive/tikz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from PySide6.QtCore import QSettings
from pyzx.tikz import TIKZ_BASE, _to_tikz

from zxlive.proof import ProofModel


def proof_to_tikz(proof: ProofModel) -> str:
settings = QSettings("zxlive", "zxlive")
vspace = settings.value("tikz/layout/vspace")
hspace = settings.value("tikz/layout/hspace")
max_width = settings.value("tikz/layout/max-width")
draw_scalar = False

xoffset = -max_width
yoffset = -10
idoffset = 0
total_verts, total_edges = [], []
for i, g in enumerate(proof.graphs):
# Compute graph dimensions
width = max(g.row(v) for v in g.vertices()) - min(g.row(v) for v in g.vertices())
height = max(g.qubit(v) for v in g.vertices()) - min(g.qubit(v) for v in g.vertices())

# Translate graph so that the first vertex starts at 0
min_x = min(g.row(v) for v in g.vertices())
g = g.translate(-min_x, 0)

if i > 0:
rewrite = proof.steps[i-1]
# Try to look up name in settings
name = settings.value(f"tikz/names/{rewrite.rule}") if settings.contains(f"tikz/names/{rewrite.rule}") else rewrite.rule
eq = f"\\node [style=none] ({idoffset}) at ({xoffset - hspace/2:.2f}, {-yoffset - height/2:.2f}) {{$\\overset{{\\mathit{{{name}}}}}{{=}}$}};"
total_verts.append(eq)
idoffset += 1

verts, edges = _to_tikz(g, draw_scalar, xoffset, yoffset, idoffset)
total_verts.extend(verts)
total_edges.extend(edges)

if xoffset + hspace > max_width:
xoffset = -max_width
yoffset += height + vspace
else:
xoffset += width + hspace

max_index = max(g.vertices()) + 2 * g.num_inputs() + 2
idoffset += max_index

return TIKZ_BASE.format(vertices="\n".join(total_verts), edges="\n".join(total_edges))

0 comments on commit a0729d3

Please sign in to comment.