Skip to content

Commit

Permalink
Forces for jax are still not working. Removed it again.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 23, 2024
1 parent 177f133 commit 322c37e
Show file tree
Hide file tree
Showing 7 changed files with 842 additions and 53 deletions.
2 changes: 1 addition & 1 deletion kgcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
__safe_scatter_max_min_to_zero__ = True

# Geometry
__geom_euclidean_norm_add_eps__ = False
__geom_euclidean_norm_add_eps__ = True # Set to false for exact sqrt computation for geometric layers.
__geom_euclidean_norm_no_nan__ = True # Only used for inverse norm.
2 changes: 1 addition & 1 deletion kgcnn/backend/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def repeat_static_length(x, repeats, axis=None, total_repeat_length: int = None)


def decompose_ragged_tensor(x):
raise NotImplementedError("Operation supported this backend '%s'." % __name__)
raise NotImplementedError("Operation not supported by this backend '%s'." % __name__)


def norm(x, ord='fro', axis=None, keepdims=False):
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/backend/_numpy.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
def decompose_ragged_tensor(x):
raise NotImplementedError("Operation supported this backend '%s'." % __name__)
raise NotImplementedError("Operation not supported by this backend '%s'." % __name__)
17 changes: 13 additions & 4 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
The return tensor type is determined by :obj:`output_tensor_type` . Options are:
graph:
Tensor: Graph labels of shape `(batch, F)` .
Tensor: Graph labels of either type
- padded (Tensor): Batched tensor of graphs of shape `(batch, F)` .
- disjoint (Tensor): Tensor with potential disjoint padded graphs of shape `(batch, F)` .
- ragged (Tensor): Tensor of shape `(batch, F)` . No ragged dimension is needed here.
nodes:
Tensor: Node labels for the graph of either type:
Expand Down Expand Up @@ -54,9 +58,14 @@ def template_cast_output(model_outputs,

# Output embedding choice
if output_embedding == 'graph':
# Here we could also modify the behaviour for direct disjoint without removing the padding via
# remove_padded_disjoint_from_batched_output
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
if output_tensor_type in ["padded", "masked"]:
# Here we could also modify the behaviour for direct disjoint without removing the padding via
# remove_padded_disjoint_from_batched_output
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
elif output_tensor_type in ["ragged", "jagged"]:
pass
else:
pass
elif output_embedding == 'node':
if output_tensor_type in ["padded", "masked"]:
if "static_batched_node_output_shape" in cast_disjoint_kwargs:
Expand Down
43 changes: 22 additions & 21 deletions kgcnn/models/force.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import keras as ks
import keras.saving
from functools import partial
from typing import Union
from kgcnn.models.utils import get_model_class
from keras.saving import deserialize_keras_object, serialize_keras_object
Expand All @@ -12,9 +13,9 @@
import tensorflow as tf
elif backend() == "torch":
import torch
elif backend() == "jax":
import jax.numpy as jnp
import jax
# elif backend() == "jax":
# import jax.numpy as jnp
# import jax
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())

Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self,
else:
self.output_as_dict_use = False

energy_output_config = outputs[self.output_as_dict_names[0]] if self.output_as_dict_use else output[0]
energy_output_config = outputs[self.output_as_dict_names[0]] if self.output_as_dict_use else outputs[0]
self._expected_energy_states = energy_output_config["shape"][0]

# We can try to infer the model inputs from energy model, if not given explicit.
Expand All @@ -116,8 +117,8 @@ def __init__(self,
self._call_grad_backend = self._call_grad_tf
elif backend() == "torch":
self._call_grad_backend = self._call_grad_torch
elif backend() == "jax":
self._call_grad_backend = self._call_grad_jax
# elif backend() == "jax":
# self._call_grad_backend = self._call_grad_jax
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())

Expand Down Expand Up @@ -163,21 +164,21 @@ def _call_grad_torch(self, inputs, training=False, **kwargs):
e_grad = torch.squeeze(e_grad, dim=-1)
return eng, e_grad

def _call_grad_jax(self, inputs, training=False, **kwargs):

def energy_reduce(*inputs, pos: int = 0):
eng = self.energy_model(inputs, training=training, **kwargs)
eng_sum = jnp.sum(eng, axis=0)[pos]
return eng_sum

grad_fn = jax.value_and_grad(energy_reduce, argnums=self.coordinate_input)
states = [grad_fn(*inputs, pos=i) for i in range(self._expected_energy_states)]
eng = jnp.concatenate([jnp.expand_dims(x[0], axis=-1) for x in states], axis=-1)
e_grad = jnp.concatenate([jnp.expand_dims(x[1], axis=-1) for x in states], axis=-1)

if self.output_squeeze_states:
e_grad = jnp.squeeze(e_grad, axis=-1)
return eng, e_grad
# def _call_grad_jax(self, inputs, training=False, **kwargs):
#
# def energy_reduce(*inputs, pos: int = 0):
# eng = self.energy_model(inputs, training=training, **kwargs)
# eng_sum = jnp.sum(eng, axis=0)[pos]
# return eng_sum
#
# grad_fn = jax.grad(energy_reduce, argnums=self.coordinate_input)
# all_grad = [grad_fn(*inputs, pos=i) for i in range(self._expected_energy_states)]
# eng = self.energy_model(inputs, training=training, **kwargs)
# e_grad = jnp.concatenate([jnp.expand_dims(x[1], axis=-1) for x in all_grad], axis=-1)
#
# if self.output_squeeze_states:
# e_grad = jnp.squeeze(e_grad, axis=-1)
# return eng, e_grad

def call(self, inputs, training=False, **kwargs):

Expand Down
Loading

0 comments on commit 322c37e

Please sign in to comment.