Skip to content

Commit

Permalink
Merge branch 'master' into device_config
Browse files Browse the repository at this point in the history
  • Loading branch information
Jose Urruticoechea committed Aug 28, 2023
2 parents 3218afe + 83609eb commit 873b522
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 17 deletions.
16 changes: 16 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits, inplace=True)
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
5 changes: 2 additions & 3 deletions cirq-core/cirq/sim/simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,8 @@ def test_delegating_gate_channel(exp):
control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
# TODO: This test should pass once we extend support to DensityMatrixSimulator.
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions cirq-google/cirq_google/api/v2/run_context.proto
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ message DeviceParameter {
repeated string path = 1;

// If the value is an array, the index of the array to change.
int64 idx = 2;
optional int64 idx = 2;

// String representation of the units, if any.
// Examples: "GHz", "ns", etc.
string units = 3;
optional string units = 3;

// Note that the device parameter values will be populated
// by the sweep values themselves.
Expand Down
16 changes: 8 additions & 8 deletions cirq-google/cirq_google/api/v2/run_context_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions cirq-google/cirq_google/api/v2/run_context_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
key = msg.single_sweep.parameter_key
if msg.single_sweep.HasField("parameter"):
metadata = DeviceParameter(
path=msg.single_sweep.parameter.path, idx=msg.single_sweep.parameter.idx
path=msg.single_sweep.parameter.path,
idx=msg.single_sweep.parameter.idx
if msg.single_sweep.parameter.HasField("idx")
else None,
units=msg.single_sweep.parameter.units
if msg.single_sweep.parameter.HasField("units")
else None,
)
else:
metadata = None
Expand Down
7 changes: 7 additions & 0 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def _values(self) -> Iterator[float]:
[1, 1.5, 2, 2.5, 3],
metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2, units='GHz'),
),
cirq.Points(
'b',
[1, 1.5, 2, 2.5, 3],
metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=None),
),
cirq.Linspace('a', 0, 1, 5) * cirq.Linspace('b', 0, 1, 5),
cirq.Points('a', [1, 2, 3]) + cirq.Linspace('b', 0, 1, 3),
(
Expand All @@ -69,6 +74,8 @@ def test_sweep_to_proto_roundtrip(sweep):
msg = v2.sweep_to_proto(sweep)
deserialized = v2.sweep_from_proto(msg)
assert deserialized == sweep
# Check that metadata is the same, if it exists.
assert getattr(deserialized, 'metadata', None) == getattr(sweep, 'metadata', None)


def test_sweep_to_proto_linspace():
Expand Down

0 comments on commit 873b522

Please sign in to comment.