-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #258 from es-ude/252-implemention-for-relu-activat…
…ion-function Implemention for ReLU activation function
- Loading branch information
Showing
6 changed files
with
149 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import torch | ||
|
||
|
||
class ReLU(torch.nn.ReLU): | ||
def __init__(self) -> None: | ||
super().__init__() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from elasticai.creator.hdl.code_generation.template import ( | ||
InProjectTemplate, | ||
module_to_package, | ||
) | ||
from elasticai.creator.hdl.design_base.design import Design, Port | ||
from elasticai.creator.hdl.design_base.ports import create_port_for_base_design | ||
from elasticai.creator.hdl.savable import Path | ||
|
||
|
||
class FPReLU(Design): | ||
def __init__(self, name: str, data_width: int, use_clock: bool) -> None: | ||
super().__init__(name) | ||
self._data_width = data_width | ||
self._clock_option = "true" if use_clock else "false" | ||
|
||
@property | ||
def port(self) -> Port: | ||
return create_port_for_base_design( | ||
x_width=self._data_width, y_width=self._data_width | ||
) | ||
|
||
def save_to(self, destination: Path) -> None: | ||
template = InProjectTemplate( | ||
package=module_to_package(self.__module__), | ||
file_name="fp_relu.tpl.vhd", | ||
parameters=dict( | ||
layer_name=self.name, | ||
data_width=str(self._data_width), | ||
clock_option=self._clock_option, | ||
), | ||
) | ||
destination.create_subpath(self.name).as_file(".vhd").write(template) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from elasticai.creator.base_modules.relu import ReLU | ||
from elasticai.creator.hdl.design_base.design import Design | ||
from elasticai.creator.hdl.translatable import Translatable | ||
|
||
from .design import FPReLU as FPReLUDesign | ||
|
||
|
||
class FPReLU(Translatable, ReLU): | ||
def __init__(self, total_bits: int, use_clock: bool = False) -> None: | ||
super().__init__() | ||
self._total_bits = total_bits | ||
self._use_clock = use_clock | ||
|
||
def translate(self, name: str) -> Design: | ||
return FPReLUDesign( | ||
name=name, | ||
data_width=self._total_bits, | ||
use_clock=self._use_clock, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import cast | ||
|
||
from elasticai.creator.in_memory_path import InMemoryFile, InMemoryPath | ||
from elasticai.creator.nn.vhdl.relu.layer import FPReLU | ||
|
||
|
||
def test_vhdl_code_matches_expected() -> None: | ||
expected = """-- This is the ReLU implementation for fixed-point data | ||
-- it only checks the highest bit of the input data | ||
-- when the CLOCK_OPTION is enabled, please notice the data only updates until the clock arises. | ||
-- Version: 1.0 | ||
-- Created by: Chao | ||
-- Last modified date: 2022.11.06 | ||
library ieee; | ||
use ieee.std_logic_1164.all; | ||
use ieee.numeric_std.all; | ||
entity relu is | ||
generic ( | ||
DATA_WIDTH : integer := 16; | ||
CLOCK_OPTION : boolean := true | ||
); | ||
port ( | ||
enable : in std_logic; | ||
clock : in std_logic; | ||
x : in std_logic_vector(DATA_WIDTH-1 downto 0); | ||
y : out std_logic_vector(DATA_WIDTH-1 downto 0) | ||
); | ||
end entity relu; | ||
architecture rtl of relu is | ||
signal fp_input : signed(DATA_WIDTH-1 downto 0) := (others=>'0'); | ||
signal fp_output : signed(DATA_WIDTH-1 downto 0) := (others=>'0'); | ||
begin | ||
fp_input <= signed(x); | ||
y <= std_logic_vector(fp_output); | ||
clocked: if CLOCK_OPTION generate | ||
main_process : process (enable, clock) | ||
begin | ||
if (enable = '0') then | ||
fp_output <= to_signed(0, DATA_WIDTH); | ||
elsif (rising_edge(clock)) then | ||
if fp_input < 0 then | ||
fp_output <= to_signed(0, DATA_WIDTH); | ||
else | ||
fp_output <= fp_input; | ||
end if; | ||
end if; | ||
end process; | ||
end generate; | ||
async: if (not CLOCK_OPTION) generate | ||
process (enable, fp_input) | ||
begin | ||
if enable = '0' then | ||
fp_output <= to_signed(0, DATA_WIDTH); | ||
else | ||
if fp_input < 0 then | ||
fp_output <= to_signed(0, DATA_WIDTH); | ||
else | ||
fp_output <= fp_input; | ||
end if; | ||
end if; | ||
end process; | ||
end generate; | ||
end architecture rtl; | ||
""".splitlines() | ||
relu = FPReLU(total_bits=16, use_clock=True) | ||
build_path = InMemoryPath("build", parent=None) | ||
design = relu.translate("relu") | ||
design.save_to(build_path) | ||
actual = cast(InMemoryFile, build_path["relu"]).text | ||
assert actual == expected |