-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement autograd fn to map inputs to a subset of inputs
- Loading branch information
1 parent
24e737e
commit 26c6ec7
Showing
4 changed files
with
123 additions
and
0 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
elasticai/creator/base_modules/autograd_functions/step_function_inputs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Any, cast | ||
|
||
import torch | ||
|
||
|
||
class StepFunctionInputs(torch.autograd.Function): | ||
@staticmethod | ||
def jvp(ctx: Any, *grad_inputs: Any) -> Any: | ||
raise NotImplementedError() | ||
|
||
@staticmethod | ||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: | ||
if len(args) != 4: | ||
raise TypeError( | ||
"apply() takes exactly four arguments " | ||
"(inputs: torch.Tensor, minimum: float, maximum: float, steps: int)" | ||
) | ||
inputs: torch.Tensor = args[0] | ||
minimum, maximum, steps = cast(tuple[float, float, int], args[1:4]) | ||
|
||
if steps < 2: | ||
raise ValueError( | ||
f"Number of steps cannot be less than or equal to 1 (steps == {steps})." | ||
) | ||
|
||
input_lut = torch.linspace(minimum, maximum, steps).flip(dims=[0]) | ||
clipped_inputs = inputs.clamp(min=minimum, max=maximum) | ||
outputs = clipped_inputs.detach().apply_( | ||
lambda x: input_lut[int(sum(input_lut >= x) - 1)] | ||
) | ||
|
||
return outputs | ||
|
||
@staticmethod | ||
def backward(ctx: Any, *grad_outputs: Any) -> Any: | ||
return *grad_outputs, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import cast | ||
|
||
import torch | ||
|
||
from elasticai.creator.base_modules.arithmetics import Arithmetics | ||
from elasticai.creator.base_modules.autograd_functions.step_function_inputs import ( | ||
StepFunctionInputs, | ||
) | ||
|
||
|
||
class Tanh(torch.nn.Tanh): | ||
def __init__( | ||
self, | ||
arithmetics: Arithmetics, | ||
num_steps: int, | ||
sampling_intervall: tuple[float, float] = (-5, 5), | ||
) -> None: | ||
super().__init__() | ||
self._arithmetics = arithmetics | ||
self.num_steps = num_steps | ||
self.sampling_intervall = sampling_intervall | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
step_inputs = cast( | ||
torch.Tensor, | ||
StepFunctionInputs.apply(inputs, *self.sampling_intervall, self.num_steps), | ||
) | ||
quantized_step_inputs = self._arithmetics.quantize(step_inputs) | ||
outputs = torch.nn.functional.tanh(quantized_step_inputs) | ||
return self._arithmetics.quantize(outputs) |
42 changes: 42 additions & 0 deletions
42
tests/base_modules/autograd_functions/test_step_function_inputs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from collections.abc import Iterable | ||
from typing import cast | ||
|
||
import pytest | ||
import torch | ||
|
||
from elasticai.creator.base_modules.autograd_functions.step_function_inputs import ( | ||
StepFunctionInputs, | ||
) | ||
from tests.tensor_test_case import assertTensorEqual | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"minimum,maximum,steps,inputs,outputs", | ||
[ | ||
(-3, 3, 2, range(-4, 5), [-3, -3, 3, 3, 3, 3, 3, 3, 3]), | ||
(-3, 3, 3, range(-4, 5), [-3, -3, 0, 0, 0, 3, 3, 3, 3]), | ||
(-5, -2, 3, range(-6, 0), [-5, -5, -3.5, -2, -2, -2]), | ||
(2, 5, 3, range(1, 7), [2, 2, 3.5, 5, 5, 5]), | ||
], | ||
) | ||
def test_inputs_correctly_mapped_to_step_function_inputs( | ||
minimum: float, | ||
maximum: float, | ||
steps: int, | ||
inputs: Iterable[float], | ||
outputs: Iterable[float], | ||
) -> None: | ||
actual_outputs = cast( | ||
torch.Tensor, | ||
StepFunctionInputs.apply( | ||
torch.tensor(inputs, dtype=torch.float32), minimum, maximum, steps | ||
), | ||
) | ||
assertTensorEqual(list(outputs), actual_outputs) | ||
|
||
|
||
@pytest.mark.parametrize("steps", [1, 0]) | ||
def test_raises_error_when_steps_less_than_or_equal_one(steps: int) -> None: | ||
inputs = torch.tensor([1, 2], dtype=torch.float32) | ||
with pytest.raises(ValueError): | ||
StepFunctionInputs.apply(inputs, -1, 1, steps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
import torch | ||
|
||
from elasticai.creator.base_modules.float_arithmetics import FloatArithmetics | ||
from elasticai.creator.base_modules.tanh import Tanh | ||
from tests.tensor_test_case import assertTensorEqual | ||
|
||
|
||
def test_single_sample() -> None: | ||
tanh = Tanh(arithmetics=FloatArithmetics(), num_steps=2) | ||
inputs = torch.tensor([-10, -3, -2, -1, 0, 1, 2, 3, 10]) | ||
assertTensorEqual( | ||
expected=[-1, -1, -1, -1, -1, 1, 1, 1, 1], | ||
actual=tanh(inputs), | ||
) |