Skip to content

Commit

Permalink
fix: remove unnecessary output quantization of the SiLU base module
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed Aug 29, 2023
1 parent a5aecb3 commit 350faa5
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 36 deletions.
15 changes: 2 additions & 13 deletions elasticai/creator/base_modules/silu_with_trainable_scale_beta.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
from typing import Protocol

import torch

from elasticai.creator.base_modules.math_operations import Quantize


class MathOperations(Quantize, Protocol):
...


class SiLUWithTrainableScaleBeta(torch.nn.SiLU):
def __init__(self, operations: MathOperations) -> None:
def __init__(self) -> None:
super().__init__(inplace=False)
self._operations = operations
self.scale = torch.nn.Parameter(torch.ones(1, requires_grad=True))
self.beta = torch.nn.Parameter(torch.zeros(1, requires_grad=True))

def forward(self, input: torch.Tensor) -> torch.Tensor:
a = self.scale * super().forward(input) + self.beta
x = self._operations.quantize(a)
return x
return self.scale * super().forward(input) + self.beta
3 changes: 1 addition & 2 deletions elasticai/creator/base_modules/torch_math_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from .conv1d import MathOperations as Conv1dOps
from .linear import MathOperations as LinearOps
from .lstm_cell import MathOperations as LSTMOps
from .silu_with_trainable_scale_beta import MathOperations as SiluOps


class TorchMathOperations(LinearOps, Conv1dOps, LSTMOps, SiluOps):
class TorchMathOperations(LinearOps, Conv1dOps, LSTMOps):
def quantize(self, a: torch.Tensor) -> torch.Tensor:
return a

Expand Down
5 changes: 1 addition & 4 deletions elasticai/creator/nn/binary/_math_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
from elasticai.creator.base_modules.conv1d import MathOperations as Conv1dOps
from elasticai.creator.base_modules.linear import MathOperations as LinearOps
from elasticai.creator.base_modules.lstm_cell import MathOperations as LSTMOps
from elasticai.creator.base_modules.silu_with_trainable_scale_beta import (
MathOperations as SiLUOps,
)

from ._binary_quantization_function import Binarize


class MathOperations(LinearOps, Conv1dOps, LSTMOps, SiLUOps):
class MathOperations(LinearOps, Conv1dOps, LSTMOps):
def quantize(self, a: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, Binarize.apply(a))

Expand Down
5 changes: 1 addition & 4 deletions elasticai/creator/nn/fixed_point/_math_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
from elasticai.creator.base_modules.conv1d import MathOperations as Conv1dOps
from elasticai.creator.base_modules.linear import MathOperations as LinearOps
from elasticai.creator.base_modules.lstm_cell import MathOperations as LSTMOps
from elasticai.creator.base_modules.silu_with_trainable_scale_beta import (
MathOperations as SiLUOps,
)

from ._round_to_fixed_point import RoundToFixedPoint
from ._two_complement_fixed_point_config import FixedPointConfig


class MathOperations(LinearOps, Conv1dOps, LSTMOps, SiLUOps):
class MathOperations(LinearOps, Conv1dOps, LSTMOps):
def __init__(self, config: FixedPointConfig) -> None:
self.config = config

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from elasticai.creator.base_modules.silu_with_trainable_scale_beta import (
SiLUWithTrainableScaleBeta as SiLUWithTrainableScaleBetaBase,
)
from elasticai.creator.nn.fixed_point._math_operations import MathOperations
from elasticai.creator.nn.fixed_point._two_complement_fixed_point_config import (
FixedPointConfig,
)

from .precomputed_module import PrecomputedModule

Expand All @@ -18,11 +14,7 @@ def __init__(
sampling_intervall: tuple[float, float] = (-10, 10),
) -> None:
super().__init__(
base_module=SiLUWithTrainableScaleBetaBase(
operations=MathOperations(
config=FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
),
),
base_module=SiLUWithTrainableScaleBetaBase(),
total_bits=total_bits,
frac_bits=frac_bits,
num_steps=num_steps,
Expand Down
5 changes: 1 addition & 4 deletions elasticai/creator/nn/float/_math_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
from elasticai.creator.base_modules.conv1d import MathOperations as Conv1dOps
from elasticai.creator.base_modules.linear import MathOperations as LinearOps
from elasticai.creator.base_modules.lstm_cell import MathOperations as LSTMOps
from elasticai.creator.base_modules.silu_with_trainable_scale_beta import (
MathOperations as SiLUOps,
)

from ._round_to_float import RoundToFloat


class MathOperations(LinearOps, Conv1dOps, LSTMOps, SiLUOps):
class MathOperations(LinearOps, Conv1dOps, LSTMOps):
def __init__(self, mantissa_bits: int, exponent_bits: int) -> None:
self.mantissa_bits = mantissa_bits
self.exponent_bits = exponent_bits
Expand Down

0 comments on commit 350faa5

Please sign in to comment.