diff --git a/layers/model_builder.py b/layers/model_builder.py index eda2bb9..5679a7e 100644 --- a/layers/model_builder.py +++ b/layers/model_builder.py @@ -37,6 +37,7 @@ def __init__( use_bias=False, groups=1, conv_func=tf.keras.layers.Conv2D, + norm_func=normalization, name=None, ): @@ -54,7 +55,7 @@ def __init__( name="{}_conv".format(name), ) - self.bn = None if not use_bn else normalization(trainable=trainable, name="{}_bn".format(name)) + self.bn = None if not use_bn else norm_func(trainable=trainable, name="{}_bn".format(name)) self.activation = activation