-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Symbolical differentiation using TensorFlow #19
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complete review
with respect to circuit's variational parameters. | ||
|
||
""" | ||
import tensorflow as tf # pylint: disable=import-error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensorflow seems to be a pretty much mandatory dependency at the current stage, thus I would import it top level in any case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the new push, I think this import should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm not sure why, even because I'm not sure to what push you're referring to.
return gradient | ||
|
||
|
||
def symbolical( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be missing something here, but this seems more a combination of forward + backward, rather than purely the calculation of the gradient with respect to model's parameters. What I mean is that, to compute the gradient, here you are re-executing the circuit, which is supposed to have happened in the forward pass already, exactly as in the PSR (but you have no other choice there). I would expect the expectation
function to take care of the 'taping' and this one just taking as input what obtained there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a look at the tf.autodiff
module, but it seems that the only thing other than GradientTape
is a ForwardAccumulator
(so, unexpectedly, TensorFlow can also differentiate in forward mode, though it is not advertising it that much...).
However, it seems that the only way to trace the graph to compute the gradient is evaluating the function with a not-so-dummy value (i.e. an actual value), that is used at the same time as a placeholder and as the value to compute the actual function.
I haven't found the way to decouple the two things.
The only thing we could save is the circuit.get_parameters()
, that could all replaced with 0s. But we need to know how many of them... (and possibly the whole layout, since it's using a list of tuples...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing we could save is the
circuit.get_parameters()
, that could all replaced with 0s. But we need to know how many of them... (and possibly the whole layout, since it's using a list of tuples...)
I hope I'll be able to improve on this in qibo-core
, qiboteam/qibo-core#22.
But I'm just advertising that issue to trigger your feedback, the benefit of not using the actual parameters in this case is mostly conceptual. It would be nice to avoid doing the work twice, but we also want a TensorFlow independent way to evaluate the function in expectation
.
So, for the time being, we are paying the price of evaluating twice...
Maybe, in this case, JAX/autograd is truly superior. Though it depends on the evaluation overhead, of course.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added the jax version of this function. Even if it is ideally better, right now TF is faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's quite possible that, even ideally, the speedup is rather limited. Evaluating the function could be cheaper than evaluating the gradient (because the gradient has an extra dimension on the derivation variables).
Still, it's pretty much unexpected that one is slower than the other...
Are you comparing TF-derivation-with-TF-backend, i.e. (TF, sym, TF)
, with (JAX, sym, JAX)
?
It could be that the slowdown is just the same of the JAX backend execution, and nothing new.
Co-authored-by: BrunoLiegiBastonLiegi <45011234+BrunoLiegiBastonLiegi@users.noreply.github.com>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #19 +/- ##
==========================================
- Coverage 16.62% 15.66% -0.97%
==========================================
Files 8 8
Lines 439 466 +27
==========================================
Hits 73 73
- Misses 366 393 +27
Flags with carried forward coverage won't be shown. Click here to find out more.
|
def _expectation(params): | ||
params = jax.numpy.array(params) | ||
circuit.set_parameters(params) | ||
return expectation._exact(hamiltonian, circuit, initial_state, exec_backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the function I was telling you yesterday: if this was already available, you didn't have to write it here, but just pass this function as input, and then this whole function would be just jax.grad()
, plus the manipulations you need to make it compatible.
I.e. all the inputs of symbolical_with_jax
are used in a single place, so the only input you need is the result of the operation in that place.
return expectation._exact(hamiltonian, circuit, initial_state, exec_backend) | ||
|
||
return jax.numpy.array( | ||
[g[0].item() for g in jax.grad(_expectation)(circuit.get_parameters())] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need all this gymnastics with [0].item()
? If you know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the way I managed to make it work compatibly with the expectation interface. Maybe some easier solution can be found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But why is jax.grad(_expectation)(circuit.get_parameters())
incompatible?
Add symbolical differentiation to
operations.differentiation
collection.