From f154b5708a220b0caa05a0d410e3823c57fcffda Mon Sep 17 00:00:00 2001 From: PatReis Date: Wed, 13 Sep 2023 15:29:52 +0200 Subject: [PATCH] continue keras core integration --- kgcnn/layers_core/mlp.py | 48 +++++++++++++++++++-------------------- kgcnn/layers_core/norm.py | 16 ------------- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/kgcnn/layers_core/mlp.py b/kgcnn/layers_core/mlp.py index 767f9b98..71a5e850 100644 --- a/kgcnn/layers_core/mlp.py +++ b/kgcnn/layers_core/mlp.py @@ -1,8 +1,26 @@ import keras_core as ks from keras_core.layers import Dense, Layer, Activation, Dropout from keras_core.layers import LayerNormalization, GroupNormalization, BatchNormalization, UnitNormalization -from kgcnn.layers_core.norm import global_normalization_args -from kgcnn.layers_core.norm import GraphNormalization, GraphInstanceNormalization + + +global_normalization_args = { + "UnitNormalization": ( + "axis" + ), + "BatchNormalization": ( + "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer", + "gamma_regularizer", "beta_constraint", "gamma_constraint", "momentum", "moving_mean_initializer", + "moving_variance_initializer" + ), + "GroupNormalization": ( + "groups", "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer", + "gamma_regularizer", "beta_constraint", "gamma_constraint" + ), + "LayerNormalization": ( + "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer", + "gamma_regularizer", "beta_constraint", "gamma_constraint" + ) +} class MLPBase(Layer): @@ -259,16 +277,6 @@ def __init__(self, units, **kwargs): "BatchNormalization": BatchNormalization, "GroupNormalization": GroupNormalization, "LayerNormalization": LayerNormalization, - "GraphNormalization": GraphNormalization, - "GraphInstanceNormalization": GraphInstanceNormalization - } - requires_batch_classes = { - "UnitNormalization": False, - "BatchNormalization": False, - "GroupNormalization": False, - "LayerNormalization": False, - "GraphNormalization": True, - "GraphInstanceNormalization": True } if not self._supress_dense: self.mlp_dense_layer_list = [ @@ -289,24 +297,17 @@ def __init__(self, units, **kwargs): **self._get_conf_for_keys(self._key_dict_norm[self._conf_normalization_technique[i]], "norm", i) ) if self._conf_use_normalization[i] else None for i in range(self._depth) ] - self._norm_requires_batch_info = [ - requires_batch_classes[self._conf_normalization_technique[i]] if self._conf_use_normalization[ - i] else None for i in range(self._depth) - ] def build(self, input_shape): """Build layer.""" - x_shape, batch = (input_shape[0], input_shape[1:]) if isinstance(input_shape, list) else (input_shape, []) + x_shape, _ = (input_shape[0], input_shape[1:]) if isinstance(input_shape, list) else (input_shape, []) for i in range(self._depth): self.mlp_dense_layer_list[i].build(x_shape) x_shape = self.mlp_dense_layer_list[i].compute_output_shape(x_shape) if self._conf_use_dropout[i]: self.mlp_dropout_layer_list[i].build(x_shape) if self._conf_use_normalization[i]: - if self._norm_requires_batch_info[i]: - self.mlp_norm_layer_list[i].build([x_shape] + batch) - else: - self.mlp_norm_layer_list[i].build(x_shape) + self.mlp_norm_layer_list[i].build(x_shape) self.mlp_activation_layer_list[i].build(x_shape) self.built = True @@ -326,10 +327,7 @@ def call(self, inputs, **kwargs): if self._conf_use_dropout[i]: x = self.mlp_dropout_layer_list[i](x, **kwargs) if self._conf_use_normalization[i]: - if self._norm_requires_batch_info[i]: - x = self.mlp_norm_layer_list[i]([x]+batch, **kwargs) - else: - x = self.mlp_norm_layer_list[i](x, **kwargs) + x = self.mlp_norm_layer_list[i](x, **kwargs) x = self.mlp_activation_layer_list[i](x, **kwargs) out = x return out diff --git a/kgcnn/layers_core/norm.py b/kgcnn/layers_core/norm.py index 7b5dccc4..f05233da 100644 --- a/kgcnn/layers_core/norm.py +++ b/kgcnn/layers_core/norm.py @@ -4,22 +4,6 @@ global_normalization_args = { - "UnitNormalization": [ - "axis" - ], - "BatchNormalization": ( - "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer", - "gamma_regularizer", "beta_constraint", "gamma_constraint", "momentum", "moving_mean_initializer", - "moving_variance_initializer" - ), - "GroupNormalization": ( - "groups", "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer", - "gamma_regularizer", "beta_constraint", "gamma_constraint" - ), - "LayerNormalization": ( - "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer", - "gamma_regularizer", "beta_constraint", "gamma_constraint" - ), "GraphNormalization": ( "mean_shift", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "alpha_initializer", "beta_regularizer", "gamma_regularizer", "beta_constraint", "alpha_constraint", "gamma_constraint",