Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gb/fix gpool #50

Merged
merged 4 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions phygnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from .base import CustomNetwork, GradientUtils
from .phygnn import PhysicsGuidedNeuralNetwork
from .layers import Layers, HiddenLayers
from .layers.custom_layers import GaussianKernelInit2D
from .layers.custom_layers import GaussianAveragePooling2D
from .utilities import PreProcess, tf_isin, tf_log10
from phygnn.version import __version__
from tensorflow.keras.utils import get_custom_objects

get_custom_objects()['GaussianKernelInit2D'] = GaussianKernelInit2D
get_custom_objects()['GaussianAveragePooling2D'] = GaussianAveragePooling2D

__author__ = """Grant Buster"""
__email__ = "grant.buster@nrel.gov"
Expand Down
148 changes: 103 additions & 45 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,109 @@ def call(self, x):
return tf.tile(x, self._mult)


class GaussianAveragePooling2D(tf.keras.layers.Layer):
"""Custom layer to implement tensorflow average pooling layer but with a
gaussian kernel. This is basically a gaussian smoothing layer with a fixed
convolution window that limits the area of effect"""

def __init__(self, pool_size, strides=None, padding='valid', sigma=1,
**kwargs):
"""
Parameters
----------
pool_size: integer
Pooling window size. This sets the number of pixels in each
dimension that will be averaged into an output pixel. Only one
integer is specified, the same window length will be used for both
dimensions. For example, if ``pool_size=2`` and ``strides=2`` then
the output dimension will be half of the input.
strides: Integer, tuple of 2 integers, or None.
Strides values. If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the
same height/width dimension as the input.
sigma : float
Sigma parameter for gaussian distribution
kwargs : dict
Extra kwargs for tf.keras.layers.Layer
"""

super().__init__(**kwargs)
assert isinstance(pool_size, int), 'pool_size must be int!'
self._pool_size = pool_size
self._strides = strides
self._padding = padding.upper()
self._sigma = sigma

target_shape = (self._pool_size, self._pool_size, 1, 1)
self._kernel = self._make_2D_gaussian_kernel(self._pool_size,
self._sigma)
self._kernel = np.expand_dims(self._kernel, -1)
self._kernel = np.expand_dims(self._kernel, -1)
assert self._kernel.shape == target_shape
self._kernel = tf.convert_to_tensor(self._kernel, dtype=tf.float32)

@staticmethod
def _make_2D_gaussian_kernel(edge_len, sigma=1.):
"""Creates 2D gaussian kernel with side length `edge_len` and a sigma
of `sigma`
Parameters
----------
edge_len : int
Edge size of the kernel
sigma : float
Sigma parameter for gaussian distribution
Returns
-------
kernel : np.ndarray
2D kernel with shape (edge_len, edge_len)
"""
ax = np.linspace(-(edge_len - 1) / 2., (edge_len - 1) / 2., edge_len)
gauss = np.exp(-0.5 * np.square(ax) / np.square(sigma))
kernel = np.outer(gauss, gauss)
kernel = kernel / np.sum(kernel)
return kernel.astype(np.float32)

def get_config(self):
"""Implementation of get_config method from tf.keras.layers.Layer for
saving/loading as part of keras sequential model.

Returns
-------
config : dict
"""
config = super().get_config().copy()
config.update({
'pool_size': self._pool_size,
'strides': self._strides,
'padding': self._padding,
'sigma': self._sigma,
})
return config

def call(self, x):
"""Operates on x with the specified function
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor operated on by the specified function
"""
out = []
for idf in range(x.shape[-1]):
fslice = slice(idf, idf + 1)
iout = tf.nn.convolution(x[..., fslice], self._kernel,
strides=self._strides,
padding=self._padding)
out.append(iout)
out = tf.concat(out, -1, name='concat')
return out


class GaussianNoiseAxis(tf.keras.layers.Layer):
"""Layer to apply random noise along a given axis."""

Expand Down Expand Up @@ -189,51 +292,6 @@ def call(self, x):
return x * rand_tensor


class GaussianKernelInit2D(tf.keras.initializers.Initializer):
"""Convolutional kernel initializer that creates a symmetric 2D array with
a gaussian distribution. This can be used with Conv2D as a gaussian average
pooling layer if trainable=False
"""

def __init__(self, stdev=1):
"""
Parameters
----------
stdev : float
Standard deviation of the gaussian distribution defining the kernel
values
"""
self.stdev = stdev

def __call__(self, shape, dtype=tf.float32):
"""
Parameters
---------
shape : tuple
Shape of the input tensor, typically (y, x, n_features, n_obs)
dtype : None | tf.DType
Tensorflow datatype e.g., tf.float32

Returns
-------
kernel : tf.Tensor
Kernel tensor of shape (y, x, n_features, n_obs) for use in a
Conv2D layer.
"""

ax = np.linspace(-(shape[0] - 1) / 2., (shape[0] - 1) / 2., shape[0])
kernel = np.exp(-0.5 * np.square(ax) / np.square(self.stdev))
kernel = np.outer(kernel, kernel)
kernel = kernel / np.sum(kernel)

kernel = np.expand_dims(kernel, (2, 3))
kernel = np.repeat(kernel, shape[2], axis=2)
kernel = np.repeat(kernel, shape[3], axis=3)

kernel = tf.convert_to_tensor(kernel, dtype=dtype)
return kernel


class FlattenAxis(tf.keras.layers.Layer):
"""Layer to flatten an axis from a 5D spatiotemporal Tensor into axis-0
observations."""
Expand Down
2 changes: 1 addition & 1 deletion phygnn/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Physics Guided Neural Network version."""

__version__ = '0.0.27'
__version__ = '0.0.28'
43 changes: 21 additions & 22 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SpatioTemporalExpansion,
TileLayer,
FunctionalLayer,
GaussianKernelInit2D,
GaussianAveragePooling2D,
)
from phygnn.layers.handlers import HiddenLayers, Layers
from phygnn import TfModel
Expand Down Expand Up @@ -450,51 +450,50 @@ def test_functional_layer():
assert "must be one of" in str(excinfo.value)


def test_gaussian_kernel():
"""Test the gaussian kernel initializer for gaussian average pooling"""
def test_gaussian_pooling():
"""Test the gaussian average pooling layer"""

kernels = []
biases = []
for stdev in [1, 2]:
kinit = GaussianKernelInit2D(stdev=stdev)
layer = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1,
padding='valid',
kernel_initializer=kinit)
layer = GaussianAveragePooling2D(pool_size=5, strides=1, sigma=stdev)
_ = layer(np.ones((24, 100, 100, 35)))
kernel = layer.weights[0].numpy()
bias = layer.weights[1].numpy()
kernel = layer._kernel.numpy()
kernels.append(kernel)
biases.append(bias)

assert (kernel[:, :, 0, 0] == kernel[:, :, -1, -1]).all()
assert kernel[:, :, 0, 0].sum() == 1
assert (bias == 0).all()
assert kernel[2, 2, 0, 0] == kernel.max()
assert kernel[0, 0, 0, 0] == kernel.min()
assert kernel[-1, -1, 0, 0] == kernel.min()

assert kernels[1].max() < kernels[0].max()
assert kernels[1].min() > kernels[0].min()

layers = [{'class': 'Conv2D', 'filters': 16, 'kernel_size': 3,
'kernel_initializer': GaussianKernelInit2D(),
'trainable': False}]
model1 = TfModel.build(['a'], ['b'], hidden_layers=layers,
input_layer=False, output_layer=False)
x_in = np.random.uniform(0, 1, (10, 12, 12, 1))
layers = [{'class': 'GaussianAveragePooling2D', 'pool_size': 12,
'strides': 1}]
model1 = TfModel.build(['a', 'b', 'c'], ['d'], hidden_layers=layers,
input_layer=False, output_layer=False,
normalize=False)
x_in = np.random.uniform(0, 1, (1, 12, 12, 3))
out1 = model1.predict(x_in)
kernel1 = model1.layers[0].weights[0][:, :, 0, 0].numpy()
kernel1 = model1.layers[0]._kernel[:, :, 0, 0].numpy()

for idf in range(out1.shape[-1]):
test = (x_in[0, :, :, idf] * kernel1).sum()
assert np.allclose(test, out1[..., idf])

assert out1.shape[1] == out1.shape[2] == 1
assert out1[0, 0, 0, 0] != out1[0, 0, 0, 1] != out1[0, 0, 0, 2]

with TemporaryDirectory() as td:
model_path = os.path.join(td, 'test_model')
model1.save_model(model_path)
model2 = TfModel.load(model_path)

kernel2 = model2.layers[0].weights[0][:, :, 0, 0].numpy()
kernel2 = model2.layers[0]._kernel[:, :, 0, 0].numpy()
out2 = model2.predict(x_in)
assert np.allclose(kernel1, kernel2)
assert np.allclose(out1, out2)

layer = model2.layers[0]
x_in = np.random.uniform(0, 1, (10, 24, 24, 1))
x_in = np.random.uniform(0, 1, (10, 24, 24, 3))
_ = model2.predict(x_in)
Loading