Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ParamResolver.value_of #6341

Merged
merged 9 commits into from
Nov 15, 2023
Merged

Optimize ParamResolver.value_of #6341

merged 9 commits into from
Nov 15, 2023

Conversation

maffoo
Copy link
Contributor

@maffoo maffoo commented Nov 9, 2023

This cleans up the logic in ParamResolver.value_of and makes a few optimizations:

  • Throughout ParamResolver, refer to self._param_dict instead of the self.param_dict property.
  • In ParamResolver.value_of, handle the cases of resolving a string or symbol in a more straightforward way by checking for their presence in the param dict (when resolving a string we check for both a str and symbol, and similarly when resolving a symbol). If the str or symbol is not found, we can immediately return the symbol without calling through to sympy substitution logic, which was being invoked previously. If the value is found, we check whether it is a "pass through" resolvable value and whether it needs to be resolved recursively and then return.

These changes ensure that sympy substitution logic is not invoked unless trying to resolve a complex expression, which fixes test_value_of_substituted_types so that it doesn't actually invoke substitution on basic symbols (the test was claiming to assert this while actually asserting the opposite), and also greatly improves performance when resolving missing symbols. On master, getting the value of a missing symbol invokes sympy and is quite slow:

In [1]: import cirq, sympy
In [2]: %time cirq.ParamResolver({"a": 1, "b": 2}).value_of(sympy.Symbol("foo"))
CPU times: user 843 µs, sys: 691 µs, total: 1.53 ms
Wall time: 1.55 ms

while on this branch the same operation is more than 10x faster:

In [1]: import cirq, sympy
In [2]: %time cirq.ParamResolver({"a": 1, "b": 2}).value_of(sympy.Symbol("foo"))
CPU times: user 69 µs, sys: 56 µs, total: 125 µs
Wall time: 130 µs

Note that I also slightly changed the behavior of resolving a type with an explicit _resolved_value_ method that returns NotImplemented. Previously this would raise an obscure sympy error, but here I've changed things so that the behavior of such an "explicitly not implemented" method is the same as if the method were omitted entirely (actually not implemented). In the latter case we were previously returning the original symbol, and now we do that for both cases. I think it's debatable whether we should in fact return the symbol in such cases, but this at least makes things more consistent.

@maffoo maffoo requested review from vtomole, cduck and a team as code owners November 9, 2023 05:03
@maffoo maffoo requested a review from pavoljuhas November 9, 2023 05:03
param_value = self._param_dict.get(symbol, _NotFound)
if param_value is _NotFound:
# Symbol or string that is not in the param_dict cannot be resolved futher; return as symbol.
return symbol
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we decided to always return a symbol here even if resolving a string. My preference would be to return the given value unchanged if the resolver doesn't need to change it, but callers could rely on this behavior so we should be careful if we want to change it.

@@ -100,7 +100,7 @@ def test_sampler_multiple_jobs():
results = sampler.sample(
program=circuit,
repetitions=4,
params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})],
params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was coincidentally working before because the value_of code was calling float on the result after looking up in the param dict. But I think that behavior was unintentional because otherwise there would be no way to distinguish between strings used as symbols and strings containing floats. I can't think of any other place where we allow floats to be specified as strings, outside of serialization, and it certainly wasn't the intent to do so for parameter resolution.

@andbe91
Copy link
Collaborator

andbe91 commented Nov 14, 2023

Any chance of getting this into master this week? Would be great for accelerating some internal efforts :)

Copy link

codecov bot commented Nov 15, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (0e288a7) 97.84% compared to head (68631ed) 97.84%.
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6341      +/-   ##
==========================================
- Coverage   97.84%   97.84%   -0.01%     
==========================================
  Files        1110     1110              
  Lines       96597    96656      +59     
==========================================
+ Hits        94516    94572      +56     
- Misses       2081     2084       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -36,6 +36,9 @@
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values."""
)

# Used to mark values that are not found in a dict.
_NotFound = object()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't global variables be named like _NOT_FOUND according to the style guide? Same with RecursionFlag below if you want to clean it up.
go/pystyle#guidelines-derived-from-guidos-recommendations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _NOT_FOUND and _RECURSION_FLAG.

@maffoo maffoo enabled auto-merge (squash) November 15, 2023 03:33
@maffoo maffoo merged commit 8d07cab into master Nov 15, 2023
35 checks passed
@maffoo maffoo deleted the u/maffoo/resolver branch November 15, 2023 04:00
harry-phasecraft pushed a commit to PhaseCraft/Cirq that referenced this pull request Oct 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants