From 0f499f048c0606de3e14163f16e8bf049708e6f1 Mon Sep 17 00:00:00 2001 From: Julian Hoever Date: Sat, 12 Aug 2023 11:55:21 +0200 Subject: [PATCH] fix: add dummy batch dimension to meet the requirements of the batch norm --- elasticai/creator/nn/batch_normed_conv1d/layer.py | 7 ++++--- elasticai/creator/nn/batch_normed_linear/layer.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/elasticai/creator/nn/batch_normed_conv1d/layer.py b/elasticai/creator/nn/batch_normed_conv1d/layer.py index 375a0678..f0280e58 100644 --- a/elasticai/creator/nn/batch_normed_conv1d/layer.py +++ b/elasticai/creator/nn/batch_normed_conv1d/layer.py @@ -58,15 +58,16 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: input_shape = ( (inputs.shape[0], self._conv1d.in_channels, -1) if has_batches - else (self._conv1d.in_channels, -1) + else (1, self._conv1d.in_channels, -1) ) output_shape = (inputs.shape[0], -1) if has_batches else (-1,) + x = inputs.view(*input_shape) x = self._conv1d(x) x = self._batch_norm(x) x = self._arithmetics.quantize(x) - outputs = x.view(*output_shape) - return outputs + + return x.view(*output_shape) def translate(self, name: str) -> FPConv1dDesign: def float_to_signed_int(value: float | list) -> int | list: diff --git a/elasticai/creator/nn/batch_normed_linear/layer.py b/elasticai/creator/nn/batch_normed_linear/layer.py index 3c222b63..56e22cd0 100644 --- a/elasticai/creator/nn/batch_normed_linear/layer.py +++ b/elasticai/creator/nn/batch_normed_linear/layer.py @@ -47,9 +47,16 @@ def __init__( ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: - x = self._linear(inputs) + has_batches = inputs.dim() == 2 + input_shape = inputs.shape if has_batches else (1, -1) + output_shape = (inputs.shape[0], -1) if has_batches else (-1,) + + x = inputs.view(*input_shape) + x = self._linear(x) x = self._batch_norm(x) - return self._arithmetics.quantize(x) + x = self._arithmetics.quantize(x) + + return x.view(*output_shape) def translate(self, name: str) -> FPLinearDesign: def float_to_signed_int(value: float | list) -> int | list: