Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolical differentiation using TensorFlow #19

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2de2807
Merge branch 'backends' into jax_custom_grad
MatteoRobbiati May 6, 2024
2a5155d
wip on jax custom diff rules
MatteoRobbiati May 6, 2024
e3c6a79
Add jax diff rule
MatteoRobbiati May 6, 2024
f8af7fa
frontend -> qibo_backend
MatteoRobbiati May 6, 2024
db9f20c
fix: cleaning jax tracing
MatteoRobbiati May 7, 2024
ee7a82d
making jax great again
MatteoRobbiati May 7, 2024
b92a335
tutorials: jax training procedure vs TF
MatteoRobbiati May 7, 2024
6613d34
refactor: passing the backend as argument
MatteoRobbiati May 8, 2024
bf75105
fix: default numpy
MatteoRobbiati May 13, 2024
22b89e8
feat: draft symbolical differentiation using tf
MatteoRobbiati May 13, 2024
cddafb4
refactor: gradients are computed all together in symbolical and param…
MatteoRobbiati May 15, 2024
cffb935
refactor: adapt expectation to new differentiation rules
MatteoRobbiati May 15, 2024
afbdcb3
temporary: add a couple of scripts to test the diff mechanism
MatteoRobbiati May 15, 2024
c85dcfe
chore: update dependencies
MatteoRobbiati May 15, 2024
44d0a1a
refactor: moving codecov instructions
MatteoRobbiati May 15, 2024
e3f2087
refactor: renaming gradients into gradient, removing defaults
MatteoRobbiati May 15, 2024
80d5832
fix: remove scale factors and do not expose one_parameter_shift
MatteoRobbiati May 15, 2024
8ccfe57
refactor: remove defaults from _one_parameter_shift arguments
MatteoRobbiati May 15, 2024
1fc6e2d
fix: conflicts after merging main
MatteoRobbiati May 15, 2024
26f6d8e
Apply suggestions from code review
MatteoRobbiati May 17, 2024
a582cd3
fix: resolve conflicts with main
MatteoRobbiati May 17, 2024
b97edae
Merge branch 'sym' of github.com:qiboteam/qiboml into sym
MatteoRobbiati May 17, 2024
e9c0a6c
feat: symbolical differentiation with jax
MatteoRobbiati May 17, 2024
5114182
fix: resolve conflicts with main
MatteoRobbiati May 17, 2024
a724b05
fix: solving conflicts with other jax work branch
MatteoRobbiati May 17, 2024
720cdb6
feat: add jax symbolical gradients and fix all to all execution
MatteoRobbiati May 17, 2024
d23da01
test: playing with a test script
MatteoRobbiati May 17, 2024
933bf36
chore: move qibojit to test deps
MatteoRobbiati May 17, 2024
b8f59bf
fix: remove personal test scripts
MatteoRobbiati May 17, 2024
13a29ee
fix: diffrule = None in expectation
MatteoRobbiati Jun 14, 2024
9fe21d5
temporary: test script for triads
MatteoRobbiati Jun 14, 2024
d09d1c2
chore: rm some deps
MatteoRobbiati Jun 14, 2024
401f603
working on jax automatic differentiation
MatteoRobbiati Jun 14, 2024
a524105
refactor: removing _with_jax
MatteoRobbiati Jun 17, 2024
db6df70
feat: expectation _with_torch
MatteoRobbiati Jun 18, 2024
6f3666d
refactor: expectation as function of parameters
MatteoRobbiati Jun 18, 2024
7f243c3
fix: args in _With_torch backward method
MatteoRobbiati Jun 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/.codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Reference docs at:
# https://docs.codecov.io/docs/codecovyml-reference

codecov:
require_ci_to_pass: yes
coverage:
precision: 2
round: down
range: "70...100"
status:
project:
default:
threshold: 1%

parsers:
gcov:
branch_detection:
conditional: yes
loop: yes
method: no
macro: no

comment:
layout: "reach,diff,flags,tree"
behavior: default
require_changes: no

github_checks:
annotations: false
463 changes: 200 additions & 263 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ packages = [{ include = "qiboml", from = "src" }]
[tool.poetry.dependencies]
python = ">=3.9,<3.12"
numpy = "^1.26.4"
numba = "^0.59.0"
tensorflow = { version = "^2.16.1", markers = "sys_platform == 'linux' or sys_platform == 'darwin'" }
# TODO: the marker is a temporary solution due to the lack of the tensorflow-io 0.32.0's wheels for Windows, this package is one of
# the tensorflow requirements
torch = "^2.2.0"
jax = "^0.4.25"
qibo = "^0.2.6"
qibo = "^0.2.8"

[tool.poetry.group.dev]
optional = true
Expand All @@ -35,6 +34,7 @@ pytest-cov = "^3.0.0"
pytest-env = "^0.8.1"
pytest-benchmark = { version = "^4.0.0", extras = ["histogram"] }


[tool.poe.tasks]
bench = "pytest benchmarks/"
test = "pytest"
Expand Down
4 changes: 3 additions & 1 deletion src/qiboml/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from qiboml.backends.tensorflow import TensorflowBackend
from qiboml.backends.jax import JaxBackend
from qiboml.backends.pytorch import PyTorchBackend
from qiboml.backends.tensorflow import TensorflowBackend
43 changes: 40 additions & 3 deletions src/qiboml/backends/jax.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,50 @@
from qibo import __version__
from qibo.backends import einsum_utils
from qibo.backends.npmatrices import NumpyMatrices
from qibo.backends.numpy import NumpyBackend
from qibo.config import raise_error


class JaxMatrices(NumpyMatrices):
def __init__(self, dtype):
super().__init__(dtype)
import jax # pylint: disable=import-error
import jax.numpy as jnp # pylint: disable=import-error

self.jax = jax
self.np = jnp

def _cast(self, x, dtype):
return self.np.array(x, dtype=dtype)

def Unitary(self, u):
return self._cast(u, dtype=self.dtype)


class JaxBackend(NumpyBackend):
def __init__(self):
super().__init__()
self.name = "jax"

import jax
import jax.numpy as jnp # pylint: disable=import-error
import numpy
import numpy as np

jax.config.update("jax_enable_x64", True)

self.jax = jax
self.numpy = numpy
self.numpy = np

self.np = jnp
self.tensor_types = (jnp.ndarray, numpy.ndarray)

self.versions = {
"qibo": __version__,
"numpy": np.__version__,
"tensorflow": jax.__version__,
}

self.matrices = JaxMatrices(self.dtype)
self.tensor_types = (jnp.ndarray, np.ndarray)

def set_precision(self, precision):
if precision != self.precision:
Expand Down Expand Up @@ -72,6 +97,18 @@ def plus_density_matrix(self, nqubits):
state /= 2**nqubits
return state

def matrix(self, gate):
npmatrix = super().matrix(gate)
return self.np.array(npmatrix, dtype=self.dtype)

def matrix_parametrized(self, gate):
npmatrix = super().matrix_parametrized(gate)
return self.np.array(npmatrix, dtype=self.dtype)

def matrix_fused(self, gate):
npmatrix = super().matrix_fused(gate)
return self.np.array(npmatrix, dtype=self.dtype)

def update_frequencies(self, frequencies, probabilities, nsamples):
samples = self.sample_shots(probabilities, nsamples)
res, counts = self.np.unique(samples, return_counts=True)
Expand Down
4 changes: 4 additions & 0 deletions src/qiboml/backends/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def __init__(self):
super().__init__()
self.name = "tensorflow"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(TF_LOG_LEVEL)

import tensorflow as tf # pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pylint: disable=import-error

if TF_LOG_LEVEL >= 2:
tf.get_logger().setLevel("ERROR")

tnp.experimental_enable_numpy_behavior()
self.tf = tf
self.np = tnp
Expand Down
Loading