-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(vhdl): implement the translation of a linear1d layer
- Loading branch information
1 parent
fc1f20e
commit b627e78
Showing
10 changed files
with
290 additions
and
2 deletions.
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
elasticai/creator/tests/vhdl/translator/abstract/layers/test_linear_1d_translatable.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import unittest | ||
|
||
from elasticai.creator.vhdl.components import ( | ||
Linear1dComponent, | ||
LSTMCommonComponent, | ||
RomComponent, | ||
) | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
from elasticai.creator.vhdl.translator.abstract.layers import ( | ||
Linear1dTranslatable, | ||
Linear1dTranslationArgs, | ||
) | ||
|
||
|
||
class Linear1dTranslatableTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.linear = Linear1dTranslatable(weight=[[1, 2, 3]], bias=[1]) | ||
self.translation_args = Linear1dTranslationArgs( | ||
fixed_point_factory=FixedPoint.get_factory(total_bits=8, frac_bits=4) | ||
) | ||
|
||
def test_contains_all_needed_components(self) -> None: | ||
vhdl_components = self.linear.translate(self.translation_args) | ||
|
||
target_components = [ | ||
(Linear1dComponent, "linear_1d.vhd"), | ||
(RomComponent, "w_rom.vhd"), | ||
(RomComponent, "b_rom.vhd"), | ||
(LSTMCommonComponent, "lstm_common.vhd"), | ||
] | ||
actual_components = [(type(x), x.file_name) for x in vhdl_components] | ||
|
||
self.assertEqual(actual_components, target_components) |
25 changes: 25 additions & 0 deletions
25
...ai/creator/tests/vhdl/translator/pytorch/build_functions/test_linear_1d_build_function.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from elasticai.creator.vhdl.translator.pytorch.build_functions import build_linear_1d | ||
|
||
|
||
def arange_parameter( | ||
start: int, end: int, shape: tuple[int, ...] | ||
) -> torch.nn.Parameter: | ||
return torch.nn.Parameter( | ||
torch.reshape(torch.arange(start, end, dtype=torch.float32), shape) | ||
) | ||
|
||
|
||
class Linear1dBuildFunctionTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.linear = torch.nn.Linear(in_features=3, out_features=1) | ||
self.linear.weight = arange_parameter(start=1, end=4, shape=(1, -1)) | ||
self.linear.bias = arange_parameter(start=1, end=2, shape=(-1,)) | ||
|
||
def test_weights_and_bias_correct_set(self) -> None: | ||
linear1d = build_linear_1d(self.linear) | ||
self.assertEqual(linear1d.weight, [[1.0, 2.0, 3.0]]) | ||
self.assertEqual(linear1d.bias, [1.0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import math | ||
from typing import Callable | ||
|
||
from elasticai.creator.resource_utils import read_text | ||
from elasticai.creator.vhdl.language import Code | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
|
||
|
||
class Linear1dComponent: | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
fixed_point_factory: Callable[[float], FixedPoint], | ||
) -> None: | ||
|
||
if out_features != 1: | ||
raise NotImplementedError( | ||
"Currently only one bias is supported (which implies that out_features must be 1)." | ||
) | ||
|
||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.data_width = self._derive_data_width(fixed_point_factory) | ||
self.addr_width = self._calculate_addr_width(in_features * out_features) | ||
|
||
@staticmethod | ||
def _derive_data_width(fixed_point_factory: Callable[[float], FixedPoint]) -> int: | ||
return fixed_point_factory(0).total_bits | ||
|
||
@staticmethod | ||
def _calculate_addr_width(num_items: int) -> int: | ||
return max(1, math.ceil(math.log2(num_items))) | ||
|
||
@property | ||
def file_name(self) -> str: | ||
return "linear_1d.vhd" | ||
|
||
def __call__(self) -> Code: | ||
template = read_text("elasticai.creator.vhdl.templates", "linear_1d.tpl.vhd") | ||
|
||
code = template.format( | ||
addr_width=self.addr_width, | ||
data_width=self.out_features, | ||
in_feature_count=self.in_features, | ||
out_feature_count=self.out_features, | ||
) | ||
|
||
yield from code.splitlines() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
LIBRARY ieee; | ||
USE ieee.std_logic_1164.all; | ||
use ieee.numeric_std.all; | ||
|
||
library work; | ||
use work.lstm_common.all; | ||
|
||
entity linear_1d is | ||
generic ( | ||
ADDR_WIDTH :integer := {addr_width}; | ||
DATA_WIDTH :integer := {data_width}; | ||
IN_FEATURE_COUNT :integer := {in_feature_count}; | ||
OUT_FEATURE_COUNT :integer := {out_feature_count} | ||
); | ||
port ( | ||
clock : in std_logic; | ||
addr : out std_logic_vector(ADDR_WIDTH-1 downto 0); | ||
enable : in std_logic; | ||
x_in : in std_logic_vector(DATA_WIDTH-1 downto 0); | ||
y_out : out std_logic_vector(DATA_WIDTH-1 downto 0); | ||
done : out std_logic | ||
) ; | ||
end linear_1d ; | ||
|
||
architecture rtl of linear_1d is | ||
|
||
signal addr_s : std_logic_vector(ADDR_WIDTH-1 downto 0) ; | ||
signal test_mul : signed(2*DATA_WIDTH-1 downto 0) ; | ||
signal test_sum : signed(2*DATA_WIDTH-1 downto 0) ; | ||
signal w_in, b_in : signed(DATA_WIDTH-1 downto 0) ; | ||
signal std_w_in, std_b_in : std_logic_vector(DATA_WIDTH-1 downto 0) ; | ||
signal n_clock : std_logic ; | ||
begin | ||
|
||
n_clock <= not clock ; | ||
|
||
process(clock) | ||
variable var_addr : integer range 0 to 2**ADDR_WIDTH-1:= 0; | ||
variable fsm : integer:=0; | ||
variable temp_mul : signed(2*DATA_WIDTH-1 downto 0); | ||
variable sum : signed(2*DATA_WIDTH-1 downto 0):=(others=>'0'); | ||
variable temp_x : signed(DATA_WIDTH-1 downto 0); | ||
variable temp_w : signed(DATA_WIDTH-1 downto 0); | ||
variable prefetc_flag:std_logic; | ||
begin | ||
if rising_edge(clock) then | ||
if enable = '0' then | ||
var_addr := 0; | ||
fsm := 0; | ||
sum := (others=>'0'); | ||
temp_mul := (others=>'0'); | ||
done <= '0'; | ||
else | ||
|
||
if fsm=0 then | ||
fsm := 1; | ||
elsif fsm =1 then | ||
if prefetc_flag='0' then | ||
prefetc_flag := '1'; | ||
temp_x := signed(x_in); | ||
temp_w := signed(w_in); | ||
temp_mul := multiply_16_8_without_cut(temp_x,temp_w); | ||
else | ||
sum := sum + temp_mul; | ||
var_addr := var_addr + 1; | ||
if var_addr=IN_FEATURE_COUNT then | ||
fsm := 2; | ||
var_addr := 0; | ||
end if; | ||
prefetc_flag := '0'; | ||
end if; | ||
elsif fsm =2 then | ||
done <= '1'; | ||
y_out <= std_logic_vector(cut_16_to_8(test_sum)+signed(b_in)); | ||
end if; | ||
end if; | ||
addr_s <= std_logic_vector(to_unsigned(var_addr, ADDR_WIDTH)); | ||
|
||
test_mul <= temp_mul; | ||
test_sum <= sum; | ||
end if; | ||
end process; | ||
|
||
addr <= addr_s; | ||
|
||
-- Weights | ||
rom_w : entity work.w_rom(rtl) | ||
port map ( | ||
clk => n_clock, | ||
en => '1', | ||
addr => addr_s, | ||
data => std_w_in | ||
); | ||
w_in <= signed(std_w_in); | ||
|
||
-- Bias | ||
rom_b : entity work.b_rom(rtl) | ||
port map ( | ||
clk => n_clock, | ||
en => '1', | ||
addr => (others=>'0'), -- ToDo: at the moment only one bias is supported | ||
data => std_b_in | ||
); | ||
b_in <= signed(std_b_in); | ||
|
||
end architecture ; -- rtl |
4 changes: 4 additions & 0 deletions
4
elasticai/creator/vhdl/translator/abstract/layers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
elasticai/creator/vhdl/translator/abstract/layers/linear_1d_translatable.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from collections.abc import Iterable | ||
from dataclasses import dataclass | ||
from itertools import chain | ||
from typing import Callable | ||
|
||
import numpy as np | ||
|
||
from elasticai.creator.vhdl.components import ( | ||
Linear1dComponent, | ||
LSTMCommonComponent, | ||
RomComponent, | ||
) | ||
from elasticai.creator.vhdl.number_representations import FixedPoint | ||
from elasticai.creator.vhdl.vhdl_component import VHDLModule | ||
|
||
|
||
@dataclass | ||
class Linear1dTranslationArgs: | ||
fixed_point_factory: Callable[[float], FixedPoint] | ||
|
||
|
||
@dataclass | ||
class Linear1dTranslatable: | ||
weight: list[list[float]] | ||
bias: list[float] | ||
|
||
def translate(self, args: Linear1dTranslationArgs) -> VHDLModule: | ||
def to_fp(values: Iterable[float]) -> list[FixedPoint]: | ||
return list(map(args.fixed_point_factory, values)) | ||
|
||
out_features, in_features = np.shape(self.weight) | ||
|
||
yield Linear1dComponent( | ||
in_features=in_features, | ||
out_features=out_features, | ||
fixed_point_factory=args.fixed_point_factory, | ||
) | ||
|
||
flat_weight = chain(*self.weight) | ||
|
||
yield RomComponent( | ||
rom_name="w_rom", values=to_fp(flat_weight), resource_option="auto" | ||
) | ||
|
||
yield RomComponent( | ||
rom_name="b_rom", values=to_fp(self.bias), resource_option="auto" | ||
) | ||
|
||
yield LSTMCommonComponent() |
10 changes: 8 additions & 2 deletions
10
elasticai/creator/vhdl/translator/pytorch/build_function_mappings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
from elasticai.creator.vhdl.translator.build_function_mapping import ( | ||
BuildFunctionMapping, | ||
) | ||
from elasticai.creator.vhdl.translator.pytorch.build_functions import build_lstm | ||
from elasticai.creator.vhdl.translator.pytorch.build_functions import ( | ||
build_linear_1d, | ||
build_lstm, | ||
) | ||
|
||
DEFAULT_BUILD_FUNCTION_MAPPING = BuildFunctionMapping( | ||
mapping={"torch.nn.modules.rnn.LSTM": build_lstm} | ||
mapping={ | ||
"torch.nn.modules.rnn.LSTM": build_lstm, | ||
"torch.nn.modules.linear.Linear": build_linear_1d, | ||
} | ||
) |
3 changes: 3 additions & 0 deletions
3
elasticai/creator/vhdl/translator/pytorch/build_functions/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from elasticai.creator.vhdl.translator.pytorch.build_functions.linear_1d_build_function import ( | ||
build_linear_1d, | ||
) | ||
from elasticai.creator.vhdl.translator.pytorch.build_functions.lstm_build_function import ( | ||
build_lstm, | ||
) |
12 changes: 12 additions & 0 deletions
12
elasticai/creator/vhdl/translator/pytorch/build_functions/linear_1d_build_function.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
|
||
from elasticai.creator.vhdl.translator.abstract.layers import Linear1dTranslatable | ||
|
||
|
||
def build_linear_1d(linear: torch.nn.Linear) -> Linear1dTranslatable: | ||
def to_list(tensor: torch.Tensor) -> list: | ||
return tensor.detach().numpy().tolist() | ||
|
||
return Linear1dTranslatable( | ||
weight=to_list(linear.weight), bias=to_list(linear.bias) | ||
) |