Skip to content

Commit

Permalink
feat: implement autograd fn to map inputs to a subset of inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed May 29, 2023
1 parent 24e737e commit 26c6ec7
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, cast

import torch


class StepFunctionInputs(torch.autograd.Function):
@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
raise NotImplementedError()

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
if len(args) != 4:
raise TypeError(
"apply() takes exactly four arguments "
"(inputs: torch.Tensor, minimum: float, maximum: float, steps: int)"
)
inputs: torch.Tensor = args[0]
minimum, maximum, steps = cast(tuple[float, float, int], args[1:4])

if steps < 2:
raise ValueError(
f"Number of steps cannot be less than or equal to 1 (steps == {steps})."
)

input_lut = torch.linspace(minimum, maximum, steps).flip(dims=[0])
clipped_inputs = inputs.clamp(min=minimum, max=maximum)
outputs = clipped_inputs.detach().apply_(
lambda x: input_lut[int(sum(input_lut >= x) - 1)]
)

return outputs

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
return *grad_outputs, None
30 changes: 30 additions & 0 deletions elasticai/creator/base_modules/tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import cast

import torch

from elasticai.creator.base_modules.arithmetics import Arithmetics
from elasticai.creator.base_modules.autograd_functions.step_function_inputs import (
StepFunctionInputs,
)


class Tanh(torch.nn.Tanh):
def __init__(
self,
arithmetics: Arithmetics,
num_steps: int,
sampling_intervall: tuple[float, float] = (-5, 5),
) -> None:
super().__init__()
self._arithmetics = arithmetics
self.num_steps = num_steps
self.sampling_intervall = sampling_intervall

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
step_inputs = cast(
torch.Tensor,
StepFunctionInputs.apply(inputs, *self.sampling_intervall, self.num_steps),
)
quantized_step_inputs = self._arithmetics.quantize(step_inputs)
outputs = torch.nn.functional.tanh(quantized_step_inputs)
return self._arithmetics.quantize(outputs)
42 changes: 42 additions & 0 deletions tests/base_modules/autograd_functions/test_step_function_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections.abc import Iterable
from typing import cast

import pytest
import torch

from elasticai.creator.base_modules.autograd_functions.step_function_inputs import (
StepFunctionInputs,
)
from tests.tensor_test_case import assertTensorEqual


@pytest.mark.parametrize(
"minimum,maximum,steps,inputs,outputs",
[
(-3, 3, 2, range(-4, 5), [-3, -3, 3, 3, 3, 3, 3, 3, 3]),
(-3, 3, 3, range(-4, 5), [-3, -3, 0, 0, 0, 3, 3, 3, 3]),
(-5, -2, 3, range(-6, 0), [-5, -5, -3.5, -2, -2, -2]),
(2, 5, 3, range(1, 7), [2, 2, 3.5, 5, 5, 5]),
],
)
def test_inputs_correctly_mapped_to_step_function_inputs(
minimum: float,
maximum: float,
steps: int,
inputs: Iterable[float],
outputs: Iterable[float],
) -> None:
actual_outputs = cast(
torch.Tensor,
StepFunctionInputs.apply(
torch.tensor(inputs, dtype=torch.float32), minimum, maximum, steps
),
)
assertTensorEqual(list(outputs), actual_outputs)


@pytest.mark.parametrize("steps", [1, 0])
def test_raises_error_when_steps_less_than_or_equal_one(steps: int) -> None:
inputs = torch.tensor([1, 2], dtype=torch.float32)
with pytest.raises(ValueError):
StepFunctionInputs.apply(inputs, -1, 1, steps)
15 changes: 15 additions & 0 deletions tests/base_modules/test_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
import torch

from elasticai.creator.base_modules.float_arithmetics import FloatArithmetics
from elasticai.creator.base_modules.tanh import Tanh
from tests.tensor_test_case import assertTensorEqual


def test_single_sample() -> None:
tanh = Tanh(arithmetics=FloatArithmetics(), num_steps=2)
inputs = torch.tensor([-10, -3, -2, -1, 0, 1, 2, 3, 10])
assertTensorEqual(
expected=[-1, -1, -1, -1, -1, 1, 1, 1, 1],
actual=tanh(inputs),
)

0 comments on commit 26c6ec7

Please sign in to comment.