Skip to content

Commit

Permalink
Merge pull request #258 from es-ude/252-implemention-for-relu-activat…
Browse files Browse the repository at this point in the history
…ion-function

Implemention for ReLU activation function
  • Loading branch information
julianhoever authored Jun 8, 2023
2 parents 1800a76 + 62c1555 commit a55d1c0
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 20 deletions.
6 changes: 6 additions & 0 deletions elasticai/creator/base_modules/relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


class ReLU(torch.nn.ReLU):
def __init__(self) -> None:
super().__init__()
Empty file.
32 changes: 32 additions & 0 deletions elasticai/creator/nn/vhdl/relu/design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from elasticai.creator.hdl.code_generation.template import (
InProjectTemplate,
module_to_package,
)
from elasticai.creator.hdl.design_base.design import Design, Port
from elasticai.creator.hdl.design_base.ports import create_port_for_base_design
from elasticai.creator.hdl.savable import Path


class FPReLU(Design):
def __init__(self, name: str, data_width: int, use_clock: bool) -> None:
super().__init__(name)
self._data_width = data_width
self._clock_option = "true" if use_clock else "false"

@property
def port(self) -> Port:
return create_port_for_base_design(
x_width=self._data_width, y_width=self._data_width
)

def save_to(self, destination: Path) -> None:
template = InProjectTemplate(
package=module_to_package(self.__module__),
file_name="fp_relu.tpl.vhd",
parameters=dict(
layer_name=self.name,
data_width=str(self._data_width),
clock_option=self._clock_option,
),
)
destination.create_subpath(self.name).as_file(".vhd").write(template)
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,25 @@ library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

entity fp_relu_${layer_name} is
generic (
DATA_WIDTH : integer := ${data_width};
CLOCK_OPTION : boolean := ${clock_option}
);
port (
enable : in std_logic;
clock : in std_logic;
input : in std_logic_vector(DATA_WIDTH-1 downto 0);
output : out std_logic_vector(DATA_WIDTH-1 downto 0)
);
end entity fp_relu_${layer_name};

architecture rtl of fp_relu_${layer_name} is

entity ${layer_name} is
generic (
DATA_WIDTH : integer := ${data_width};
CLOCK_OPTION : boolean := ${clock_option}
);
port (
enable : in std_logic;
clock : in std_logic;
x : in std_logic_vector(DATA_WIDTH-1 downto 0);
y : out std_logic_vector(DATA_WIDTH-1 downto 0)
);
end entity ${layer_name};

architecture rtl of ${layer_name} is
signal fp_input : signed(DATA_WIDTH-1 downto 0) := (others=>'0');
signal fp_output : signed(DATA_WIDTH-1 downto 0) := (others=>'0');

begin

fp_input <= signed(input);
output <= std_logic_vector(fp_output);
fp_input <= signed(x);
y <= std_logic_vector(fp_output);

clocked: if CLOCK_OPTION generate
main_process : process (enable, clock)
Expand Down Expand Up @@ -62,5 +59,4 @@ begin
end if;
end process;
end generate;

end architecture rtl;
19 changes: 19 additions & 0 deletions elasticai/creator/nn/vhdl/relu/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from elasticai.creator.base_modules.relu import ReLU
from elasticai.creator.hdl.design_base.design import Design
from elasticai.creator.hdl.translatable import Translatable

from .design import FPReLU as FPReLUDesign


class FPReLU(Translatable, ReLU):
def __init__(self, total_bits: int, use_clock: bool = False) -> None:
super().__init__()
self._total_bits = total_bits
self._use_clock = use_clock

def translate(self, name: str) -> Design:
return FPReLUDesign(
name=name,
data_width=self._total_bits,
use_clock=self._use_clock,
)
76 changes: 76 additions & 0 deletions tests/nn/test_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import cast

from elasticai.creator.in_memory_path import InMemoryFile, InMemoryPath
from elasticai.creator.nn.vhdl.relu.layer import FPReLU


def test_vhdl_code_matches_expected() -> None:
expected = """-- This is the ReLU implementation for fixed-point data
-- it only checks the highest bit of the input data
-- when the CLOCK_OPTION is enabled, please notice the data only updates until the clock arises.
-- Version: 1.0
-- Created by: Chao
-- Last modified date: 2022.11.06
library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;
entity relu is
generic (
DATA_WIDTH : integer := 16;
CLOCK_OPTION : boolean := true
);
port (
enable : in std_logic;
clock : in std_logic;
x : in std_logic_vector(DATA_WIDTH-1 downto 0);
y : out std_logic_vector(DATA_WIDTH-1 downto 0)
);
end entity relu;
architecture rtl of relu is
signal fp_input : signed(DATA_WIDTH-1 downto 0) := (others=>'0');
signal fp_output : signed(DATA_WIDTH-1 downto 0) := (others=>'0');
begin
fp_input <= signed(x);
y <= std_logic_vector(fp_output);
clocked: if CLOCK_OPTION generate
main_process : process (enable, clock)
begin
if (enable = '0') then
fp_output <= to_signed(0, DATA_WIDTH);
elsif (rising_edge(clock)) then
if fp_input < 0 then
fp_output <= to_signed(0, DATA_WIDTH);
else
fp_output <= fp_input;
end if;
end if;
end process;
end generate;
async: if (not CLOCK_OPTION) generate
process (enable, fp_input)
begin
if enable = '0' then
fp_output <= to_signed(0, DATA_WIDTH);
else
if fp_input < 0 then
fp_output <= to_signed(0, DATA_WIDTH);
else
fp_output <= fp_input;
end if;
end if;
end process;
end generate;
end architecture rtl;
""".splitlines()
relu = FPReLU(total_bits=16, use_clock=True)
build_path = InMemoryPath("build", parent=None)
design = relu.translate("relu")
design.save_to(build_path)
actual = cast(InMemoryFile, build_path["relu"]).text
assert actual == expected

0 comments on commit a55d1c0

Please sign in to comment.