Skip to content

Commit

Permalink
Reenable forces for jax. Fixed little typo.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 24, 2024
1 parent 322c37e commit c54858f
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions kgcnn/models/force.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import keras as ks
import keras.saving
from functools import partial
from typing import Union
from kgcnn.models.utils import get_model_class
from keras.saving import deserialize_keras_object, serialize_keras_object
Expand All @@ -13,9 +12,10 @@
import tensorflow as tf
elif backend() == "torch":
import torch
# elif backend() == "jax":
# import jax.numpy as jnp
# import jax
elif backend() == "jax":
import jax.numpy as jnp
import jax
from functools import partial
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())

Expand Down Expand Up @@ -135,7 +135,7 @@ def _call_grad_tf(self, inputs, training=False, **kwargs):
else:
x, splits = x_in, None
tape.watch(x)
eng = self.energy_model(inputs, training=training, **kwargs)
eng = self.energy_model.call(inputs, training=training, **kwargs)
eng_sum = tf.reduce_sum(eng, axis=0, keepdims=False)
e_grad = [eng_sum[i] for i in range(eng_sum.shape[-1])]
e_grad = [tf.expand_dims(tape.gradient(e_i, x), axis=-1) for e_i in e_grad]
Expand Down Expand Up @@ -164,21 +164,22 @@ def _call_grad_torch(self, inputs, training=False, **kwargs):
e_grad = torch.squeeze(e_grad, dim=-1)
return eng, e_grad

# def _call_grad_jax(self, inputs, training=False, **kwargs):
#
# def energy_reduce(*inputs, pos: int = 0):
# eng = self.energy_model(inputs, training=training, **kwargs)
# eng_sum = jnp.sum(eng, axis=0)[pos]
# return eng_sum
#
# grad_fn = jax.grad(energy_reduce, argnums=self.coordinate_input)
# all_grad = [grad_fn(*inputs, pos=i) for i in range(self._expected_energy_states)]
# eng = self.energy_model(inputs, training=training, **kwargs)
# e_grad = jnp.concatenate([jnp.expand_dims(x[1], axis=-1) for x in all_grad], axis=-1)
#
# if self.output_squeeze_states:
# e_grad = jnp.squeeze(e_grad, axis=-1)
# return eng, e_grad
def _call_grad_jax(self, inputs, training=False, **kwargs):

@partial(jax.jit, static_argnames=['pos'])
def energy_reduce(*inputs, pos: int = 0):
eng_temp = self.energy_model.call(inputs, training=training, **kwargs)
eng_sum = jnp.sum(eng_temp, axis=0)[pos]
return eng_sum

grad_fn = jax.grad(energy_reduce, argnums=self.coordinate_input)
all_grad = [grad_fn(*inputs, pos=i) for i in range(self._expected_energy_states)]
eng = self.energy_model.call(inputs, training=training, **kwargs)
e_grad = jnp.concatenate([jnp.expand_dims(x, axis=-1) for x in all_grad], axis=-1)

if self.output_squeeze_states:
e_grad = jnp.squeeze(e_grad, axis=-1)
return eng, e_grad

def call(self, inputs, training=False, **kwargs):

Expand Down

0 comments on commit c54858f

Please sign in to comment.