Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sympy error #5930

Merged
merged 9 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ParamResolver:

Raises:
TypeError if formulas are passed as keys.
ValueError if the resulting value cannot be interpreted.
"""

def __new__(cls, param_dict: 'cirq.ParamResolverOrSimilarType' = None):
Expand Down Expand Up @@ -179,7 +180,13 @@ def value_of(
if not recursive:
# Resolves one step at a time. For example:
# a.subs({a: b, b: c}) == b
v = value.subs(self.param_dict, simultaneous=True)
try:
v = value.subs(self.param_dict, simultaneous=True)
except sympy.SympifyError as e: # coverage: ignore
# Lines will be covered in sympy 1.12+
raise ValueError(
f'Could not resolve parameter {value}, underlying error {e}'
) # coverage: ignore
if v.free_symbols:
return v
elif sympy.im(v):
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,11 @@ def _resolved_value_(self):
c = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: bar, c: baz})
assert r.value_of(a) is foo
assert r.value_of(b) is b
assert r.value_of(c) == 'Baz'

with pytest.raises(ValueError, match='Could not resolve parameter b'):
_ = r.value_of(b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this to a separate test function marked with a strict xfail, for example,

@pytest.mark.xfail(reason='this test requires sympy 1.12', strict=True)
def test_custom_value_not_implemented():
    ...

When the new sympy is released, the strict xfail will produce a CI failure
after which we can remove the xfail mark.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea done.



def test_compose():
"""Tests that cirq.resolve_parameters on a ParamResolver composes."""
Expand Down