Skip to content

Commit

Permalink
Merge pull request #42 from NREL/bnb/input_feature_attrs
Browse files Browse the repository at this point in the history
input_feature names for concat/adder layers
  • Loading branch information
bnb32 authored Sep 19, 2023
2 parents 8a09e3e + 616f9ae commit 9e88552
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
"""Custom tf layers."""
import numpy as np
import logging

import numpy as np
import tensorflow as tf

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(self, paddings, mode='REFLECT'):
self.mode = mode

def compute_output_shape(self, input_shape):
"""computes output shape after padding
"""Computes output shape after padding
Parameters
----------
Expand All @@ -47,7 +48,7 @@ def compute_output_shape(self, input_shape):
return tf.TensorShape(output_shape)

def call(self, x):
"""calls the padding routine
"""Calls the padding routine
Parameters
----------
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, axis=3):
self._axis = axis

def call(self, x):
"""calls the expand dims operation
"""Calls the expand dims operation
Parameters
----------
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(self, multiples):
self._mult = tf.constant(multiples, tf.int32)

def call(self, x):
"""calls the tile operation
"""Calls the tile operation
Parameters
----------
Expand Down Expand Up @@ -168,7 +169,7 @@ def build(self, input_shape):
self._rand_shape = tf.constant(shape, dtype=tf.dtypes.int32)

def call(self, x):
"""calls the tile operation
"""Calls the tile operation
Parameters
----------
Expand Down Expand Up @@ -221,7 +222,7 @@ def _check_shape(input_shape):
assert len(input_shape) == 5, msg

def call(self, x):
"""calls the flatten axis operation
"""Calls the flatten axis operation
Parameters
----------
Expand Down Expand Up @@ -601,8 +602,18 @@ class Sup3rAdder(tf.keras.layers.Layer):
"""Layer to add high-resolution data to a sup3r model in the middle of a
super resolution forward pass."""

def __init__(self, name=None):
"""
Parameters
----------
name : str | None
Unique str identifier of the adder layer. Usually the name of the
hi-resolution feature used in the addition.
"""
self.name = name

def call(self, x, hi_res_adder):
"""adds hi-resolution data to the input tensor x in the middle of a
"""Adds hi-resolution data to the input tensor x in the middle of a
sup3r resolution network.
Parameters
Expand All @@ -626,8 +637,18 @@ class Sup3rConcat(tf.keras.layers.Layer):
"""Layer to concatenate a high-resolution feature to a sup3r model in the
middle of a super resolution forward pass."""

def __init__(self, name=None):
"""
Parameters
----------
name : str | None
Unique str identifier for the concat layer. Usually the name of the
hi-resolution feature used in the concatenation.
"""
self.name = name

def call(self, x, hi_res_feature):
"""concatenates a hi-resolution feature to the input tensor x in the
"""Concatenates a hi-resolution feature to the input tensor x in the
middle of a sup3r resolution network.
Parameters
Expand Down

0 comments on commit 9e88552

Please sign in to comment.