-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #136 from es-ude/127-implementation-for-fully-conn…
…ected-layer Implementation for linear layer
- Loading branch information
Showing
34 changed files
with
1,321 additions
and
633 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
from elasticai.creator.vhdl.translator.abstract.layers import ( | ||
Linear1dTranslationArgs, | ||
LSTMTranslationArgs, | ||
) | ||
from elasticai.creator.vhdl.translator.pytorch import translator | ||
from elasticai.creator.vhdl.translator.pytorch.build_function_mappings import ( | ||
DEFAULT_BUILD_FUNCTION_MAPPING, | ||
) | ||
|
||
|
||
def read_commandline_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--build_dir", required=True, type=Path) | ||
return parser.parse_args() | ||
|
||
|
||
class LSTMModel(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.lstm = torch.nn.LSTM(input_size=1, hidden_size=10) | ||
self.linear = torch.nn.Linear(in_features=10, out_features=1) | ||
|
||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
return self.linear(self.lstm(x)[0]) | ||
|
||
|
||
def main() -> None: | ||
args = read_commandline_args() | ||
|
||
model = LSTMModel() | ||
|
||
fixed_point_factory = FixedPoint.get_factory(total_bits=8, frac_bits=4) | ||
work_library_name = "xil_defaultlib" | ||
translation_args = dict( | ||
LSTMTranslatable=LSTMTranslationArgs( | ||
fixed_point_factory=fixed_point_factory, | ||
sigmoid_resolution=(-2.5, 2.5, 256), | ||
tanh_resolution=(-1, 1, 256), | ||
work_library_name=work_library_name, | ||
), | ||
Linear1dTranslatable=Linear1dTranslationArgs( | ||
fixed_point_factory=fixed_point_factory, | ||
work_library_name=work_library_name, | ||
), | ||
) | ||
|
||
translatable_layers = translator.translate_model( | ||
model=model, build_function_mapping=DEFAULT_BUILD_FUNCTION_MAPPING | ||
) | ||
|
||
code_repr = translator.generate_code( | ||
translatable_layers=translatable_layers, translation_args=translation_args | ||
) | ||
|
||
translator.save_code(code_repr=code_repr, path=args.build_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
28 changes: 28 additions & 0 deletions
28
elasticai/creator/tests/vhdl/components/test_linear_1d_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import unittest | ||
|
||
from elasticai.creator.vhdl.components import Linear1dComponent | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
|
||
|
||
class Linear1dComponentTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.component = Linear1dComponent( | ||
in_features=20, | ||
out_features=1, | ||
fixed_point_factory=FixedPoint.get_factory(total_bits=16, frac_bits=8), | ||
work_library_name="work", | ||
) | ||
|
||
def test_derives_correct_data_width(self) -> None: | ||
self.assertEqual(self.component.data_width, 16) | ||
|
||
def test_calculates_correct_addr_width(self) -> None: | ||
self.assertEqual(self.component.addr_width, 5) | ||
|
||
def test_out_features_larger_1_raises_not_implemented_error(self) -> None: | ||
with self.assertRaises(NotImplementedError): | ||
_ = Linear1dComponent( | ||
in_features=3, | ||
out_features=2, | ||
fixed_point_factory=FixedPoint.get_factory(total_bits=8, frac_bits=4), | ||
) |
27 changes: 27 additions & 0 deletions
27
elasticai/creator/tests/vhdl/components/test_lstm_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import unittest | ||
|
||
from elasticai.creator.vhdl.components import LSTMComponent | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
|
||
|
||
class LSTMComponentTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.lstm = LSTMComponent( | ||
input_size=5, | ||
hidden_size=3, | ||
fixed_point_factory=FixedPoint.get_factory(total_bits=8, frac_bits=4), | ||
work_library_name="xil_defaultlib", | ||
) | ||
|
||
def test_fixed_point_params_correct_derived(self): | ||
self.assertEqual(8, self.lstm.data_width) | ||
self.assertEqual(4, self.lstm.frac_width) | ||
|
||
def test_x_h_addr_width_correct_set(self): | ||
self.assertEqual(3, self.lstm.x_h_addr_width) | ||
|
||
def test_hidden_addr_width_correct_set(self): | ||
self.assertEqual(3, self.lstm.hidden_addr_width) | ||
|
||
def test_w_addr_width_correct_set(self): | ||
self.assertEqual(5, self.lstm.w_addr_width) |
26 changes: 26 additions & 0 deletions
26
elasticai/creator/tests/vhdl/components/test_rom_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
|
||
from elasticai.creator.vhdl.components import RomComponent | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
|
||
|
||
class RomComponentTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
fp = FixedPoint.get_factory(total_bits=16, frac_bits=8) | ||
self.rom = RomComponent( | ||
rom_name="test_rom", | ||
values=[fp(i) for i in range(20)], | ||
resource_option="auto", | ||
) | ||
|
||
def test_data_width_correct_derived(self) -> None: | ||
self.assertEqual(self.rom.data_width, 16) | ||
|
||
def test_addr_width_correct_calculated(self) -> None: | ||
self.assertEqual(self.rom.addr_width, 5) | ||
|
||
def test_correct_number_of_values(self) -> None: | ||
self.assertEqual(len(self.rom.hex_values), 32) | ||
|
||
def test_values_correct_padded(self) -> None: | ||
self.assertEqual(['x"0000"'] * 12, self.rom.hex_values[20:]) |
33 changes: 33 additions & 0 deletions
33
elasticai/creator/tests/vhdl/translator/abstract/layers/test_linear_1d_translatable.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import unittest | ||
|
||
from elasticai.creator.vhdl.components import ( | ||
Linear1dComponent, | ||
LSTMCommonComponent, | ||
RomComponent, | ||
) | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
from elasticai.creator.vhdl.translator.abstract.layers import ( | ||
Linear1dTranslatable, | ||
Linear1dTranslationArgs, | ||
) | ||
|
||
|
||
class Linear1dTranslatableTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.linear = Linear1dTranslatable(weight=[[1, 2, 3]], bias=[1]) | ||
self.translation_args = Linear1dTranslationArgs( | ||
fixed_point_factory=FixedPoint.get_factory(total_bits=8, frac_bits=4) | ||
) | ||
|
||
def test_contains_all_needed_components(self) -> None: | ||
vhdl_components = self.linear.translate(self.translation_args) | ||
|
||
target_components = [ | ||
(Linear1dComponent, "linear_1d.vhd"), | ||
(RomComponent, "w_rom.vhd"), | ||
(RomComponent, "b_rom.vhd"), | ||
(LSTMCommonComponent, "lstm_common.vhd"), | ||
] | ||
actual_components = [(type(x), x.file_name) for x in vhdl_components] | ||
|
||
self.assertEqual(actual_components, target_components) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
...ai/creator/tests/vhdl/translator/pytorch/build_functions/test_linear_1d_build_function.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from elasticai.creator.vhdl.translator.pytorch.build_functions import build_linear_1d | ||
|
||
|
||
def arange_parameter( | ||
start: int, end: int, shape: tuple[int, ...] | ||
) -> torch.nn.Parameter: | ||
return torch.nn.Parameter( | ||
torch.reshape(torch.arange(start, end, dtype=torch.float32), shape) | ||
) | ||
|
||
|
||
class Linear1dBuildFunctionTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.linear = torch.nn.Linear(in_features=3, out_features=1) | ||
self.linear.weight = arange_parameter(start=1, end=4, shape=(1, -1)) | ||
self.linear.bias = arange_parameter(start=1, end=2, shape=(-1,)) | ||
|
||
def test_weights_and_bias_correct_set(self) -> None: | ||
linear1d = build_linear_1d(self.linear) | ||
self.assertEqual(linear1d.weight, [[1.0, 2.0, 3.0]]) | ||
self.assertEqual(linear1d.bias, [1.0]) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.