Skip to content

Commit

Permalink
feat: use conv1d arithmetics function to implement conv1d module
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed May 27, 2023
1 parent 1cab190 commit 69778be
Showing 1 changed file with 17 additions and 54 deletions.
71 changes: 17 additions & 54 deletions elasticai/creator/base_modules/conv1d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any

import torch
from torch.nn.functional import pad

from elasticai.creator.base_modules.arithmetics import Arithmetics

Expand All @@ -15,6 +14,8 @@ def __init__(
kernel_size: int | tuple[int],
stride: int | tuple[int] = 1,
padding: int | tuple[int] | str = 0,
dilation: int | tuple[int] = 1,
groups: int = 1,
bias: bool = True,
device: Any = None,
dtype: Any = None,
Expand All @@ -25,8 +26,8 @@ def __init__(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=1,
groups=1,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode="zeros",
device=device,
Expand All @@ -38,55 +39,17 @@ def __init__(
def _flatten_tuple(x: int | tuple[int, ...]) -> int:
return x[0] if isinstance(x, tuple) else x

def _pad_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
kernel_size = self._flatten_tuple(self.kernel_size)
stride = self._flatten_tuple(self.stride)

if self.padding == "valid":
padding = (0, 0)
elif self.padding == "same":
if stride != 1:
raise ValueError("'same' padding only supports stride of 1.")
padding = (kernel_size // 2 + kernel_size % 2 - 1, kernel_size // 2)
elif isinstance(self.padding, (int, tuple)):
pad_value = self._flatten_tuple(self.padding)
padding = (pad_value, pad_value)
else:
raise ValueError(f"Padding {self.padding} is not supported.")
return pad(inputs, pad=padding, mode="constant", value=0)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
has_batches = inputs.dim() == 3
if not has_batches:
inputs = inputs.expand(1, -1, -1)

inputs = self._pad_inputs(inputs)

batch_size, _, input_length = inputs.shape
kernel_size = self._flatten_tuple(self.kernel_size)
stride = self._flatten_tuple(self.stride)

output_length = (input_length - kernel_size) // stride + 1
outputs = torch.empty(batch_size, self.out_channels, output_length)

weight = self._arithmetics.quantize(self.weight)
bias = None if self.bias is None else self._arithmetics.quantize(self.bias)

for window_start_idx in range(output_length):
start_idx = window_start_idx * stride
input_slice = inputs[:, :, start_idx : start_idx + kernel_size]

weighted_slice = self._arithmetics.mul(input_slice, weight)
output_single_conv = self._arithmetics.sum(weighted_slice, dim=(1, 2))

if bias is not None:
output_single_conv = self._arithmetics.add(output_single_conv, bias)

outputs[:, :, window_start_idx] = output_single_conv.view(
batch_size, self.out_channels
)

if not has_batches:
outputs = outputs.squeeze(dim=0)

return outputs
return self._arithmetics.conv1d(
inputs=inputs,
weights=self.weight,
bias=self.bias,
stride=self._flatten_tuple(self.stride),
padding=(
self.padding
if isinstance(self.padding, str)
else self._flatten_tuple(self.padding)
),
dilation=self._flatten_tuple(self.dilation),
groups=self.groups,
)

0 comments on commit 69778be

Please sign in to comment.