Skip to content

Commit

Permalink
Replace in conv1d test (#725)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb authored Aug 14, 2023
1 parent 8a6264d commit 4a579e0
Showing 1 changed file with 92 additions and 14 deletions.
106 changes: 92 additions & 14 deletions keras_core/layers/convolutional/conv_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import math

import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
from numpy.lib.stride_tricks import as_strided

from keras_core import layers
from keras_core import testing
@@ -245,6 +248,85 @@ def test_bad_init_args(self):


class ConvCorrectnessTest(testing.TestCase, parameterized.TestCase):
def _same_padding(self, input_size, kernel_size, stride):
# P = ((S-1)*W-S+K)/2, with K = kernel size, S = stride, W = input size
padding = int(
math.ceil(((stride - 1) * input_size - stride + kernel_size) / 2)
)
return padding

def _np_conv1d(
self,
x,
kernel_weights,
bias_weights,
strides,
padding,
data_format,
dilation_rate,
groups,
):
if data_format == "channels_first":
x = x.swapaxes(1, 2)
if isinstance(strides, (tuple, list)):
h_stride = strides[0]
else:
h_stride = strides
if isinstance(dilation_rate, (tuple, list)):
dilation_rate = dilation_rate[0]
kernel_size, ch_in, ch_out = kernel_weights.shape

if dilation_rate > 1:
new_kernel_size = kernel_size + (dilation_rate - 1) * (
kernel_size - 1
)
new_kernel_weights = np.zeros(
(new_kernel_size, ch_in, ch_out), dtype=kernel_weights.dtype
)
new_kernel_weights[::dilation_rate] = kernel_weights
kernel_weights = new_kernel_weights
kernel_size = kernel_weights.shape[0]

if padding != "valid":
n_batch, h_x, _ = x.shape
h_pad = self._same_padding(h_x, kernel_size, h_stride)
npad = [(0, 0)] * x.ndim
if padding == "causal":
npad[1] = (h_pad * 2, 0)
else:
npad[1] = (h_pad, h_pad)
x = np.pad(x, pad_width=npad, mode="constant", constant_values=0)

n_batch, h_x, _ = x.shape
h_out = int((h_x - kernel_size) / h_stride) + 1

kernel_weights = kernel_weights.reshape(-1, ch_out)
bias_weights = bias_weights.reshape(1, ch_out)

out_grps = []
for grp in range(1, groups + 1):
x_in = x[..., (grp - 1) * ch_in : grp * ch_in]
stride_shape = (n_batch, h_out, kernel_size, ch_in)
strides = (
x_in.strides[0],
h_stride * x_in.strides[1],
x_in.strides[1],
x_in.strides[2],
)
inner_dim = kernel_size * ch_in
x_strided = as_strided(
x_in, shape=stride_shape, strides=strides
).reshape(n_batch, h_out, inner_dim)
ch_out_groups = ch_out // groups
kernel_weights_grp = kernel_weights[
..., (grp - 1) * ch_out_groups : grp * ch_out_groups
]
bias_weights_grp = bias_weights[
..., (grp - 1) * ch_out_groups : grp * ch_out_groups
]
out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp)
return np.concatenate(out_grps, axis=-1)

@parameterized.parameters(
{
"filters": 5,
@@ -302,31 +384,27 @@ def test_conv1d(
dilation_rate=dilation_rate,
groups=groups,
)
tf_keras_layer = tf.keras.layers.Conv1D(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)

inputs = np.random.normal(size=[2, 8, 4])
layer.build(input_shape=inputs.shape)
tf_keras_layer.build(input_shape=inputs.shape)

kernel_shape = layer.kernel.shape
kernel_weights = np.random.normal(size=kernel_shape)
bias_weights = np.random.normal(size=(filters,))
layer.kernel.assign(kernel_weights)
tf_keras_layer.kernel.assign(kernel_weights)

layer.bias.assign(bias_weights)
tf_keras_layer.bias.assign(bias_weights)

outputs = layer(inputs)
expected = tf_keras_layer(inputs)
expected = self._np_conv1d(
inputs,
kernel_weights,
bias_weights,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
groups=groups,
)
self.assertAllClose(outputs, expected)

@parameterized.parameters(

0 comments on commit 4a579e0

Please sign in to comment.