Skip to content

Commit

Permalink
fix: update
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Sep 10, 2024
1 parent 6bd798b commit bdd803d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 10 deletions.
12 changes: 11 additions & 1 deletion src/qiboml/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def forward(self, x): # pragma: no cover
def __call__(self, x):
return self.forward(x)

@property
def has_parameters(self):
if len(self.parameters) > 0:
return True
return False

@property
def parameters(self) -> ndarray:
return self.backend.cast(self.circuit.get_parameters(), self.backend.np.float64)
Expand All @@ -47,7 +53,11 @@ def circuit(self) -> Circuit:
return self._circuit


def _run_layers(x: ndarray, layers: list[QuantumCircuitLayer]):
def _run_layers(x: ndarray, layers: list[QuantumCircuitLayer], parameters):
index = 0
for layer in layers:
if layer.has_parameters:
layer.parameters = parameters[index]
index += 1
x = layer.forward(x)
return x
12 changes: 7 additions & 5 deletions src/qiboml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, x: torch.Tensor):
x, self.layers, self.backend, self.differentiation, *self.parameters()
)
else:
x = _run_layers(x, self.layers)
x = _run_layers(x, self.layers, self.parameters)
return x

@property
Expand All @@ -82,19 +82,21 @@ def forward(
differentiation,
*parameters,
):
ctx.save_for_backward(x)
ctx.save_for_backward(x, *parameters)
ctx.layers = layers
ctx.differentiation = differentiation
ctx.backend = backend
x_clone = x.clone().detach().numpy()
x_clone = backend.cast(x_clone, dtype=x_clone.dtype)
x_clone = torch.as_tensor(np.array(_run_layers(x_clone, layers)))
x_clone = torch.as_tensor(np.array(_run_layers(x_clone, layers, parameters)))
x_clone.requires_grad = True
return x_clone

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
(x,) = ctx.saved_tensors
(
x,
*parameters,
) = ctx.saved_tensors
gradients = [
torch.as_tensor(grad)
for grad in ctx.differentiation.evaluate(x, ctx.layers)
Expand Down
34 changes: 33 additions & 1 deletion src/qiboml/operations/differentiation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import jax
import jax.numpy as jnp
import numpy as np
from qibo import parameter
from qibo.backends import construct_backend
Expand Down Expand Up @@ -38,7 +40,7 @@ def _evaluate_for_parameter(self, x, layers, layer, index, parameters_bkup):
outputs = []
for shift in self._shift_parameters(layer.parameters, index, self.epsilon):
layer.parameters = shift
outputs.append(_run_layers(x, layers))
outputs.append(_run_layers(x, layers, [l.parameters for l in layers]))
layer.parameters = parameters_bkup
return (outputs[0] - outputs[1]) * self.scale_factor

Expand All @@ -51,6 +53,36 @@ def _shift_parameters(parameters: ndarray, index: int, epsilon: float):
return forward, backward


class Jax:

def __init__(self):
self._input = None
self._layers = None

def evaluate(self, x: ndarray, layers: list[QuantumCircuitLayer]):
self._input = x
self.layers = layers
parameters = []
for layer in layers:
if layer.has_parameters:
parameters.extend(layer.parameters.ravel())
parameters = jnp.asarray(parameters)
breakpoint()
return jax.jacfwd(self._run)(parameters)

def _run(self, parameters):
breakpoint()
grouped_parameters = []
left_index = right_index = 0
for layer in self.layers:
if layer.has_parameters:
right_index += len(layer.parameters)
grouped_parameters.append(parameters[left_index:right_index])
left_index = right_index
breakpoint()
return _run_layers(self._input, self.layers, grouped_parameters)


def parameter_shift(
hamiltonian,
circuit,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
encoding_layer,
training_layer,
decoding_layer,
]
],
differentiation="Jax",
)
print(list(q_model.parameters()))
data = torch.randn(1, 5)
data.requires_grad = True
# data.requires_grad = True
out = q_model(data)
print(out.requires_grad)
loss = (out - 1.0) ** 2
print(loss.requires_grad)
loss.backward()
print(loss)

0 comments on commit bdd803d

Please sign in to comment.