Skip to content

Commit

Permalink
update keras core.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 11, 2023
1 parent cbe5a0d commit 29e5dcc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
23 changes: 23 additions & 0 deletions kgcnn/backend/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import jax.numpy as jnp


def scatter_reduce_sum(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
return zeros.at[indices].add(values)


def scatter_reduce_min(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype.max, values.dtype)
return zeros.at[indices].min(values)


def scatter_reduce_max(indices, values, shape):
zeros = jnp.full(shape, values.dtype.min, values.dtype)
return zeros.at[indices].max(values)


def scatter_reduce_mean(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
counts = jnp.zeros(shape, values.dtype)
counts.at[indices].add(jnp.ones_like(values))
return zeros.at[indices].add(values)/counts
8 changes: 4 additions & 4 deletions kgcnn/backend/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import tensorflow as tf


def scatter_reduce_sum(indices, values, shape):
return tf.scatter_nd(indices, values, tf.cast(shape, dtype="int64"))


def scatter_reduce_min(indices, values, shape):
target = tf.fill(shape, values.dtype.limits[1], dtype=values.dtype)
return tf.tensor_scatter_nd_min(target, indices, values)
Expand All @@ -16,10 +20,6 @@ def scatter_reduce_mean(indices, values, shape):
return tf.scatter_nd(indices, values, shape)/counts


def scatter_reduce_sum(indices, values, shape):
return tf.scatter_nd(indices, values, tf.cast(shape, dtype="int64"))


def scatter_reduce_softmax(indices, values, shape):
# if normalize:
# data_segment_max = tf.math.segment_max(data, segment_ids)
Expand Down

0 comments on commit 29e5dcc

Please sign in to comment.