From a4fe47c711e7d0e1947af1c2dca52e1a5c4938fd Mon Sep 17 00:00:00 2001 From: PatReis Date: Sat, 23 Sep 2023 21:27:40 +0200 Subject: [PATCH] continue keras core integration --- kgcnn/ops/core.py | 4 ++-- kgcnn/ops/scatter.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/kgcnn/ops/core.py b/kgcnn/ops/core.py index a6f9d976..8f2b531b 100644 --- a/kgcnn/ops/core.py +++ b/kgcnn/ops/core.py @@ -5,7 +5,7 @@ import kgcnn.backend as kgcnn_backend -class RepeatStaticLength(Operation): +class _RepeatStaticLength(Operation): def __init__(self, total_repeat_length: int, axis=None): super().__init__() @@ -38,5 +38,5 @@ def repeat_static_length(x, repeats, total_repeat_length: int, axis=None): Output tensor. """ if any_symbolic_tensors((x, repeats)): - return RepeatStaticLength(axis=axis, total_repeat_length=total_repeat_length).symbolic_call(x, repeats) + return _RepeatStaticLength(axis=axis, total_repeat_length=total_repeat_length).symbolic_call(x, repeats) return kgcnn_backend.repeat_static_length(x, repeats, axis=axis, total_repeat_length=total_repeat_length) diff --git a/kgcnn/ops/scatter.py b/kgcnn/ops/scatter.py index 153c79cb..de7c8363 100644 --- a/kgcnn/ops/scatter.py +++ b/kgcnn/ops/scatter.py @@ -4,7 +4,7 @@ from keras_core import Operation -class ScatterMax(Operation): +class _ScatterMax(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_max(indices, values, shape) @@ -14,11 +14,11 @@ def compute_output_spec(self, indices, values, shape): def scatter_reduce_max(indices, values, shape): if any_symbolic_tensors((indices, values, shape)): - return ScatterMax().symbolic_call(indices, values, shape) + return _ScatterMax().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_max(indices, values, shape) -class ScatterMin(Operation): +class _ScatterMin(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_min(indices, values, shape) @@ -28,11 +28,11 @@ def compute_output_spec(self, indices, values, shape): def scatter_reduce_min(indices, values, shape): if any_symbolic_tensors((indices, values, shape)): - return ScatterMin().symbolic_call(indices, values, shape) + return _ScatterMin().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_min(indices, values, shape) -class ScatterMean(Operation): +class _ScatterMean(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_mean(indices, values, shape) @@ -42,11 +42,11 @@ def compute_output_spec(self, indices, values, shape): def scatter_reduce_mean(indices, values, shape): if any_symbolic_tensors((indices, values, shape)): - return ScatterMean().symbolic_call(indices, values, shape) + return _ScatterMean().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_mean(indices, values, shape) -class ScatterSum(Operation): +class _ScatterSum(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_sum(indices, values, shape) @@ -56,11 +56,11 @@ def compute_output_spec(self, indices, values, shape): def scatter_reduce_sum(indices, values, shape): if any_symbolic_tensors((indices, values, shape)): - return ScatterSum().symbolic_call(indices, values, shape) + return _ScatterSum().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_sum(indices, values, shape) -class ScatterSoftmax(Operation): +class _ScatterSoftmax(Operation): def __init__(self, normalize: bool = False): super().__init__() @@ -75,5 +75,5 @@ def compute_output_spec(self, indices, values, shape): def scatter_reduce_softmax(indices, values, shape, normalize: bool = False): if any_symbolic_tensors((indices, values, shape)): - return ScatterSoftmax(normalize=normalize).symbolic_call(indices, values, shape) + return _ScatterSoftmax(normalize=normalize).symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_softmax(indices, values, shape, normalize=normalize) \ No newline at end of file