Skip to content

Commit

Permalink
improve hashing in Kernel/Discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
junnaka51 committed Jun 30, 2023
1 parent d44dd11 commit 0afd18d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
24 changes: 22 additions & 2 deletions qiskit/pulse/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ def _assert_nested_dict_equal(a, b):
return True


def _get_hash(obj: dict):
obj_ = obj.copy()
for k, v in obj.items():
if isinstance(v, dict):
obj_[k] = _get_hash(v)
elif isinstance(v, np.ndarray):
obj_[k] = v.tobytes()
elif isinstance(v, list):
v_ = []
for i in v:
if isinstance(i, dict):
v_.append(_get_hash(i))
else:
v_.append(i)
obj_[k] = tuple(v_)
else:
obj_[k] = v
return hash(tuple(obj_.items()))


class Kernel:
"""Settings for this Kernel, which is responsible for integrating time series (raw) data
into IQ points.
Expand Down Expand Up @@ -67,7 +87,7 @@ def __eq__(self, other):
return False

def __hash__(self):
return hash(repr(self))
return _get_hash(self.__dict__)


class Discriminator:
Expand Down Expand Up @@ -98,7 +118,7 @@ def __eq__(self, other):
return False

def __hash__(self):
return hash(repr(self))
return _get_hash(self.__dict__)


class LoRange:
Expand Down
42 changes: 42 additions & 0 deletions test/python/pulse/test_experiment_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,27 @@ def test_neq_nested_params(self):
)
self.assertFalse(kernel_a == kernel_b)

def test_hash(self):
"""Test if hash is implemented correctly."""
kernel_a = Kernel(
"kernel_test",
kernel={"real": np.zeros(10), "imag": np.zeros(10)},
bias=[0, 0],
)
kernel_b = Kernel(
"kernel_test",
kernel={"real": np.zeros(10), "imag": np.zeros(10)},
bias=[0, 0],
)
self.assertTrue(hash(kernel_a) == hash(kernel_b))

kernel_c = Kernel(
"kernel_test",
bias=[0, 0],
kernel={"real": np.zeros(10), "imag": np.zeros(10)},
)
self.assertFalse(hash(kernel_a) == hash(kernel_c))


class TestDiscriminator(QiskitTestCase):
"""Test Discriminator."""
Expand Down Expand Up @@ -199,6 +220,27 @@ def test_neq_params(self):
)
self.assertFalse(discriminator_a == discriminator_b)

def test_hash(self):
"""Test if hash is implemented correctly."""
discriminator_a = Discriminator(
"discriminator_test",
discriminator_type="linear",
neighborhoods=[{"qubits": 1, "channels": 1}],
)
discriminator_b = Discriminator(
"discriminator_test",
discriminator_type="linear",
neighborhoods=[{"qubits": 1, "channels": 1}],
)
self.assertTrue(hash(discriminator_a) == hash(discriminator_b))

discriminator_c = Discriminator(
"discriminator_test",
discriminator_type="linear",
neighborhoods=[{"channels": 1, "qubits": 1}],
)
self.assertFalse(hash(discriminator_a) == hash(discriminator_c))


if __name__ == "__main__":
unittest.main()

0 comments on commit 0afd18d

Please sign in to comment.