Skip to content

Commit

Permalink
Merge pull request #306 from QunaSys/pauli_label_cache
Browse files Browse the repository at this point in the history
Cache PauliLabel instances in a `weakref` dictionary
  • Loading branch information
dchung0741 authored Jan 23, 2024
2 parents cc1c684 + 29e8f58 commit 07cd1a6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
14 changes: 13 additions & 1 deletion packages/core/quri_parts/core/operator/pauli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections.abc import Collection, Iterable, Mapping, Sequence, Set
from enum import IntEnum
from typing import Optional, Protocol, Union, cast, runtime_checkable
from weakref import WeakValueDictionary

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -40,6 +41,8 @@ class SinglePauli(IntEnum):
(SinglePauli.Z, SinglePauli.Z): None,
}

_pauli_cache: WeakValueDictionary[str, "PauliLabel"] = WeakValueDictionary()


def pauli_name(p: int) -> str:
"""Returns the name of Pauli matrix for a SinglePauli (int)"""
Expand Down Expand Up @@ -179,8 +182,17 @@ class PauliLabel(frozenset[tuple[int, int]]):
https://docs.python.org/3/library/collections.abc.html#collections.abc.Hashable
"""

def __new__(cls, arg: Iterable[tuple[int, int]] = ()) -> "PauliLabel":
instance = super().__new__(cls, arg) # type: ignore
pl_str = str(instance)
if pl_str in _pauli_cache:
return _pauli_cache[pl_str]
else:
_pauli_cache[pl_str] = instance
return instance

def __str__(self) -> str:
if self == PAULI_IDENTITY:
if len(self) == 0:
return "I"
return " ".join(
[SinglePauli(o).name + str(i) for i, o in sorted(self, key=lambda t: t[0])]
Expand Down
30 changes: 30 additions & 0 deletions packages/core/tests/core/operator/test_pauli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch
from weakref import WeakValueDictionary

import pytest

from quri_parts.core.operator import (
Expand Down Expand Up @@ -167,6 +170,33 @@ def test_index_and_pauli_id_list(self) -> None:
((2, SinglePauli.Y), (6, SinglePauli.Z), (4, SinglePauli.X))
)

def test_pauli_cache(self) -> None:
with patch( # type: ignore
"quri_parts.core.operator.pauli._pauli_cache", WeakValueDictionary()
) as pauli_cache:
# Checks if the cache works correctly.
pl_1 = pauli_label("X0 X1 Y2 Y3 Z4 Z5")
pl_2 = pauli_label("X0 X1 Y2 Y3 Z4 Z5")
assert id(pl_1) == id(pl_2)

cache_len = len(pauli_cache)
pl_2 = pauli_label("X0 X1 Y2 Y3 Z4 Z7")
assert id(pl_1) != id(pl_2)
assert len(pauli_cache) == cache_len + 1

pl_3 = pauli_label("X0 X1 Y2 Y3 Z4 Z6")
assert id(pl_1) != id(pl_3)

# Checks if the cache is cleared correctly.
cache_len = len(pauli_cache)
pl_3 = pauli_label("X0 X1 Y2 Y3 Z4 Z8")
assert str(pl_3) == "X0 X1 Y2 Y3 Z4 Z8"
assert id(pl_1) != id(pl_3)
assert len(pauli_cache) == cache_len

pl_3 = "QURI Parts" # type: ignore
assert len(pauli_cache) == cache_len - 1


class TestPauliProduct:
def test_pauli_product(self) -> None:
Expand Down

1 comment on commit 07cd1a6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.