Skip to content

Commit

Permalink
Reject formulas as keys of ParamResolvers (#5384)
Browse files Browse the repository at this point in the history
* Reject formulas as keys of ParamResolvers

- A ParamResolver resolves variables into values.
- Having non-trivial formulas as keys allows a significant
complexity and ambiguity into ParamResolvers, since it is
unclear how much is supported.  Prevent this case altogether
by raising an error if non-symbol formulas are used in ParamResolvers.

Fixes: #3550
  • Loading branch information
dstrain115 authored May 20, 2022
1 parent 1b7a800 commit 95bebae
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
6 changes: 6 additions & 0 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class ParamResolver:
Attributes:
param_dict: A dictionary from the ParameterValue key (str) to its
assigned value.
Raises:
TypeError if formulas are passed as keys.
"""

def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None):
Expand All @@ -68,6 +71,9 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None

self._param_hash: Optional[int] = None
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}

def value_of(
Expand Down
20 changes: 3 additions & 17 deletions cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,13 @@ def test_param_dict_iter():


def test_formulas_in_param_dict():
"""Test formulas in a `param_dict`.
Param dicts are allowed to have str or sympy.Symbol as keys and
floats or sympy.Symbol as values. This should not be a common use case,
but this tests makes sure something reasonable is returned when
mixing these types and using formulas in ParamResolvers.
Note that sympy orders expressions for deterministic resolution, so
depending on the operands sent to sub(), the expression may not fully
resolve if it needs to take several iterations of resolution.
"""
"""Tests that formula keys are rejected in a `param_dict`."""
a = sympy.Symbol('a')
b = sympy.Symbol('b')
c = sympy.Symbol('c')
e = sympy.Symbol('e')
r = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})
assert sympy.Eq(r.value_of('a'), 3)
assert sympy.Eq(r.value_of('b'), 2)
assert sympy.Eq(r.value_of(b + c), 101)
assert sympy.Eq(r.value_of('c'), c)
assert sympy.Eq(r.value_of('d'), 2 * e)
with pytest.raises(TypeError, match='formula'):
_ = cirq.ParamResolver({a: b + 1, b: 2, b + c: 101, 'd': 2 * e})


def test_recursive_evaluation():
Expand Down
6 changes: 2 additions & 4 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def sweep_to_proto(
sweep_dict: Dict[str, List[float]] = {}
for param_resolver in sweep:
for key in param_resolver:
if isinstance(key, sympy.Expr):
raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}')
if key not in sweep_dict:
sweep_dict[key] = []
sweep_dict[key].append(cast(float, param_resolver.value_of(key)))
sweep_dict[cast(str, key)] = []
sweep_dict[cast(str, key)].append(cast(float, param_resolver.value_of(key)))
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
for key in sweep_dict:
sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add())
Expand Down
17 changes: 14 additions & 3 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,20 @@ def test_sweep_to_proto_linspace():


def test_list_sweep_bad_expression():
sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})])
with pytest.raises(ValueError, match='cannot convert'):
v2.sweep_to_proto(sweep)
with pytest.raises(TypeError, match='formula'):
_ = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})])


def test_symbol_to_string_conversion():
sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a'): 4.0})])
proto = v2.sweep_to_proto(sweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
expected = v2.run_context_pb2.Sweep()
expected.sweep_function.function_type = v2.run_context_pb2.SweepFunction.ZIP
p1 = expected.sweep_function.sweeps.add()
p1.single_sweep.parameter_key = 'a'
p1.single_sweep.points.points.extend([4.0])
assert proto == expected


def test_sweep_to_proto_points():
Expand Down

0 comments on commit 95bebae

Please sign in to comment.