Skip to content

Commit

Permalink
Add pretty repr to devices and result types (#4649)
Browse files Browse the repository at this point in the history
Implements `_repr_pretty_` which pretty prints in iPython (jupyter notebooks)

Addresses #682
  • Loading branch information
dabacon authored Nov 11, 2021
1 parent 5428606 commit dc6f569
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 26 deletions.
12 changes: 12 additions & 0 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def __str__(self) -> str:
final = self._final_simulator_state
return f'measurements: {samples}\noutput state: {final}'

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
if cycle:
# There should never be a cycle. This is just in case.
p.text('cirq.MPSTrialResult(...)')
else:
p.text(str(self))


class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState', 'MPSState']):
"""A `StepResult` that can perform measurements."""
Expand Down Expand Up @@ -201,6 +209,10 @@ def bitstring(vals):

return f'{measurements}{final}'

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.MPSSimulatorStepResult(...)" if cycle else self.__str__())

def _simulator_state(self):
return self.state

Expand Down
38 changes: 38 additions & 0 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cirq
import cirq.contrib.quimb as ccq
import cirq.experiments.google_v2_supremacy_circuit as supremacy_v2
import cirq.testing
from cirq import value


Expand Down Expand Up @@ -275,6 +276,29 @@ def test_trial_result_str():
)


def test_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState(
qubits=(q0,),
prng=value.parse_random_state(0),
simulation_options=ccq.mps_simulator.MPSOptions(),
)
result = ccq.mps_simulator.MPSTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
)
cirq.testing.assert_repr_pretty(
result,
"""measurements: m=1
output state: TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=set()),
])""",
)
cirq.testing.assert_repr_pretty(result, "cirq.MPSTrialResult(...)", cycle=True)


def test_empty_step_result():
q0 = cirq.LineQubit(0)
sim = ccq.mps_simulator.MPSSimulator()
Expand All @@ -288,6 +312,20 @@ def test_empty_step_result():
)


def test_step_result_repr_pretty():
q0 = cirq.LineQubit(0)
sim = ccq.mps_simulator.MPSSimulator()
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
cirq.testing.assert_repr_pretty(
step_result,
"""0=0
TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=set()),
])""",
)
cirq.testing.assert_repr_pretty(step_result, "cirq.MPSSimulatorStepResult(...)", cycle=True)


def test_state_equal():
q0, q1 = cirq.LineQubit.range(2)
state0 = ccq.mps_simulator.MPSState(
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ion/ion_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def __str__(self) -> str:

return diagram.render(horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True)

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("IonDevice(...)" if cycle else self.__str__())

def _value_equality_values_(self) -> Any:
return (
self._measurement_duration,
Expand Down
13 changes: 7 additions & 6 deletions cirq-core/cirq/ion/ion_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
import cirq.ion as ci
import cirq.testing


def ion_device(chain_length: int, use_timedelta=False) -> ci.IonDevice:
Expand Down Expand Up @@ -183,12 +184,12 @@ def test_validate_circuit_repeat_measurement_keys():


def test_ion_device_str():
assert (
str(ion_device(3)).strip()
== """
0───1───2
""".strip()
)
assert str(ion_device(3)) == "0───1───2"


def test_ion_device_pretty_repr():
cirq.testing.assert_repr_pretty(ion_device(3), "0───1───2")
cirq.testing.assert_repr_pretty(ion_device(3), "IonDevice(...)", cycle=True)


def test_at():
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,7 @@ def __str__(self) -> str:
diagram.grid_line(q.col, q.row, q2.col, q2.row)

return diagram.render(horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True)

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.NeutralAtomDevice(...)" if cycle else self.__str__())
14 changes: 14 additions & 0 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import cirq
import cirq.neutral_atoms as neutral_atoms
import cirq.testing


def square_device(
Expand Down Expand Up @@ -266,5 +267,18 @@ def test_str():
)


def test_repr_pretty():
cirq.testing.assert_repr_pretty(
square_device(2, 2),
"""
(0, 0)───(0, 1)
│ │
│ │
(1, 0)───(1, 1)
""".strip(),
)
cirq.testing.assert_repr_pretty(square_device(2, 2), "cirq.NeutralAtomDevice(...)", cycle=True)


def test_qubit_set():
assert square_device(2, 2).qubit_set() == frozenset(cirq.GridQubit.square(2, 0, 0))
8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def __str__(self) -> str:
final = self._final_simulator_state
return f'measurements: {samples}\noutput state: {final}'

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.CliffordTrialResult(...)" if cycle else self.__str__())


class CliffordSimulatorStepResult(
simulator_base.StepResultBase['clifford.CliffordState', 'clifford.ActOnStabilizerCHFormArgs']
Expand Down Expand Up @@ -168,6 +172,10 @@ def bitstring(vals):

return f'{measurements}{final}'

def _repr_pretty_(self, p, cycle):
"""iPython (Jupyter) pretty print."""
p.text("cirq.CliffordSimulatorStateResult(...)" if cycle else self.__str__())

@property
def state(self):
if self._clifford_state is None:
Expand Down
23 changes: 23 additions & 0 deletions cirq-core/cirq/sim/clifford/clifford_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,20 @@ def test_clifford_trial_result_str():
)


def test_clifford_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
result = cirq.CliffordTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
)

cirq.testing.assert_repr_pretty(result, "measurements: m=1\n" "output state: |0⟩")
cirq.testing.assert_repr_pretty(result, "cirq.CliffordTrialResult(...)", cycle=True)


def test_clifford_step_result_str():
q0 = cirq.LineQubit(0)
result = next(
Expand All @@ -252,6 +266,15 @@ def test_clifford_step_result_str():
assert str(result) == "m=0\n" "|0⟩"


def test_clifford_step_result_repr_pretty():
q0 = cirq.LineQubit(0)
result = next(
cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.measure(q0, key='m')))
)
cirq.testing.assert_repr_pretty(result, "m=0\n" "|0⟩")
cirq.testing.assert_repr_pretty(result, "cirq.CliffordSimulatorStateResult(...)", cycle=True)


def test_clifford_step_result_no_measurements_str():
q0 = cirq.LineQubit(0)
result = next(cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.I(q0))))
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,7 @@ def __repr__(self) -> str:
f'params={self.params!r}, measurements={self.measurements!r}, '
f'final_simulator_state={self._final_simulator_state!r})'
)

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.DensityMatrixTrialResult(...)" if cycle else self.__str__())
23 changes: 23 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sympy

import cirq
import cirq.testing


class PlusGate(cirq.Gate):
Expand Down Expand Up @@ -1188,6 +1189,28 @@ def test_density_matrix_trial_result_str():
)


def test_density_matrix_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}
)
result = cirq.DensityMatrixTrialResult(
params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result
)

fake_printer = cirq.testing.FakePrinter()
result._repr_pretty_(fake_printer, cycle=False)
# numpy varies whitespace in its representation for different versions
# Eliminate whitespace to harden tests against this variation
result_no_whitespace = fake_printer.text_pretty.replace('\n', '').replace(' ', '')
assert result_no_whitespace == (
'measurements:(nomeasurements)finaldensitymatrix:[[0.50.5][0.50.5]]'
)

cirq.testing.assert_repr_pretty(result, "cirq.DensityMatrixTrialResult(...)", cycle=True)


def test_run_sweep_parameters_not_resolved():
a = cirq.LineQubit(0)
simulator = cirq.DensityMatrixSimulator()
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def __str__(self) -> str:
state_vector = str(final)
return f'measurements: {samples}\noutput vector: {state_vector}'

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Text output in Jupyter."""
def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
if cycle:
# There should never be a cycle. This is just in case.
p.text('StateVectorTrialResult(...)')
Expand Down
41 changes: 25 additions & 16 deletions cirq-core/cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
assert_all_implemented_act_on_effects_match_unitary,
)

from cirq.testing.consistent_phase_by import (
assert_phase_by_is_consistent_with_unitary,
)

from cirq.testing.consistent_controlled_gate_op import (
assert_controlled_and_controlled_by_identical,
)
Expand All @@ -44,6 +40,10 @@
assert_pauli_expansion_is_consistent_with_unitary,
)

from cirq.testing.consistent_phase_by import (
assert_phase_by_is_consistent_with_unitary,
)

from cirq.testing.consistent_protocols import (
assert_eigengate_implements_consistent_protocols,
assert_has_consistent_trace_distance_bound,
Expand All @@ -63,6 +63,10 @@
assert_specifies_has_unitary_if_unitary,
)

from cirq.testing.deprecation import (
assert_deprecated,
)

from cirq.testing.devices import (
ValidatingTestDevice,
)
Expand All @@ -71,11 +75,18 @@
EqualsTester,
)

from cirq.testing.equivalent_basis_map import (
assert_equivalent_computational_basis_map,
)

from cirq.testing.equivalent_repr_eval import (
assert_equivalent_repr,
)

from cirq.testing.equivalent_basis_map import assert_equivalent_computational_basis_map
from cirq.testing.gate_features import (
TwoQubitGate,
ThreeQubitGate,
)

from cirq.testing.json import (
assert_json_roundtrip_works,
Expand All @@ -95,15 +106,14 @@
assert_logs,
)

from cirq.testing.gate_features import (
TwoQubitGate,
ThreeQubitGate,
)

from cirq.testing.no_identifier_qubit import (
NoIdentifierQubit,
)

from cirq.testing.op_tree import (
assert_equivalent_op_tree,
)

from cirq.testing.order_tester import (
OrderTester,
)
Expand All @@ -114,12 +124,11 @@
random_two_qubit_circuit_with_czs,
)

from cirq.testing.sample_circuits import (
nonoptimal_toffoli_circuit,
from cirq.testing.repr_pretty_tester import (
assert_repr_pretty,
FakePrinter,
)

from cirq.testing.deprecation import (
assert_deprecated,
from cirq.testing.sample_circuits import (
nonoptimal_toffoli_circuit,
)

from cirq.testing.op_tree import assert_equivalent_op_tree
5 changes: 3 additions & 2 deletions cirq-core/cirq/testing/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import os
from typing import Optional

from cirq.testing import assert_logs

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'


Expand All @@ -42,6 +40,9 @@ def __enter__(self):
os.environ.get(ALLOW_DEPRECATION_IN_TEST, None),
)
os.environ[ALLOW_DEPRECATION_IN_TEST] = 'True'
# Avoid circular import.
from cirq.testing import assert_logs

self.assert_logs = assert_logs(
*(msgs + (deadline,)),
min_level=logging.WARNING,
Expand Down
Loading

0 comments on commit dc6f569

Please sign in to comment.