diff --git a/elasticai/creator/base_modules/autograd_functions/step_function_inputs.py b/elasticai/creator/base_modules/autograd_functions/step_function_inputs.py new file mode 100644 index 00000000..272bd877 --- /dev/null +++ b/elasticai/creator/base_modules/autograd_functions/step_function_inputs.py @@ -0,0 +1,36 @@ +from typing import Any, cast + +import torch + + +class StepFunctionInputs(torch.autograd.Function): + @staticmethod + def jvp(ctx: Any, *grad_inputs: Any) -> Any: + raise NotImplementedError() + + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: + if len(args) != 4: + raise TypeError( + "apply() takes exactly four arguments " + "(inputs: torch.Tensor, minimum: float, maximum: float, steps: int)" + ) + inputs: torch.Tensor = args[0] + minimum, maximum, steps = cast(tuple[float, float, int], args[1:4]) + + if steps < 2: + raise ValueError( + f"Number of steps cannot be less than or equal to 1 (steps == {steps})." + ) + + input_lut = torch.linspace(minimum, maximum, steps).flip(dims=[0]) + clipped_inputs = inputs.clamp(min=minimum, max=maximum) + outputs = clipped_inputs.detach().apply_( + lambda x: input_lut[int(sum(input_lut >= x) - 1)] + ) + + return outputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + return *grad_outputs, None diff --git a/elasticai/creator/base_modules/tanh.py b/elasticai/creator/base_modules/tanh.py new file mode 100644 index 00000000..65f1488d --- /dev/null +++ b/elasticai/creator/base_modules/tanh.py @@ -0,0 +1,30 @@ +from typing import cast + +import torch + +from elasticai.creator.base_modules.arithmetics import Arithmetics +from elasticai.creator.base_modules.autograd_functions.step_function_inputs import ( + StepFunctionInputs, +) + + +class Tanh(torch.nn.Tanh): + def __init__( + self, + arithmetics: Arithmetics, + num_steps: int, + sampling_intervall: tuple[float, float] = (-5, 5), + ) -> None: + super().__init__() + self._arithmetics = arithmetics + self.num_steps = num_steps + self.sampling_intervall = sampling_intervall + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + step_inputs = cast( + torch.Tensor, + StepFunctionInputs.apply(inputs, *self.sampling_intervall, self.num_steps), + ) + quantized_step_inputs = self._arithmetics.quantize(step_inputs) + outputs = torch.nn.functional.tanh(quantized_step_inputs) + return self._arithmetics.quantize(outputs) diff --git a/tests/base_modules/autograd_functions/test_step_function_inputs.py b/tests/base_modules/autograd_functions/test_step_function_inputs.py new file mode 100644 index 00000000..45ef919b --- /dev/null +++ b/tests/base_modules/autograd_functions/test_step_function_inputs.py @@ -0,0 +1,42 @@ +from collections.abc import Iterable +from typing import cast + +import pytest +import torch + +from elasticai.creator.base_modules.autograd_functions.step_function_inputs import ( + StepFunctionInputs, +) +from tests.tensor_test_case import assertTensorEqual + + +@pytest.mark.parametrize( + "minimum,maximum,steps,inputs,outputs", + [ + (-3, 3, 2, range(-4, 5), [-3, -3, 3, 3, 3, 3, 3, 3, 3]), + (-3, 3, 3, range(-4, 5), [-3, -3, 0, 0, 0, 3, 3, 3, 3]), + (-5, -2, 3, range(-6, 0), [-5, -5, -3.5, -2, -2, -2]), + (2, 5, 3, range(1, 7), [2, 2, 3.5, 5, 5, 5]), + ], +) +def test_inputs_correctly_mapped_to_step_function_inputs( + minimum: float, + maximum: float, + steps: int, + inputs: Iterable[float], + outputs: Iterable[float], +) -> None: + actual_outputs = cast( + torch.Tensor, + StepFunctionInputs.apply( + torch.tensor(inputs, dtype=torch.float32), minimum, maximum, steps + ), + ) + assertTensorEqual(list(outputs), actual_outputs) + + +@pytest.mark.parametrize("steps", [1, 0]) +def test_raises_error_when_steps_less_than_or_equal_one(steps: int) -> None: + inputs = torch.tensor([1, 2], dtype=torch.float32) + with pytest.raises(ValueError): + StepFunctionInputs.apply(inputs, -1, 1, steps) diff --git a/tests/base_modules/test_tanh.py b/tests/base_modules/test_tanh.py new file mode 100644 index 00000000..4aa5d932 --- /dev/null +++ b/tests/base_modules/test_tanh.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from elasticai.creator.base_modules.float_arithmetics import FloatArithmetics +from elasticai.creator.base_modules.tanh import Tanh +from tests.tensor_test_case import assertTensorEqual + + +def test_single_sample() -> None: + tanh = Tanh(arithmetics=FloatArithmetics(), num_steps=2) + inputs = torch.tensor([-10, -3, -2, -1, 0, 1, 2, 3, 10]) + assertTensorEqual( + expected=[-1, -1, -1, -1, -1, 1, 1, 1, 1], + actual=tanh(inputs), + )