Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 13, 2023
1 parent 6dfba04 commit f154b57
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 41 deletions.
48 changes: 23 additions & 25 deletions kgcnn/layers_core/mlp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
import keras_core as ks
from keras_core.layers import Dense, Layer, Activation, Dropout
from keras_core.layers import LayerNormalization, GroupNormalization, BatchNormalization, UnitNormalization
from kgcnn.layers_core.norm import global_normalization_args
from kgcnn.layers_core.norm import GraphNormalization, GraphInstanceNormalization


global_normalization_args = {
"UnitNormalization": (
"axis"
),
"BatchNormalization": (
"axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
"gamma_regularizer", "beta_constraint", "gamma_constraint", "momentum", "moving_mean_initializer",
"moving_variance_initializer"
),
"GroupNormalization": (
"groups", "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
"gamma_regularizer", "beta_constraint", "gamma_constraint"
),
"LayerNormalization": (
"axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
"gamma_regularizer", "beta_constraint", "gamma_constraint"
)
}


class MLPBase(Layer):
Expand Down Expand Up @@ -259,16 +277,6 @@ def __init__(self, units, **kwargs):
"BatchNormalization": BatchNormalization,
"GroupNormalization": GroupNormalization,
"LayerNormalization": LayerNormalization,
"GraphNormalization": GraphNormalization,
"GraphInstanceNormalization": GraphInstanceNormalization
}
requires_batch_classes = {
"UnitNormalization": False,
"BatchNormalization": False,
"GroupNormalization": False,
"LayerNormalization": False,
"GraphNormalization": True,
"GraphInstanceNormalization": True
}
if not self._supress_dense:
self.mlp_dense_layer_list = [
Expand All @@ -289,24 +297,17 @@ def __init__(self, units, **kwargs):
**self._get_conf_for_keys(self._key_dict_norm[self._conf_normalization_technique[i]], "norm", i)
) if self._conf_use_normalization[i] else None for i in range(self._depth)
]
self._norm_requires_batch_info = [
requires_batch_classes[self._conf_normalization_technique[i]] if self._conf_use_normalization[
i] else None for i in range(self._depth)
]

def build(self, input_shape):
"""Build layer."""
x_shape, batch = (input_shape[0], input_shape[1:]) if isinstance(input_shape, list) else (input_shape, [])
x_shape, _ = (input_shape[0], input_shape[1:]) if isinstance(input_shape, list) else (input_shape, [])
for i in range(self._depth):
self.mlp_dense_layer_list[i].build(x_shape)
x_shape = self.mlp_dense_layer_list[i].compute_output_shape(x_shape)
if self._conf_use_dropout[i]:
self.mlp_dropout_layer_list[i].build(x_shape)
if self._conf_use_normalization[i]:
if self._norm_requires_batch_info[i]:
self.mlp_norm_layer_list[i].build([x_shape] + batch)
else:
self.mlp_norm_layer_list[i].build(x_shape)
self.mlp_norm_layer_list[i].build(x_shape)
self.mlp_activation_layer_list[i].build(x_shape)
self.built = True

Expand All @@ -326,10 +327,7 @@ def call(self, inputs, **kwargs):
if self._conf_use_dropout[i]:
x = self.mlp_dropout_layer_list[i](x, **kwargs)
if self._conf_use_normalization[i]:
if self._norm_requires_batch_info[i]:
x = self.mlp_norm_layer_list[i]([x]+batch, **kwargs)
else:
x = self.mlp_norm_layer_list[i](x, **kwargs)
x = self.mlp_norm_layer_list[i](x, **kwargs)
x = self.mlp_activation_layer_list[i](x, **kwargs)
out = x
return out
Expand Down
16 changes: 0 additions & 16 deletions kgcnn/layers_core/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,6 @@


global_normalization_args = {
"UnitNormalization": [
"axis"
],
"BatchNormalization": (
"axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
"gamma_regularizer", "beta_constraint", "gamma_constraint", "momentum", "moving_mean_initializer",
"moving_variance_initializer"
),
"GroupNormalization": (
"groups", "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
"gamma_regularizer", "beta_constraint", "gamma_constraint"
),
"LayerNormalization": (
"axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
"gamma_regularizer", "beta_constraint", "gamma_constraint"
),
"GraphNormalization": (
"mean_shift", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "alpha_initializer",
"beta_regularizer", "gamma_regularizer", "beta_constraint", "alpha_constraint", "gamma_constraint",
Expand Down

0 comments on commit f154b57

Please sign in to comment.