-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize ParamResolver.value_of #6341
Conversation
param_value = self._param_dict.get(symbol, _NotFound) | ||
if param_value is _NotFound: | ||
# Symbol or string that is not in the param_dict cannot be resolved futher; return as symbol. | ||
return symbol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we decided to always return a symbol here even if resolving a string. My preference would be to return the given value unchanged if the resolver doesn't need to change it, but callers could rely on this behavior so we should be careful if we want to change it.
@@ -100,7 +100,7 @@ def test_sampler_multiple_jobs(): | |||
results = sampler.sample( | |||
program=circuit, | |||
repetitions=4, | |||
params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})], | |||
params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was coincidentally working before because the value_of
code was calling float
on the result after looking up in the param dict. But I think that behavior was unintentional because otherwise there would be no way to distinguish between strings used as symbols and strings containing floats. I can't think of any other place where we allow floats to be specified as strings, outside of serialization, and it certainly wasn't the intent to do so for parameter resolution.
Any chance of getting this into master this week? Would be great for accelerating some internal efforts :) |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6341 +/- ##
==========================================
- Coverage 97.84% 97.84% -0.01%
==========================================
Files 1110 1110
Lines 96597 96656 +59
==========================================
+ Hits 94516 94572 +56
- Misses 2081 2084 +3 ☔ View full report in Codecov by Sentry. |
cirq-core/cirq/study/resolver.py
Outdated
@@ -36,6 +36,9 @@ | |||
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values.""" | |||
) | |||
|
|||
# Used to mark values that are not found in a dict. | |||
_NotFound = object() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't global variables be named like _NOT_FOUND
according to the style guide? Same with RecursionFlag below if you want to clean it up.
go/pystyle#guidelines-derived-from-guidos-recommendations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to _NOT_FOUND
and _RECURSION_FLAG
.
This cleans up the logic in
ParamResolver.value_of
and makes a few optimizations:ParamResolver
, refer toself._param_dict
instead of theself.param_dict
property.ParamResolver.value_of
, handle the cases of resolving a string or symbol in a more straightforward way by checking for their presence in the param dict (when resolving a string we check for both a str and symbol, and similarly when resolving a symbol). If the str or symbol is not found, we can immediately return the symbol without calling through to sympy substitution logic, which was being invoked previously. If the value is found, we check whether it is a "pass through" resolvable value and whether it needs to be resolved recursively and then return.These changes ensure that sympy substitution logic is not invoked unless trying to resolve a complex expression, which fixes
test_value_of_substituted_types
so that it doesn't actually invoke substitution on basic symbols (the test was claiming to assert this while actually asserting the opposite), and also greatly improves performance when resolving missing symbols. On master, getting the value of a missing symbol invokes sympy and is quite slow:while on this branch the same operation is more than 10x faster:
Note that I also slightly changed the behavior of resolving a type with an explicit
_resolved_value_
method that returnsNotImplemented
. Previously this would raise an obscure sympy error, but here I've changed things so that the behavior of such an "explicitly not implemented" method is the same as if the method were omitted entirely (actually not implemented). In the latter case we were previously returning the original symbol, and now we do that for both cases. I think it's debatable whether we should in fact return the symbol in such cases, but this at least makes things more consistent.