-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(translation): add support for single buffered module to sequential
BREAKING CHANGE
- Loading branch information
Showing
17 changed files
with
231 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
|
||
from elasticai.creator.nn._two_complement_fixed_point_config import ( | ||
TwoComplementFixedPointConfig, | ||
) | ||
from elasticai.creator.nn.arithmetics import Arithmetics | ||
from elasticai.creator.nn.autograd_functions.fixed_point_quantization import ( | ||
FixedPointDequantFunction, | ||
FixedPointQuantFunction, | ||
) | ||
|
||
|
||
class FixedPointArithmetics(Arithmetics): | ||
def __init__(self, config: TwoComplementFixedPointConfig) -> None: | ||
self.config = config | ||
|
||
def quantize(self, a: torch.Tensor) -> torch.Tensor: | ||
return self.round(self.clamp(a)) | ||
|
||
def clamp(self, a: torch.Tensor) -> torch.Tensor: | ||
return torch.clamp( | ||
a, min=self.config.minimum_as_rational, max=self.config.maximum_as_rational | ||
) | ||
|
||
def round(self, a: torch.Tensor) -> torch.Tensor: | ||
def float_to_int(x: torch.Tensor) -> torch.Tensor: | ||
return FixedPointQuantFunction.apply(x, self.config) | ||
|
||
def int_to_fixed_point(x: torch.Tensor) -> torch.Tensor: | ||
return FixedPointDequantFunction.apply(x, self.config) | ||
|
||
return int_to_fixed_point(float_to_int(a)) | ||
|
||
def add(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
return self.clamp(a + b) | ||
|
||
def sum(self, tensor: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor: | ||
summed = tensor | ||
for t in tensors: | ||
summed += t | ||
return self.clamp(summed) | ||
|
||
def mul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
return self.round(self.clamp(a * b)) | ||
|
||
def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
return self.round(self.clamp(torch.matmul(a, b))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch | ||
|
||
from elasticai.creator.nn.arithmetics import Arithmetics | ||
|
||
|
||
class FloatArithmetics(Arithmetics): | ||
def quantize(self, a: torch.Tensor) -> torch.Tensor: | ||
return a | ||
|
||
def clamp(self, a: torch.Tensor) -> torch.Tensor: | ||
return a | ||
|
||
def round(self, a: torch.Tensor) -> torch.Tensor: | ||
return a | ||
|
||
def add(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
return a + b | ||
|
||
def sum(self, tensor: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor: | ||
summed = tensor | ||
for t in tensors: | ||
summed += t | ||
return summed | ||
|
||
def mul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
return a * b | ||
|
||
def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
return torch.matmul(a, b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
elasticai/creator/translatable_modules/vhdl/fp_linear_1d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from elasticai.creator.hdl.translatable import Saveable | ||
from elasticai.creator.hdl.vhdl.designs.fp_linear_1d import FPLinear1d as FPLinearDesign | ||
from elasticai.creator.nn.linear import FixedPointLinear | ||
|
||
|
||
class FPLinear1d(FixedPointLinear): | ||
def translate(self) -> Saveable: | ||
return FPLinearDesign( | ||
frac_bits=self.frac_bits, | ||
total_bits=self.total_bits, | ||
in_feature_num=self.in_features, | ||
out_feature_num=self.out_features, | ||
) |
Oops, something went wrong.