diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index e2f82cce1c7..35c3889348c 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -55,6 +55,9 @@ class ParamResolver: Attributes: param_dict: A dictionary from the ParameterValue key (str) to its assigned value. + + Raises: + TypeError if formulas are passed as keys. """ def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None): @@ -68,6 +71,9 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None self._param_hash: Optional[int] = None self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) + for key in self.param_dict: + if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol): + raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') self._deep_eval_map: ParamDictType = {} def value_of( diff --git a/cirq-core/cirq/study/resolver_test.py b/cirq-core/cirq/study/resolver_test.py index c612448b7c9..8e26abd9d7e 100644 --- a/cirq-core/cirq/study/resolver_test.py +++ b/cirq-core/cirq/study/resolver_test.py @@ -156,27 +156,13 @@ def test_param_dict_iter(): def test_formulas_in_param_dict(): - """Test formulas in a `param_dict`. - - Param dicts are allowed to have str or sympy.Symbol as keys and - floats or sympy.Symbol as values. This should not be a common use case, - but this tests makes sure something reasonable is returned when - mixing these types and using formulas in ParamResolvers. - - Note that sympy orders expressions for deterministic resolution, so - depending on the operands sent to sub(), the expression may not fully - resolve if it needs to take several iterations of resolution. - """ + """Tests that formula keys are rejected in a `param_dict`.""" a = sympy.Symbol('a') b = sympy.Symbol('b') c = sympy.Symbol('c') e = sympy.Symbol('e') - r = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e}) - assert sympy.Eq(r.value_of('a'), 3) - assert sympy.Eq(r.value_of('b'), 2) - assert sympy.Eq(r.value_of(b + c), 101) - assert sympy.Eq(r.value_of('c'), c) - assert sympy.Eq(r.value_of('d'), 2 * e) + with pytest.raises(TypeError, match='formula'): + _ = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e}) def test_recursive_evaluation(): diff --git a/cirq-google/cirq_google/api/v2/sweeps.py b/cirq-google/cirq_google/api/v2/sweeps.py index da540da4b74..83210e86315 100644 --- a/cirq-google/cirq_google/api/v2/sweeps.py +++ b/cirq-google/cirq_google/api/v2/sweeps.py @@ -61,11 +61,9 @@ def sweep_to_proto( sweep_dict: Dict[str, List[float]] = {} for param_resolver in sweep: for key in param_resolver: - if isinstance(key, sympy.Expr): - raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}') if key not in sweep_dict: - sweep_dict[key] = [] - sweep_dict[key].append(cast(float, param_resolver.value_of(key))) + sweep_dict[cast(str, key)] = [] + sweep_dict[cast(str, key)].append(cast(float, param_resolver.value_of(key))) out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP for key in sweep_dict: sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add()) diff --git a/cirq-google/cirq_google/api/v2/sweeps_test.py b/cirq-google/cirq_google/api/v2/sweeps_test.py index 6acc3264d07..d4d1c2b8cbe 100644 --- a/cirq-google/cirq_google/api/v2/sweeps_test.py +++ b/cirq-google/cirq_google/api/v2/sweeps_test.py @@ -73,9 +73,20 @@ def test_sweep_to_proto_linspace(): def test_list_sweep_bad_expression(): - sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})]) - with pytest.raises(ValueError, match='cannot convert'): - v2.sweep_to_proto(sweep) + with pytest.raises(TypeError, match='formula'): + _ = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})]) + + +def test_symbol_to_string_conversion(): + sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a'): 4.0})]) + proto = v2.sweep_to_proto(sweep) + assert isinstance(proto, v2.run_context_pb2.Sweep) + expected = v2.run_context_pb2.Sweep() + expected.sweep_function.function_type = v2.run_context_pb2.SweepFunction.ZIP + p1 = expected.sweep_function.sweeps.add() + p1.single_sweep.parameter_key = 'a' + p1.single_sweep.points.points.extend([4.0]) + assert proto == expected def test_sweep_to_proto_points():