Skip to content

Commit

Permalink
feat: pass step lut to identity step function and improve readablility
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed Jun 3, 2023
1 parent d607e98 commit c1b6747
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions elasticai/creator/base_modules/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ def __init__(
) -> None:
super().__init__()
self._arithmetics = arithmetics
self.num_steps = num_steps
self.sampling_intervall = sampling_intervall
self._step_lut = torch.linspace(*sampling_intervall, num_steps)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def _quantized_step_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
step_inputs = cast(
torch.Tensor,
IdentityStepFunction.apply(
inputs, *self.sampling_intervall, self.num_steps
),
torch.Tensor, IdentityStepFunction.apply(inputs, self._step_lut)
)
quantized_step_inputs = self._arithmetics.quantize(step_inputs)
outputs = torch.nn.functional.tanh(quantized_step_inputs)
return self._arithmetics.quantize(step_inputs)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self._quantized_step_inputs(inputs)
outputs = torch.nn.functional.tanh(inputs)
return self._arithmetics.quantize(outputs)

0 comments on commit c1b6747

Please sign in to comment.