Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolical differentiation using TensorFlow #19

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open

Conversation

MatteoRobbiati
Copy link
Contributor

@MatteoRobbiati MatteoRobbiati commented May 13, 2024

Add symbolical differentiation to operations.differentiation collection.

@MatteoRobbiati MatteoRobbiati marked this pull request as ready for review May 15, 2024 12:16
Copy link
Member

@alecandido alecandido left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial review

src/qiboml/operations/differentiation.py Outdated Show resolved Hide resolved
src/qiboml/operations/differentiation.py Outdated Show resolved Hide resolved
src/qiboml/operations/differentiation.py Outdated Show resolved Hide resolved
src/qiboml/operations/differentiation.py Outdated Show resolved Hide resolved
Copy link
Member

@alecandido alecandido left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete review

src/qiboml/operations/differentiation.py Outdated Show resolved Hide resolved
@MatteoRobbiati MatteoRobbiati requested a review from alecandido May 15, 2024 13:19
pyproject.toml Outdated Show resolved Hide resolved
src/qiboml/operations/differentiation.py Outdated Show resolved Hide resolved
with respect to circuit's variational parameters.

"""
import tensorflow as tf # pylint: disable=import-error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorflow seems to be a pretty much mandatory dependency at the current stage, thus I would import it top level in any case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new push, I think this import should be fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not sure why, even because I'm not sure to what push you're referring to.

return gradient


def symbolical(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something here, but this seems more a combination of forward + backward, rather than purely the calculation of the gradient with respect to model's parameters. What I mean is that, to compute the gradient, here you are re-executing the circuit, which is supposed to have happened in the forward pass already, exactly as in the PSR (but you have no other choice there). I would expect the expectation function to take care of the 'taping' and this one just taking as input what obtained there.

Copy link
Member

@alecandido alecandido May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look at the tf.autodiff module, but it seems that the only thing other than GradientTape is a ForwardAccumulator (so, unexpectedly, TensorFlow can also differentiate in forward mode, though it is not advertising it that much...).

However, it seems that the only way to trace the graph to compute the gradient is evaluating the function with a not-so-dummy value (i.e. an actual value), that is used at the same time as a placeholder and as the value to compute the actual function.

I haven't found the way to decouple the two things.

The only thing we could save is the circuit.get_parameters(), that could all replaced with 0s. But we need to know how many of them... (and possibly the whole layout, since it's using a list of tuples...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing we could save is the circuit.get_parameters(), that could all replaced with 0s. But we need to know how many of them... (and possibly the whole layout, since it's using a list of tuples...)

I hope I'll be able to improve on this in qibo-core, qiboteam/qibo-core#22.
But I'm just advertising that issue to trigger your feedback, the benefit of not using the actual parameters in this case is mostly conceptual. It would be nice to avoid doing the work twice, but we also want a TensorFlow independent way to evaluate the function in expectation.
So, for the time being, we are paying the price of evaluating twice...

Maybe, in this case, JAX/autograd is truly superior. Though it depends on the evaluation overhead, of course.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added the jax version of this function. Even if it is ideally better, right now TF is faster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's quite possible that, even ideally, the speedup is rather limited. Evaluating the function could be cheaper than evaluating the gradient (because the gradient has an extra dimension on the derivation variables).

Still, it's pretty much unexpected that one is slower than the other...

Are you comparing TF-derivation-with-TF-backend, i.e. (TF, sym, TF), with (JAX, sym, JAX)?
It could be that the slowdown is just the same of the JAX backend execution, and nothing new.

Co-authored-by: BrunoLiegiBastonLiegi <45011234+BrunoLiegiBastonLiegi@users.noreply.github.com>
Copy link

codecov bot commented May 17, 2024

Codecov Report

Attention: Patch coverage is 0% with 31 lines in your changes missing coverage. Please review.

Project coverage is 15.66%. Comparing base (a9e7f9e) to head (7f243c3).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #19      +/-   ##
==========================================
- Coverage   16.62%   15.66%   -0.97%     
==========================================
  Files           8        8              
  Lines         439      466      +27     
==========================================
  Hits           73       73              
- Misses        366      393      +27     
Flag Coverage Δ
unittests 15.66% <0.00%> (-0.97%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
src/qiboml/backends/tensorflow.py 0.00% <0.00%> (ø)
src/qiboml/backends/__init__.py 0.00% <0.00%> (ø)
src/qiboml/backends/jax.py 0.00% <0.00%> (ø)

Comment on lines +168 to +171
def _expectation(params):
params = jax.numpy.array(params)
circuit.set_parameters(params)
return expectation._exact(hamiltonian, circuit, initial_state, exec_backend)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the function I was telling you yesterday: if this was already available, you didn't have to write it here, but just pass this function as input, and then this whole function would be just jax.grad(), plus the manipulations you need to make it compatible.

I.e. all the inputs of symbolical_with_jax are used in a single place, so the only input you need is the result of the operation in that place.

return expectation._exact(hamiltonian, circuit, initial_state, exec_backend)

return jax.numpy.array(
[g[0].item() for g in jax.grad(_expectation)(circuit.get_parameters())]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need all this gymnastics with [0].item()? If you know

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the way I managed to make it work compatibly with the expectation interface. Maybe some easier solution can be found.

Copy link
Member

@alecandido alecandido Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why is jax.grad(_expectation)(circuit.get_parameters()) incompatible?

pyproject.toml Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants