Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 23, 2023
1 parent b000b00 commit a4fe47c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions kgcnn/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import kgcnn.backend as kgcnn_backend


class RepeatStaticLength(Operation):
class _RepeatStaticLength(Operation):

def __init__(self, total_repeat_length: int, axis=None):
super().__init__()
Expand Down Expand Up @@ -38,5 +38,5 @@ def repeat_static_length(x, repeats, total_repeat_length: int, axis=None):
Output tensor.
"""
if any_symbolic_tensors((x, repeats)):
return RepeatStaticLength(axis=axis, total_repeat_length=total_repeat_length).symbolic_call(x, repeats)
return _RepeatStaticLength(axis=axis, total_repeat_length=total_repeat_length).symbolic_call(x, repeats)
return kgcnn_backend.repeat_static_length(x, repeats, axis=axis, total_repeat_length=total_repeat_length)
20 changes: 10 additions & 10 deletions kgcnn/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from keras_core import Operation


class ScatterMax(Operation):
class _ScatterMax(Operation):
def call(self, indices, values, shape):
return kgcnn_backend.scatter_reduce_max(indices, values, shape)

Expand All @@ -14,11 +14,11 @@ def compute_output_spec(self, indices, values, shape):

def scatter_reduce_max(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMax().symbolic_call(indices, values, shape)
return _ScatterMax().symbolic_call(indices, values, shape)
return kgcnn_backend.scatter_reduce_max(indices, values, shape)


class ScatterMin(Operation):
class _ScatterMin(Operation):
def call(self, indices, values, shape):
return kgcnn_backend.scatter_reduce_min(indices, values, shape)

Expand All @@ -28,11 +28,11 @@ def compute_output_spec(self, indices, values, shape):

def scatter_reduce_min(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMin().symbolic_call(indices, values, shape)
return _ScatterMin().symbolic_call(indices, values, shape)
return kgcnn_backend.scatter_reduce_min(indices, values, shape)


class ScatterMean(Operation):
class _ScatterMean(Operation):
def call(self, indices, values, shape):
return kgcnn_backend.scatter_reduce_mean(indices, values, shape)

Expand All @@ -42,11 +42,11 @@ def compute_output_spec(self, indices, values, shape):

def scatter_reduce_mean(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMean().symbolic_call(indices, values, shape)
return _ScatterMean().symbolic_call(indices, values, shape)
return kgcnn_backend.scatter_reduce_mean(indices, values, shape)


class ScatterSum(Operation):
class _ScatterSum(Operation):
def call(self, indices, values, shape):
return kgcnn_backend.scatter_reduce_sum(indices, values, shape)

Expand All @@ -56,11 +56,11 @@ def compute_output_spec(self, indices, values, shape):

def scatter_reduce_sum(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterSum().symbolic_call(indices, values, shape)
return _ScatterSum().symbolic_call(indices, values, shape)
return kgcnn_backend.scatter_reduce_sum(indices, values, shape)


class ScatterSoftmax(Operation):
class _ScatterSoftmax(Operation):

def __init__(self, normalize: bool = False):
super().__init__()
Expand All @@ -75,5 +75,5 @@ def compute_output_spec(self, indices, values, shape):

def scatter_reduce_softmax(indices, values, shape, normalize: bool = False):
if any_symbolic_tensors((indices, values, shape)):
return ScatterSoftmax(normalize=normalize).symbolic_call(indices, values, shape)
return _ScatterSoftmax(normalize=normalize).symbolic_call(indices, values, shape)
return kgcnn_backend.scatter_reduce_softmax(indices, values, shape, normalize=normalize)

0 comments on commit a4fe47c

Please sign in to comment.