Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding __matmul__ to CircuitComponents #347

Merged
merged 34 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
774681f
Silencing TensorFlow warning (#332)
Feb 1, 2024
96ea206
Fixing a Scipy error (#337)
Feb 2, 2024
429b16c
Removing TF warnings on import with `os` instead of `logging` (#339)
Feb 6, 2024
db1f4d5
Remove warnings from state plotter (#343)
Feb 8, 2024
ab1f82d
Fix `math.Categorical` (#342)
Feb 8, 2024
312d7e2
Wires final (merge second) (#330)
ziofil Feb 9, 2024
87ee2dc
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
SamFerracin Feb 9, 2024
3e7a91a
fixed
SamFerracin Feb 9, 2024
bf86dfd
empty
SamFerracin Feb 12, 2024
6e97ca4
wires >> to @
SamFerracin Feb 13, 2024
c8dfa6a
eq and neq for wires
SamFerracin Feb 13, 2024
95d0602
add triples.py file in the physics (#338)
sylviemonet Feb 13, 2024
f6292cf
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
SamFerracin Feb 13, 2024
edb9675
Merge branch 'merge-dev' of https://github.com/XanaduAI/MrMustard int…
SamFerracin Feb 13, 2024
0050424
fixed matmul
SamFerracin Feb 14, 2024
b69efc7
added tests
SamFerracin Feb 14, 2024
6f3492a
conflics
SamFerracin Feb 14, 2024
790f90c
oops
SamFerracin Feb 14, 2024
2dac56c
fix the c of the attenuator triples needs to be 1 (#346)
sylviemonet Feb 14, 2024
aceffca
clean up
SamFerracin Feb 14, 2024
d4071ce
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
SamFerracin Feb 14, 2024
4c81fac
black
SamFerracin Feb 14, 2024
10d2390
fixed math.gather
SamFerracin Feb 14, 2024
4f112dd
fixed tests and docs of math.gather
SamFerracin Feb 14, 2024
b6c8ac4
error fixed elsewhere
SamFerracin Feb 14, 2024
96fa067
Docs and tests for Fock and ArrayAnsatz, and a few minor changes to R…
Feb 15, 2024
650ec58
add global phase tests for two dgates
sylviemonet Feb 20, 2024
47b76d9
Fix the small bugs (#349)
sylviemonet Feb 21, 2024
bea2138
Fix the length of b Vector for all the objects and the corresponding …
sylviemonet Feb 21, 2024
d7ae0be
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
SamFerracin Feb 21, 2024
851fa77
CR
SamFerracin Feb 21, 2024
99b34b3
Merge branch 'matmul-cc' of https://github.com/XanaduAI/MrMustard int…
SamFerracin Feb 21, 2024
e9571b5
Merge branch 'matmul2-cc' into matmul-cc
SamFerracin Feb 21, 2024
66aca97
black
SamFerracin Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mrmustard/lab_dev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .circuit_components import CircuitComponent
from .circuits import Circuit
from .simulator import Simulator
from .states import *
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved
from .transformations import *
from .wires import Wires
56 changes: 54 additions & 2 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing import Optional, Sequence, Union

from ..physics.representations import Bargmann, Representation
from ..physics.representations import Bargmann, Fock, Representation
from ..math.parameter_set import ParameterSet
from ..math.parameters import Constant, Variable
from ..utils.typing import Batch, ComplexMatrix, ComplexTensor, ComplexVector, Mode
Expand Down Expand Up @@ -140,6 +140,7 @@ def wires(self) -> Wires:
"""
return self._wires

@property
def adjoint(self) -> CircuitComponent:
r"""
Light-copies this component, then returns the adjoint of it, obtained by taking the
Expand All @@ -151,6 +152,7 @@ def adjoint(self) -> CircuitComponent:
representation = ret.representation.conj()
return CircuitComponent.from_attributes(name, wires, representation)

@property
def dual(self) -> CircuitComponent:
r"""
Light-copies this component, then returns the dual of it, obtained by taking the
Expand All @@ -172,6 +174,14 @@ def light_copy(self) -> CircuitComponent:
instance.__dict__["_wires"] = self.wires.copy()
return instance

def __eq__(self, other) -> bool:
r"""
Whether this component is equal to another component.

Compares representations and wires, but not the other attributes (including name and parameter set).
"""
return self.representation == other.representation and self.wires == other.wires

def __getitem__(self, idx: Union[Mode, Sequence[Mode]]):
r"""
Returns a slice of this component for the given modes.
Expand All @@ -181,6 +191,48 @@ def __getitem__(self, idx: Union[Mode, Sequence[Mode]]):
ret._parameter_set = self.parameter_set
return ret

def __matmul__(self, other: CircuitComponent) -> CircuitComponent:
r"""
Contracts ``self`` and ``other``, without adding adjoints.
"""
# initialized the ``Wires`` of the returned component
wires_ret = self.wires @ other.wires

# find the indices of the wires being contracted on the bra side
ziofil marked this conversation as resolved.
Show resolved Hide resolved
bra_modes = set(self.wires.bra.output.modes).intersection(other.wires.bra.input.modes)
idx_z = self.wires[bra_modes].bra.output.indices
idx_zconj = other.wires[bra_modes].bra.input.indices

# find the indices of the wires being contracted on the ket side
ket_modes = set(self.wires.ket.output.modes).intersection(other.wires.ket.input.modes)
idx_z += self.wires[ket_modes].ket.output.indices
idx_zconj += other.wires[ket_modes].ket.input.indices

# convert Bargmann -> Fock if needed
LEFT = self.representation
RIGHT = other.representation
if isinstance(LEFT, Bargmann) and isinstance(RIGHT, Fock):
raise ValueError("Cannot contract objects with different representations.")
# shape = [s if i in idx_z else None for i, s in enumerate(other.representation.shape)]
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved
# LEFT = Fock(self.fock(shape=shape), batched=False)
elif isinstance(LEFT, Fock) and isinstance(RIGHT, Bargmann):
raise ValueError("Cannot contract objects with different representations.")
# shape = [s if i in idx_zconj else None for i, s in enumerate(self.representation.shape)]
# RIGHT = Fock(other.fock(shape=shape), batched=False)

# calculate the representation of the returned component
representation_ret = LEFT[idx_z] @ RIGHT[idx_zconj]

# reorder the representation
contracted_idx = [self.wires.ids[i] for i in range(len(self.wires.ids)) if i not in idx_z]
contracted_idx += [
other.wires.ids[i] for i in range(len(other.wires.ids)) if i not in idx_zconj
]
order = [contracted_idx.index(id) for id in wires_ret.ids]
representation_ret = representation_ret.reorder(order)

return CircuitComponent.from_attributes("", wires_ret, representation_ret)


def connect(components: Sequence[CircuitComponent]) -> Sequence[CircuitComponent]:
r"""
Expand Down Expand Up @@ -227,6 +279,6 @@ def add_bra(components: Sequence[CircuitComponent]) -> Sequence[CircuitComponent
for component in components:
ret.append(component.light_copy())
if not component.wires.bra:
ret.append(component.adjoint())
ret.append(component.adjoint)

return ret
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _contract(self, components: Sequence[CircuitComponent]) -> CircuitComponent:
# calculate the ``Wires`` of the returned component, alongside its substrings
wires_out = components[0].wires
for c in components[1:]:
wires_out >>= c.wires
wires_out @= c.wires
subs_out = "".join([ids_to_subs[id] for id in wires_out.ids])

# grab the representation that remains in ``subs_to_rep``
Expand Down
10 changes: 5 additions & 5 deletions mrmustard/lab_dev/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ def representation(self) -> Bargmann:
A = math.cast(
np.array(
[
[0, np.sqrt(e), 1 - e, 0],
[np.sqrt(e), 0, 0, 0],
[1 - e, 0, 0, np.sqrt(e)],
[0, 0, np.sqrt(e), 0],
[0, np.sqrt(e), 0, 0],
[np.sqrt(e), 0, 0, 1 - e],
[0, 0, 0, np.sqrt(e)],
[0, 1 - e, np.sqrt(e), 0],
],
),
math.complex128,
)
B = math.cast([0.0, 0.0, 0.0, 0.0], math.complex128)
C = math.cast(np.sqrt(e), math.complex128)
C = 1

return Bargmann(A, B, C)
38 changes: 30 additions & 8 deletions mrmustard/lab_dev/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,26 @@ def __bool__(self) -> bool:
"""
return len(self.ids) > 0

def __eq__(self, other) -> bool:
r"""
Returns ``True`` if this ``Wires`` acts on the same modes as ``other``, ``False`` otherwise.
"""
if self.output.bra.modes != other.output.bra.modes:
return False
if self.input.bra.modes != other.input.bra.modes:
return False
if self.output.ket.modes != other.output.ket.modes:
return False
if self.input.ket.modes != other.input.ket.modes:
return False
return True

def __neq__(self, other) -> bool:
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved
r"""
Returns ``False`` if this ``Wires`` is equal to ``other``, ``True`` otherwise.
"""
return not self == other

def __getitem__(self, modes: Iterable[int] | int) -> Wires:
r"""
A view of this Wires object with wires only on the given modes.
Expand All @@ -308,9 +328,6 @@ def __getitem__(self, modes: Iterable[int] | int) -> Wires:
idxs = tuple(list(self._modes).index(m) for m in set(self._modes).difference(modes))
return self._view(masked_rows=idxs)

def __lshift__(self, other: Wires) -> Wires:
return (other.dual >> self.dual).dual # how cool is this

@staticmethod
def _outin(self_in: int, self_out: int, other_in: int, other_out: int) -> np.ndarray:
r"""
Expand All @@ -331,12 +348,17 @@ def _outin(self_in: int, self_out: int, other_in: int, other_out: int) -> np.nda
else: # no wires on other
return np.array([self_out, self_in], dtype=np.int64)

def __rshift__(self, other: Wires) -> Wires:
def __matmul__(self, other: Wires) -> Wires:
r"""
A new Wires object with the wires of ``self`` and ``other`` combined as two
components in a circuit: the output of self connects to the input of other wherever
they match. All surviving wires are arranged in the standard order.
A ValueError is raised if there are any surviving wires that overlap.
A new ``Wires`` object with the wires of ``self`` and ``other`` combined.

The output of ``self`` connects to the input of ``other`` wherever they match. All
surviving wires are arranged in the standard order.

This function does not add missing adjoints.

Raises:
ValueError: If there are any surviving wires that overlap.
"""
all_modes = sorted(set(self.modes) | set(other.modes))
new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def attenuator_Abc(eta: Union[Scalar, Iterable]) -> Union[Matrix, Vector, Scalar

A = A[reshape_list, :][:, reshape_list]
b = _vacuum_B_vector(n_modes * 2)
c = np.prod(eta)
c = 1.0
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved

return A, b, c

Expand Down
79 changes: 78 additions & 1 deletion tests/test_lab_dev/test_circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
Tests for circuit components.
"""

import numpy as np

from mrmustard.physics.representations import Bargmann
from mrmustard.lab_dev.circuit_components import connect, add_bra, CircuitComponent
from mrmustard.lab_dev.states import Vacuum
from mrmustard.lab_dev.transformations import Dgate, Attenuator
from mrmustard.lab_dev.wires import Wires


class TestCircuitComponent:
Expand All @@ -37,6 +41,79 @@ def test_light_copy(self):
assert d.y is d_copy.y
assert d.wires is not d_copy.wires

def test_matmul_one_mode(self):
sylviemonet marked this conversation as resolved.
Show resolved Hide resolved
r"""
Tests that ``__matmul__`` produces the correct outputs for one-mode components.
"""
vac0 = Vacuum(1)
d0 = Dgate(1, modes=[0])
a0 = Attenuator(0.9, modes=[0])

result1 = vac0 @ d0
result1 = (result1 @ result1.adjoint) @ a0

assert result1.wires == Wires(modes_out_bra=[0], modes_out_ket=[0])
assert np.allclose(result1.representation.A, 0)
assert np.allclose(result1.representation.b, [0.9486833, 0.9486833])
assert np.allclose(result1.representation.c, 0.40656966)

result2 = result1 @ vac0.dual @ vac0.dual.adjoint
assert not result2.wires
assert np.allclose(result2.representation.A, 0)
assert np.allclose(result2.representation.b, 0)
assert np.allclose(result2.representation.c, 0.40656966)

def test_matmul_multi_modes(self):
r"""
Tests that ``__matmul__`` produces the correct outputs for multi-mode components.
"""
vac012 = Vacuum(3)
d0 = Dgate(0.1, 0.1, modes=[0])
d1 = Dgate(0.1, 0.1, modes=[1])
d2 = Dgate(0.1, 0.1, modes=[2])
a0 = Attenuator(0.8, modes=[0])
a1 = Attenuator(0.8, modes=[1])
a2 = Attenuator(0.7, modes=[2])

result = vac012 @ d0 @ d1 @ d2
result = result @ result.adjoint @ a0 @ a1 @ a2

assert result.wires == Wires(modes_out_bra=[0, 1, 2], modes_out_ket=[0, 1, 2])
assert np.allclose(result.representation.A, 0)
assert np.allclose(
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved
result.representation.b,
[
0.08944272 - 0.08944272j,
0.08944272 - 0.08944272j,
0.083666 - 0.083666j,
0.08944272 + 0.08944272j,
0.08944272 + 0.08944272j,
0.083666 + 0.083666j,
],
)
assert np.allclose(result.representation.c, 0.95504196)

def test_matmul_is_associative(self):
r"""
Tests that ``__matmul__`` is associative, meaning ``a @ (b @ c) == (a @ b) @ c``.
"""
vac012 = Vacuum(3)
d0 = Dgate(0.1, 0.1, modes=[0])
d1 = Dgate(0.1, 0.1, modes=[1])
d2 = Dgate(0.1, 0.1, modes=[2])
a0 = Attenuator(0.8, modes=[0])
a1 = Attenuator(0.8, modes=[1])
a2 = Attenuator(0.7, modes=[2])

result1 = vac012 @ d0 @ d1 @ a0 @ a1 @ a2 @ d2
result2 = (vac012 @ d0) @ (d1 @ a0) @ a1 @ (a2 @ d2)
result3 = vac012 @ (d0 @ (d1 @ a0 @ a1) @ a2 @ d2)
result4 = vac012 @ (d0 @ (d1 @ (a0 @ (a1 @ (a2 @ d2)))))

assert result1 == result2
assert result1 == result3
assert result1 == result4


class TestConnect:
r"""
Expand Down Expand Up @@ -82,7 +159,7 @@ def test_ket_and_bra(self):
Tests the ``connect`` function with components with kets and bras.
"""
d1 = Dgate(1, modes=[0, 8, 9])
d1_adj = d1.adjoint()
d1_adj = d1.adjoint
a1 = Attenuator(0.1, modes=[8])

components = connect([d1, d1_adj, a1])
Expand Down
23 changes: 19 additions & 4 deletions tests/test_lab_dev/test_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,30 @@ def test_getitem(self):
assert w1.modes == [1]
assert w1.ids == [w.ids[1], w.ids[3]]

def test_rshift(self):
def test_eq_neq(self):
w1 = Wires([0, 1], [2, 3], [4, 5], [6, 7])
w2 = Wires([0, 1], [2, 3], [4, 5], [6, 7])
w3 = Wires([], [2, 3], [4, 5], [6, 7])
w4 = Wires([0, 1], [], [4, 5], [6, 7])
w5 = Wires([0, 1], [2, 3], [], [6, 7])
w6 = Wires([0, 1], [2, 3], [4, 5], [])

assert w1 == w1
assert w1 == w2
assert w1 != w3
assert w1 != w4
assert w1 != w5
assert w1 != w6

def test_matmul(self):
# contracts 1,1 on bra side
# contracts 3,3 and 13,13 on ket side (note order doesn't matter)
u = Wires([1, 5], [2, 6, 15], [3, 7, 13], [4, 8])
v = Wires([0, 9, 14], [1, 10], [2, 11], [13, 3, 12])
assert (u >> v)._args() == ((0, 5, 9, 14), (2, 6, 10, 15), (2, 7, 11), (4, 8, 12))
assert (u @ v)._args() == ((0, 5, 9, 14), (2, 6, 10, 15), (2, 7, 11), (4, 8, 12))

def test_rshift_error(self):
def test_matmul_error(self):
u = Wires([], [], [0], []) # only output wire
v = Wires([], [], [0], []) # only output wire
with pytest.raises(ValueError):
u >> v # pylint: disable=pointless-statement
u @ v # pylint: disable=pointless-statement
4 changes: 2 additions & 2 deletions tests/test_physics/test_triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_attenuator_Abc(self):
e = 0.31622777
assert np.allclose(A1, [[0, e, 0, 0], [e, 0, 0, 0.9], [0, 0, 0, e], [0, 0.9, e, 0]])
assert np.allclose(b1, 0)
assert np.allclose(c1, 0.1)
assert np.allclose(c1, 1.0)

A2, b2, c2 = triples.attenuator_Abc([0.1, 1])
e = 0.31622777
Expand All @@ -238,7 +238,7 @@ def test_attenuator_Abc(self):
],
)
assert np.allclose(b2, 0)
assert np.allclose(c2, 0.1)
assert np.allclose(c2, 1.0)

def test_attenuator_Abc_error(self):
with pytest.raises(ValueError, match="in the interval"):
Expand Down
Loading