From bf5d99a054b5521f8b30cd8d26783aecc5ad41c0 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Thu, 5 Dec 2024 22:10:29 +0000 Subject: [PATCH 01/10] fix the ipython_display side effect --- mrmustard/lab_dev/circuit_components.py | 4 ++++ mrmustard/lab_dev/states/dm.py | 4 ++++ mrmustard/lab_dev/states/ket.py | 4 ++++ mrmustard/physics/ansatz/array_ansatz.py | 4 ++++ mrmustard/physics/ansatz/polyexp_ansatz.py | 4 ++++ mrmustard/physics/wires.py | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index d0e1297df..1185f74c4 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -28,6 +28,7 @@ import numpy as np from numpy.typing import ArrayLike import ipywidgets as widgets +from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import settings, math, widgets as mmwidgets @@ -724,6 +725,9 @@ def __truediv__(self, other: Scalar) -> CircuitComponent: return self._from_attributes(Representation(self.ansatz / other, self.wires), self.name) def _ipython_display_(self): + if isinstance(get_ipython(), InteractiveShell): + print(self) + return # both reps might return None rep_fn = mmwidgets.fock if isinstance(self.ansatz, ArrayAnsatz) else mmwidgets.bargmann rep_widget = rep_fn(self.ansatz) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 0b969540c..00daf15f8 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -22,6 +22,7 @@ from itertools import product import warnings import numpy as np +from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import math, settings, widgets @@ -375,6 +376,9 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex return self.quadrature(quad, phi) def _ipython_display_(self): # pragma: no cover + if isinstance(get_ipython(), InteractiveShell): + print(self) + return is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=False, is_fock=is_fock)) diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index 9ea9151b1..20b7ee041 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -22,6 +22,7 @@ from itertools import product import warnings import numpy as np +from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import math, settings, widgets @@ -335,6 +336,9 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex return math.abs(self.quadrature(quad, phi)) ** 2 def _ipython_display_(self): # pragma: no cover + if isinstance(get_ipython(), InteractiveShell): + print(self) + return is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=True, is_fock=is_fock)) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 7455a0af1..70fd331ca 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -24,6 +24,7 @@ import numpy as np from numpy.typing import ArrayLike +from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import math, widgets @@ -207,6 +208,9 @@ def _generate_ansatz(self): self.array = [self._fn(**self._kwargs)] def _ipython_display_(self): + if isinstance(get_ipython(), InteractiveShell): + print(self) + return w = widgets.fock(self) if w is None: print(repr(self)) diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index aabb3fcba..fc947cc38 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -29,6 +29,7 @@ from matplotlib import colors import matplotlib.pyplot as plt +from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard.utils.typing import ( @@ -568,6 +569,9 @@ def _generate_ansatz(self): self.c = c def _ipython_display_(self): + if isinstance(get_ipython(), InteractiveShell): + print(self) + return display(widgets.bargmann(self)) def _order_batch(self): diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index db14a9f48..9f7137b34 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -18,6 +18,7 @@ from functools import cached_property import numpy as np +from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import widgets @@ -546,4 +547,7 @@ def __repr__(self) -> str: return f"Wires{self.args}" def _ipython_display_(self): + if isinstance(get_ipython(), InteractiveShell): + print(self) + return display(widgets.wires(self)) From c0fc9acd230c1eb8f55db15828819bb832b796b6 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 17:34:46 +0000 Subject: [PATCH 02/10] add tests --- tests/test_lab_dev/test_circuit_components.py | 10 ++++++++++ tests/test_physics/test_ansatz/test_array_ansatz.py | 10 ++++++++++ tests/test_physics/test_ansatz/test_polyexp_ansatz.py | 10 ++++++++++ tests/test_physics/test_wires.py | 10 ++++++++++ 4 files changed, 40 insertions(+) diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index 200479614..d8e16e417 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -20,6 +20,7 @@ import numpy as np import pytest +from IPython import InteractiveShell from ipywidgets import HTML, Box, HBox, VBox from mrmustard import math, settings @@ -559,6 +560,15 @@ def test_ipython_repr_invalid_obj(self, mock_display): assert isinstance(title_widget, HTML) assert isinstance(wires_widget, HTML) + @patch("mrmustard.lab_dev.circuit_components.get_ipython") + def test_ipython_repr_interactive(self, mock_ipython, capsys): + """Test the IPython repr function.""" + mock_ipython.return_value = InteractiveShell() + dgate = Dgate([1, 2], x=0.1, y=0.1).to_fock() + dgate._ipython_display_() + captured = capsys.readouterr() + assert captured.out.rstrip() == repr(dgate) + def test_serialize_default_behaviour(self): """Test the default serializer.""" name = "my_component" diff --git a/tests/test_physics/test_ansatz/test_array_ansatz.py b/tests/test_physics/test_ansatz/test_array_ansatz.py index 0879ef0fc..a9f130ec3 100644 --- a/tests/test_physics/test_ansatz/test_array_ansatz.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -20,6 +20,7 @@ import numpy as np import pytest +from IPython import InteractiveShell from ipywidgets import HTML, HBox, Tab, VBox from plotly.graph_objs import FigureWidget @@ -261,3 +262,12 @@ def test_ipython_repr_expects_3_dims_or_less(self, mock_display): rep = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True) rep._ipython_display_() mock_display.assert_not_called() + + @patch("mrmustard.physics.ansatz.array_ansatz.get_ipython") + def test_ipython_repr_interactive(self, mock_ipython, capsys): + """Test the IPython repr function.""" + mock_ipython.return_value = InteractiveShell() + rep = ArrayAnsatz(np.random.random((1, 8)), batched=True) + rep._ipython_display_() + captured = capsys.readouterr() + assert captured.out.rstrip() == repr(rep) diff --git a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py index a5a9f5f28..ed0b605e5 100644 --- a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py +++ b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py @@ -20,6 +20,7 @@ import numpy as np import pytest +from IPython import InteractiveShell from ipywidgets import HTML, Box, IntSlider, IntText, Stack, VBox from plotly.graph_objs import FigureWidget @@ -321,6 +322,15 @@ def test_ipython_repr_batched(self, mock_display): assert len(stack.children) == 2 assert all(box.layout.max_width == "50%" for box in stack.children) + @patch("mrmustard.physics.ansatz.polyexp_ansatz.get_ipython") + def test_ipython_repr_interactive(self, mock_ipython, capsys): + """Test the IPython repr function.""" + mock_ipython.return_value = InteractiveShell() + rep = PolyExpAnsatz(*Abc_triple(2)) + rep._ipython_display_() + captured = capsys.readouterr() + assert captured.out.rstrip() == repr(rep) + def test_matmul_barg_barg(self): triple1 = Abc_triple(3) triple2 = Abc_triple(3) diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index e0a537de0..e7dcab326 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -19,6 +19,7 @@ from unittest.mock import patch import pytest +from IPython import InteractiveShell from ipywidgets import HTML from mrmustard.physics.wires import Wires @@ -235,3 +236,12 @@ def test_ipython_repr(self, mock_display): wires._ipython_display_() [widget] = mock_display.call_args.args assert isinstance(widget, HTML) + + @patch("mrmustard.physics.wires.get_ipython") + def test_ipython_repr_interactive(self, mock_ipython, capsys): + """Test the IPython repr function.""" + mock_ipython.return_value = InteractiveShell() + wires = Wires({0}, {}, {3}, {3, 4}) + wires._ipython_display_() + captured = capsys.readouterr() + assert captured.out.rstrip() == repr(wires) From aba8bddb27d99ab764f1c91d7c6005a9c1a10b7c Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 18:49:07 +0000 Subject: [PATCH 03/10] codefactor, avoid code injection --- .github/workflows/tests_numpy.yml | 4 +++- .github/workflows/tests_tensorflow.yml | 4 +++- tests/test_physics/test_wires.py | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests_numpy.yml b/.github/workflows/tests_numpy.yml index ee794f78b..7a63cb844 100644 --- a/.github/workflows/tests_numpy.yml +++ b/.github/workflows/tests_numpy.yml @@ -59,6 +59,8 @@ jobs: - name: Copy durations to s3 if: github.event_name == 'push' + env: + REF_NAME: ${{ github.ref_name }} run: | grep ' call ' durations.txt | awk '{print $3,$1}' > ${{ steps.record_file.outputs.filename }} - aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/numpy_tests/${{ github.ref_name }}/ + aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/numpy_tests/${{ env.REF_NAME }}/ diff --git a/.github/workflows/tests_tensorflow.yml b/.github/workflows/tests_tensorflow.yml index 959ec7c3f..9c55ed9e9 100644 --- a/.github/workflows/tests_tensorflow.yml +++ b/.github/workflows/tests_tensorflow.yml @@ -66,6 +66,8 @@ jobs: - name: Copy durations to s3 if: github.event_name == 'push' + env: + REF_NAME: ${{ github.ref_name }} run: | grep ' call ' durations.txt | awk '{print $3,$1}' > ${{ steps.record_file.outputs.filename }} - aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/tf_tests/${{ github.ref_name }}/ + aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/tf_tests/${{ env.REF_NAME }}/ diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index e7dcab326..8b70745e6 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -229,6 +229,10 @@ def test_matmul_error(self): with pytest.raises(ValueError): u @ v # pylint: disable=pointless-statement + +class TestWiresDisplay: + """Test the wires _ipython_display_ functionality.""" + @patch("mrmustard.physics.wires.display") def test_ipython_repr(self, mock_display): """Test the IPython repr function.""" From d7824a1429f17da331d20cd32127a3ebc9f6b9f6 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 18:50:40 +0000 Subject: [PATCH 04/10] changelog --- .github/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index c13ba6135..25a85281f 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -49,6 +49,9 @@ * Fix the TensorFlow issue with the expected number of gradients in `custom_gradient`. [(#506)](https://github.com/XanaduAI/MrMustard/pull/506) +* Use the default repr when in interactive IPython. + [(#531)](https://github.com/XanaduAI/MrMustard/pull/531) + ### Documentation ### Tests From 922a0e3dd07cd9a46fee17e34ccc37d772a09047 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 21:26:49 +0000 Subject: [PATCH 05/10] Revert "fix the ipython_display side effect" This reverts commit bf5d99a054b5521f8b30cd8d26783aecc5ad41c0. --- mrmustard/lab_dev/circuit_components.py | 4 ---- mrmustard/lab_dev/states/dm.py | 4 ---- mrmustard/lab_dev/states/ket.py | 4 ---- mrmustard/physics/ansatz/array_ansatz.py | 4 ---- mrmustard/physics/ansatz/polyexp_ansatz.py | 4 ---- mrmustard/physics/wires.py | 4 ---- 6 files changed, 24 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 1185f74c4..d0e1297df 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -28,7 +28,6 @@ import numpy as np from numpy.typing import ArrayLike import ipywidgets as widgets -from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import settings, math, widgets as mmwidgets @@ -725,9 +724,6 @@ def __truediv__(self, other: Scalar) -> CircuitComponent: return self._from_attributes(Representation(self.ansatz / other, self.wires), self.name) def _ipython_display_(self): - if isinstance(get_ipython(), InteractiveShell): - print(self) - return # both reps might return None rep_fn = mmwidgets.fock if isinstance(self.ansatz, ArrayAnsatz) else mmwidgets.bargmann rep_widget = rep_fn(self.ansatz) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 00daf15f8..0b969540c 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -22,7 +22,6 @@ from itertools import product import warnings import numpy as np -from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import math, settings, widgets @@ -376,9 +375,6 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex return self.quadrature(quad, phi) def _ipython_display_(self): # pragma: no cover - if isinstance(get_ipython(), InteractiveShell): - print(self) - return is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=False, is_fock=is_fock)) diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index 20b7ee041..9ea9151b1 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -22,7 +22,6 @@ from itertools import product import warnings import numpy as np -from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import math, settings, widgets @@ -336,9 +335,6 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex return math.abs(self.quadrature(quad, phi)) ** 2 def _ipython_display_(self): # pragma: no cover - if isinstance(get_ipython(), InteractiveShell): - print(self) - return is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=True, is_fock=is_fock)) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 70fd331ca..7455a0af1 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -24,7 +24,6 @@ import numpy as np from numpy.typing import ArrayLike -from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import math, widgets @@ -208,9 +207,6 @@ def _generate_ansatz(self): self.array = [self._fn(**self._kwargs)] def _ipython_display_(self): - if isinstance(get_ipython(), InteractiveShell): - print(self) - return w = widgets.fock(self) if w is None: print(repr(self)) diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index fc947cc38..aabb3fcba 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -29,7 +29,6 @@ from matplotlib import colors import matplotlib.pyplot as plt -from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard.utils.typing import ( @@ -569,9 +568,6 @@ def _generate_ansatz(self): self.c = c def _ipython_display_(self): - if isinstance(get_ipython(), InteractiveShell): - print(self) - return display(widgets.bargmann(self)) def _order_batch(self): diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 9f7137b34..db14a9f48 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -18,7 +18,6 @@ from functools import cached_property import numpy as np -from IPython import get_ipython, InteractiveShell from IPython.display import display from mrmustard import widgets @@ -547,7 +546,4 @@ def __repr__(self) -> str: return f"Wires{self.args}" def _ipython_display_(self): - if isinstance(get_ipython(), InteractiveShell): - print(self) - return display(widgets.wires(self)) From 177fded6ab0d0c2af40320fd24744e7f0dbf071d Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 21:47:49 +0000 Subject: [PATCH 06/10] make IN_INTERACTIVE_SHELL global --- mrmustard/lab_dev/circuit_components.py | 2 ++ mrmustard/lab_dev/states/dm.py | 2 ++ mrmustard/lab_dev/states/ket.py | 2 ++ mrmustard/physics/ansatz/array_ansatz.py | 8 ++++---- mrmustard/physics/ansatz/polyexp_ansatz.py | 2 ++ mrmustard/physics/wires.py | 2 ++ mrmustard/widgets/__init__.py | 3 +++ tests/test_lab_dev/test_circuit_components.py | 6 ++---- tests/test_physics/test_ansatz/test_array_ansatz.py | 6 ++---- tests/test_physics/test_ansatz/test_polyexp_ansatz.py | 6 ++---- tests/test_physics/test_wires.py | 6 ++---- 11 files changed, 25 insertions(+), 20 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index d0e1297df..da446f743 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -724,6 +724,8 @@ def __truediv__(self, other: Scalar) -> CircuitComponent: return self._from_attributes(Representation(self.ansatz / other, self.wires), self.name) def _ipython_display_(self): + if mmwidgets.IN_INTERACTIVE_SHELL: + return print(self) # both reps might return None rep_fn = mmwidgets.fock if isinstance(self.ansatz, ArrayAnsatz) else mmwidgets.bargmann rep_widget = rep_fn(self.ansatz) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 0b969540c..0b5695c5d 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -375,6 +375,8 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex return self.quadrature(quad, phi) def _ipython_display_(self): # pragma: no cover + if widgets.IN_INTERACTIVE_SHELL: + return print(self) is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=False, is_fock=is_fock)) diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index 9ea9151b1..24aa3ee81 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -335,6 +335,8 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex return math.abs(self.quadrature(quad, phi)) ** 2 def _ipython_display_(self): # pragma: no cover + if widgets.IN_INTERACTIVE_SHELL: + return print(self) is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=True, is_fock=is_fock)) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 7455a0af1..0b7846a70 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -207,10 +207,10 @@ def _generate_ansatz(self): self.array = [self._fn(**self._kwargs)] def _ipython_display_(self): - w = widgets.fock(self) - if w is None: - print(repr(self)) - return + if widgets.IN_INTERACTIVE_SHELL: + return print(self) + if (w := widgets.fock(self)) is None: + return print(repr(self)) display(w) def __add__(self, other: ArrayAnsatz) -> ArrayAnsatz: diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index aabb3fcba..360636ea1 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -568,6 +568,8 @@ def _generate_ansatz(self): self.c = c def _ipython_display_(self): + if widgets.IN_INTERACTIVE_SHELL: + return print(self) display(widgets.bargmann(self)) def _order_batch(self): diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index db14a9f48..98a4681a2 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -546,4 +546,6 @@ def __repr__(self) -> str: return f"Wires{self.args}" def _ipython_display_(self): + if widgets.IN_INTERACTIVE_SHELL: + return print(self) display(widgets.wires(self)) diff --git a/mrmustard/widgets/__init__.py b/mrmustard/widgets/__init__.py index 4c9026dd0..e10940374 100644 --- a/mrmustard/widgets/__init__.py +++ b/mrmustard/widgets/__init__.py @@ -17,11 +17,14 @@ import numpy as np import ipywidgets as widgets import plotly.graph_objs as go +from IPython import get_ipython +from IPython.terminal.interactiveshell import TerminalInteractiveShell from .css import FOCK, WIRES, TABLE, STATE NO_MARGIN = {"l": 0, "r": 0, "t": 0, "b": 0} +IN_INTERACTIVE_SHELL = isinstance(get_ipython(), TerminalInteractiveShell) def _batch_widget(obj, batch_size, widget_fn, *widget_args): diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index d8e16e417..aa3bbbc97 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -20,7 +20,6 @@ import numpy as np import pytest -from IPython import InteractiveShell from ipywidgets import HTML, Box, HBox, VBox from mrmustard import math, settings @@ -560,10 +559,9 @@ def test_ipython_repr_invalid_obj(self, mock_display): assert isinstance(title_widget, HTML) assert isinstance(wires_widget, HTML) - @patch("mrmustard.lab_dev.circuit_components.get_ipython") - def test_ipython_repr_interactive(self, mock_ipython, capsys): + @patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True) + def test_ipython_repr_interactive(self, capsys): """Test the IPython repr function.""" - mock_ipython.return_value = InteractiveShell() dgate = Dgate([1, 2], x=0.1, y=0.1).to_fock() dgate._ipython_display_() captured = capsys.readouterr() diff --git a/tests/test_physics/test_ansatz/test_array_ansatz.py b/tests/test_physics/test_ansatz/test_array_ansatz.py index a9f130ec3..410435d40 100644 --- a/tests/test_physics/test_ansatz/test_array_ansatz.py +++ b/tests/test_physics/test_ansatz/test_array_ansatz.py @@ -20,7 +20,6 @@ import numpy as np import pytest -from IPython import InteractiveShell from ipywidgets import HTML, HBox, Tab, VBox from plotly.graph_objs import FigureWidget @@ -263,10 +262,9 @@ def test_ipython_repr_expects_3_dims_or_less(self, mock_display): rep._ipython_display_() mock_display.assert_not_called() - @patch("mrmustard.physics.ansatz.array_ansatz.get_ipython") - def test_ipython_repr_interactive(self, mock_ipython, capsys): + @patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True) + def test_ipython_repr_interactive(self, capsys): """Test the IPython repr function.""" - mock_ipython.return_value = InteractiveShell() rep = ArrayAnsatz(np.random.random((1, 8)), batched=True) rep._ipython_display_() captured = capsys.readouterr() diff --git a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py index ed0b605e5..b432c7374 100644 --- a/tests/test_physics/test_ansatz/test_polyexp_ansatz.py +++ b/tests/test_physics/test_ansatz/test_polyexp_ansatz.py @@ -20,7 +20,6 @@ import numpy as np import pytest -from IPython import InteractiveShell from ipywidgets import HTML, Box, IntSlider, IntText, Stack, VBox from plotly.graph_objs import FigureWidget @@ -322,10 +321,9 @@ def test_ipython_repr_batched(self, mock_display): assert len(stack.children) == 2 assert all(box.layout.max_width == "50%" for box in stack.children) - @patch("mrmustard.physics.ansatz.polyexp_ansatz.get_ipython") - def test_ipython_repr_interactive(self, mock_ipython, capsys): + @patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True) + def test_ipython_repr_interactive(self, capsys): """Test the IPython repr function.""" - mock_ipython.return_value = InteractiveShell() rep = PolyExpAnsatz(*Abc_triple(2)) rep._ipython_display_() captured = capsys.readouterr() diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index 8b70745e6..b488483df 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -19,7 +19,6 @@ from unittest.mock import patch import pytest -from IPython import InteractiveShell from ipywidgets import HTML from mrmustard.physics.wires import Wires @@ -241,10 +240,9 @@ def test_ipython_repr(self, mock_display): [widget] = mock_display.call_args.args assert isinstance(widget, HTML) - @patch("mrmustard.physics.wires.get_ipython") - def test_ipython_repr_interactive(self, mock_ipython, capsys): + @patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True) + def test_ipython_repr_interactive(self, capsys): """Test the IPython repr function.""" - mock_ipython.return_value = InteractiveShell() wires = Wires({0}, {}, {3}, {3, 4}) wires._ipython_display_() captured = capsys.readouterr() From f446499f60a3bbac1f37fc594a70f309ee2a3916 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 21:52:09 +0000 Subject: [PATCH 07/10] do not use templating --- .github/workflows/tests_numpy.yml | 2 +- .github/workflows/tests_tensorflow.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests_numpy.yml b/.github/workflows/tests_numpy.yml index 7a63cb844..57dcb14eb 100644 --- a/.github/workflows/tests_numpy.yml +++ b/.github/workflows/tests_numpy.yml @@ -63,4 +63,4 @@ jobs: REF_NAME: ${{ github.ref_name }} run: | grep ' call ' durations.txt | awk '{print $3,$1}' > ${{ steps.record_file.outputs.filename }} - aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/numpy_tests/${{ env.REF_NAME }}/ + aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/numpy_tests/$REF_NAME/ diff --git a/.github/workflows/tests_tensorflow.yml b/.github/workflows/tests_tensorflow.yml index 9c55ed9e9..7e68c2a72 100644 --- a/.github/workflows/tests_tensorflow.yml +++ b/.github/workflows/tests_tensorflow.yml @@ -70,4 +70,4 @@ jobs: REF_NAME: ${{ github.ref_name }} run: | grep ' call ' durations.txt | awk '{print $3,$1}' > ${{ steps.record_file.outputs.filename }} - aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/tf_tests/${{ env.REF_NAME }}/ + aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/tf_tests/$REF_NAME/ From 9ee120935a476b4fc46c58828ee79c23b2c55911 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 21:57:00 +0000 Subject: [PATCH 08/10] codefactor --- .pylintrc | 2 +- mrmustard/lab_dev/circuit_components.py | 3 ++- mrmustard/lab_dev/states/dm.py | 3 ++- mrmustard/lab_dev/states/ket.py | 3 ++- mrmustard/physics/ansatz/array_ansatz.py | 7 +++---- mrmustard/physics/ansatz/polyexp_ansatz.py | 3 ++- mrmustard/physics/wires.py | 3 ++- 7 files changed, 14 insertions(+), 10 deletions(-) diff --git a/.pylintrc b/.pylintrc index 47a5c2046..33b59d5ad 100644 --- a/.pylintrc +++ b/.pylintrc @@ -28,4 +28,4 @@ ignored-classes=numpy,tensorflow,scipy,networkx,strawberryfields,thewalrus # can either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). -disable=fixme,no-member,line-too-long,invalid-name,too-many-lines,redefined-builtin,too-many-locals,duplicate-code,too-many-arguments,too-few-public-methods,no-else-return,isinstance-second-argument-not-valid-type,no-self-argument, arguments-differ, protected-access +disable=fixme,no-member,line-too-long,invalid-name,too-many-lines,redefined-builtin,too-many-locals,duplicate-code,too-many-arguments,too-many-positional-arguments,too-few-public-methods,no-else-return,isinstance-second-argument-not-valid-type,no-self-argument, arguments-differ, protected-access diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index da446f743..fd3031572 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -725,7 +725,8 @@ def __truediv__(self, other: Scalar) -> CircuitComponent: def _ipython_display_(self): if mmwidgets.IN_INTERACTIVE_SHELL: - return print(self) + print(self) + return # both reps might return None rep_fn = mmwidgets.fock if isinstance(self.ansatz, ArrayAnsatz) else mmwidgets.bargmann rep_widget = rep_fn(self.ansatz) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index 0b5695c5d..5b97a0553 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -376,7 +376,8 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex def _ipython_display_(self): # pragma: no cover if widgets.IN_INTERACTIVE_SHELL: - return print(self) + print(self) + return is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=False, is_fock=is_fock)) diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index 24aa3ee81..497eaee7c 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -336,7 +336,8 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex def _ipython_display_(self): # pragma: no cover if widgets.IN_INTERACTIVE_SHELL: - return print(self) + print(self) + return is_fock = isinstance(self.ansatz, ArrayAnsatz) display(widgets.state(self, is_ket=True, is_fock=is_fock)) diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index 0b7846a70..d792957ef 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -207,10 +207,9 @@ def _generate_ansatz(self): self.array = [self._fn(**self._kwargs)] def _ipython_display_(self): - if widgets.IN_INTERACTIVE_SHELL: - return print(self) - if (w := widgets.fock(self)) is None: - return print(repr(self)) + if widgets.IN_INTERACTIVE_SHELL or (w := widgets.fock(self)) is None: + print(self) + return display(w) def __add__(self, other: ArrayAnsatz) -> ArrayAnsatz: diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 360636ea1..8966716ce 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -569,7 +569,8 @@ def _generate_ansatz(self): def _ipython_display_(self): if widgets.IN_INTERACTIVE_SHELL: - return print(self) + print(self) + return display(widgets.bargmann(self)) def _order_batch(self): diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index 98a4681a2..b25e92cdd 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -547,5 +547,6 @@ def __repr__(self) -> str: def _ipython_display_(self): if widgets.IN_INTERACTIVE_SHELL: - return print(self) + print(self) + return display(widgets.wires(self)) From 39732f360db207291f730b1932dc3e1cd7b49149 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 22:03:13 +0000 Subject: [PATCH 09/10] trigger ci From 1857c6bb99c34556f105fb77c0257ca34ca595a7 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Fri, 6 Dec 2024 22:24:10 +0000 Subject: [PATCH 10/10] try to disable manually --- mrmustard/lab_dev/circuit_components.py | 4 ++-- mrmustard/physics/ansatz/polyexp_ansatz.py | 2 +- mrmustard/physics/wires.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index fd3031572..944b6e5e6 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -222,7 +222,7 @@ def from_bargmann( modes_out_ket: Sequence[int] = (), modes_in_ket: Sequence[int] = (), name: str | None = None, - ) -> CircuitComponent: + ) -> CircuitComponent: # pylint:disable=too-many-positional-arguments r""" Initializes a ``CircuitComponent`` object from its Bargmann (A,b,c) parametrization. @@ -251,7 +251,7 @@ def from_quadrature( triple: tuple, phi: float = 0.0, name: str | None = None, - ) -> CircuitComponent: + ) -> CircuitComponent: # pylint:disable=too-many-positional-arguments r""" Returns a circuit component from the given triple (A,b,c) that parametrizes the quadrature wavefunction of this component in the form :math:`c * exp(1/2 x^T A x + b^T x)`. diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 8966716ce..937f76fc9 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -16,7 +16,7 @@ This module contains the PolyExp ansatz. """ -# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-instance-attributes,too-many-positional-arguments from __future__ import annotations diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index b25e92cdd..2ebf75320 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -169,7 +169,7 @@ def __init__( modes_in_ket: set[int] | None = None, classical_out: set[int] | None = None, classical_in: set[int] | None = None, - ) -> None: + ) -> None: # pylint:disable=too-many-positional-arguments self.args: tuple[set, ...] = ( modes_out_bra or set(), modes_in_bra or set(),