diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index c7e2f0e..fdc3c56 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """Custom tf layers.""" -import numpy as np import logging + +import numpy as np import tensorflow as tf logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def __init__(self, paddings, mode='REFLECT'): self.mode = mode def compute_output_shape(self, input_shape): - """computes output shape after padding + """Computes output shape after padding Parameters ---------- @@ -47,7 +48,7 @@ def compute_output_shape(self, input_shape): return tf.TensorShape(output_shape) def call(self, x): - """calls the padding routine + """Calls the padding routine Parameters ---------- @@ -82,7 +83,7 @@ def __init__(self, axis=3): self._axis = axis def call(self, x): - """calls the expand dims operation + """Calls the expand dims operation Parameters ---------- @@ -113,7 +114,7 @@ def __init__(self, multiples): self._mult = tf.constant(multiples, tf.int32) def call(self, x): - """calls the tile operation + """Calls the tile operation Parameters ---------- @@ -168,7 +169,7 @@ def build(self, input_shape): self._rand_shape = tf.constant(shape, dtype=tf.dtypes.int32) def call(self, x): - """calls the tile operation + """Calls the tile operation Parameters ---------- @@ -221,7 +222,7 @@ def _check_shape(input_shape): assert len(input_shape) == 5, msg def call(self, x): - """calls the flatten axis operation + """Calls the flatten axis operation Parameters ---------- @@ -601,8 +602,18 @@ class Sup3rAdder(tf.keras.layers.Layer): """Layer to add high-resolution data to a sup3r model in the middle of a super resolution forward pass.""" + def __init__(self, name=None): + """ + Parameters + ---------- + name : str | None + Unique str identifier of the adder layer. Usually the name of the + hi-resolution feature used in the addition. + """ + self.name = name + def call(self, x, hi_res_adder): - """adds hi-resolution data to the input tensor x in the middle of a + """Adds hi-resolution data to the input tensor x in the middle of a sup3r resolution network. Parameters @@ -626,8 +637,18 @@ class Sup3rConcat(tf.keras.layers.Layer): """Layer to concatenate a high-resolution feature to a sup3r model in the middle of a super resolution forward pass.""" + def __init__(self, name=None): + """ + Parameters + ---------- + name : str | None + Unique str identifier for the concat layer. Usually the name of the + hi-resolution feature used in the concatenation. + """ + self.name = name + def call(self, x, hi_res_feature): - """concatenates a hi-resolution feature to the input tensor x in the + """Concatenates a hi-resolution feature to the input tensor x in the middle of a sup3r resolution network. Parameters