Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the ipython_display side effect #531

Merged
merged 11 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
* Fix the TensorFlow issue with the expected number of gradients in `custom_gradient`.
[(#506)](https://github.com/XanaduAI/MrMustard/pull/506)

* Use the default repr when in interactive IPython.
[(#531)](https://github.com/XanaduAI/MrMustard/pull/531)

### Documentation

### Tests
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/tests_numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ jobs:

- name: Copy durations to s3
if: github.event_name == 'push'
env:
REF_NAME: ${{ github.ref_name }}
run: |
grep ' call ' durations.txt | awk '{print $3,$1}' > ${{ steps.record_file.outputs.filename }}
aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/numpy_tests/${{ github.ref_name }}/
aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/numpy_tests/$REF_NAME/
4 changes: 3 additions & 1 deletion .github/workflows/tests_tensorflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ jobs:

- name: Copy durations to s3
if: github.event_name == 'push'
env:
REF_NAME: ${{ github.ref_name }}
run: |
grep ' call ' durations.txt | awk '{print $3,$1}' > ${{ steps.record_file.outputs.filename }}
aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/tf_tests/${{ github.ref_name }}/
aws s3 cp ./${{ steps.record_file.outputs.filename }} s3://${{ secrets.AWS_TIMINGS_BUCKET }}/tf_tests/$REF_NAME/
7 changes: 5 additions & 2 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def from_bargmann(
modes_out_ket: Sequence[int] = (),
modes_in_ket: Sequence[int] = (),
name: str | None = None,
) -> CircuitComponent:
) -> CircuitComponent: # pylint:disable=too-many-positional-arguments
r"""
Initializes a ``CircuitComponent`` object from its Bargmann (A,b,c) parametrization.

Expand Down Expand Up @@ -251,7 +251,7 @@ def from_quadrature(
triple: tuple,
phi: float = 0.0,
name: str | None = None,
) -> CircuitComponent:
) -> CircuitComponent: # pylint:disable=too-many-positional-arguments
r"""
Returns a circuit component from the given triple (A,b,c) that parametrizes the
quadrature wavefunction of this component in the form :math:`c * exp(1/2 x^T A x + b^T x)`.
Expand Down Expand Up @@ -724,6 +724,9 @@ def __truediv__(self, other: Scalar) -> CircuitComponent:
return self._from_attributes(Representation(self.ansatz / other, self.wires), self.name)

def _ipython_display_(self):
if mmwidgets.IN_INTERACTIVE_SHELL:
print(self)
return
# both reps might return None
rep_fn = mmwidgets.fock if isinstance(self.ansatz, ArrayAnsatz) else mmwidgets.bargmann
rep_widget = rep_fn(self.ansatz)
Expand Down
3 changes: 3 additions & 0 deletions mrmustard/lab_dev/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex
return self.quadrature(quad, phi)

def _ipython_display_(self): # pragma: no cover
if widgets.IN_INTERACTIVE_SHELL:
print(self)
return
is_fock = isinstance(self.ansatz, ArrayAnsatz)
display(widgets.state(self, is_ket=False, is_fock=is_fock))

Expand Down
3 changes: 3 additions & 0 deletions mrmustard/lab_dev/states/ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex
return math.abs(self.quadrature(quad, phi)) ** 2

def _ipython_display_(self): # pragma: no cover
if widgets.IN_INTERACTIVE_SHELL:
print(self)
return
is_fock = isinstance(self.ansatz, ArrayAnsatz)
display(widgets.state(self, is_ket=True, is_fock=is_fock))

Expand Down
5 changes: 2 additions & 3 deletions mrmustard/physics/ansatz/array_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,8 @@ def _generate_ansatz(self):
self.array = [self._fn(**self._kwargs)]

def _ipython_display_(self):
w = widgets.fock(self)
if w is None:
print(repr(self))
if widgets.IN_INTERACTIVE_SHELL or (w := widgets.fock(self)) is None:
print(self)
return
display(w)

Expand Down
5 changes: 4 additions & 1 deletion mrmustard/physics/ansatz/polyexp_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
This module contains the PolyExp ansatz.
"""

# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-instance-attributes,too-many-positional-arguments

from __future__ import annotations

Expand Down Expand Up @@ -568,6 +568,9 @@ def _generate_ansatz(self):
self.c = c

def _ipython_display_(self):
if widgets.IN_INTERACTIVE_SHELL:
print(self)
return
display(widgets.bargmann(self))

def _order_batch(self):
Expand Down
5 changes: 4 additions & 1 deletion mrmustard/physics/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
modes_in_ket: set[int] | None = None,
classical_out: set[int] | None = None,
classical_in: set[int] | None = None,
) -> None:
) -> None: # pylint:disable=too-many-positional-arguments
self.args: tuple[set, ...] = (
modes_out_bra or set(),
modes_in_bra or set(),
Expand Down Expand Up @@ -546,4 +546,7 @@ def __repr__(self) -> str:
return f"Wires{self.args}"

def _ipython_display_(self):
if widgets.IN_INTERACTIVE_SHELL:
print(self)
return
display(widgets.wires(self))
3 changes: 3 additions & 0 deletions mrmustard/widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import numpy as np
import ipywidgets as widgets
import plotly.graph_objs as go
from IPython import get_ipython
from IPython.terminal.interactiveshell import TerminalInteractiveShell

from .css import FOCK, WIRES, TABLE, STATE


NO_MARGIN = {"l": 0, "r": 0, "t": 0, "b": 0}
IN_INTERACTIVE_SHELL = isinstance(get_ipython(), TerminalInteractiveShell)


def _batch_widget(obj, batch_size, widget_fn, *widget_args):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_lab_dev/test_circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,14 @@ def test_ipython_repr_invalid_obj(self, mock_display):
assert isinstance(title_widget, HTML)
assert isinstance(wires_widget, HTML)

@patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True)
def test_ipython_repr_interactive(self, capsys):
"""Test the IPython repr function."""
dgate = Dgate([1, 2], x=0.1, y=0.1).to_fock()
dgate._ipython_display_()
captured = capsys.readouterr()
assert captured.out.rstrip() == repr(dgate)

def test_serialize_default_behaviour(self):
"""Test the default serializer."""
name = "my_component"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_physics/test_ansatz/test_array_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,11 @@ def test_ipython_repr_expects_3_dims_or_less(self, mock_display):
rep = ArrayAnsatz(np.random.random((1, 4, 4, 4)), batched=True)
rep._ipython_display_()
mock_display.assert_not_called()

@patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True)
def test_ipython_repr_interactive(self, capsys):
"""Test the IPython repr function."""
rep = ArrayAnsatz(np.random.random((1, 8)), batched=True)
rep._ipython_display_()
captured = capsys.readouterr()
assert captured.out.rstrip() == repr(rep)
8 changes: 8 additions & 0 deletions tests/test_physics/test_ansatz/test_polyexp_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ def test_ipython_repr_batched(self, mock_display):
assert len(stack.children) == 2
assert all(box.layout.max_width == "50%" for box in stack.children)

@patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True)
def test_ipython_repr_interactive(self, capsys):
"""Test the IPython repr function."""
rep = PolyExpAnsatz(*Abc_triple(2))
rep._ipython_display_()
captured = capsys.readouterr()
assert captured.out.rstrip() == repr(rep)

def test_matmul_barg_barg(self):
triple1 = Abc_triple(3)
triple2 = Abc_triple(3)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_physics/test_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,22 @@ def test_matmul_error(self):
with pytest.raises(ValueError):
u @ v # pylint: disable=pointless-statement


class TestWiresDisplay:
"""Test the wires _ipython_display_ functionality."""

@patch("mrmustard.physics.wires.display")
def test_ipython_repr(self, mock_display):
"""Test the IPython repr function."""
wires = Wires({0}, {}, {3}, {3, 4})
wires._ipython_display_()
[widget] = mock_display.call_args.args
assert isinstance(widget, HTML)

@patch("mrmustard.widgets.IN_INTERACTIVE_SHELL", True)
def test_ipython_repr_interactive(self, capsys):
"""Test the IPython repr function."""
wires = Wires({0}, {}, {3}, {3, 4})
wires._ipython_display_()
captured = capsys.readouterr()
assert captured.out.rstrip() == repr(wires)
Loading