Skip to content

Commit

Permalink
feat(tests): add tests for the conv1d design
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed Sep 13, 2023
1 parent 79494fe commit c2c94bd
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions elasticai/creator/nn/fixed_point/conv1d/design_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import cast

import pytest

from elasticai.creator.file_generation.in_memory_path import InMemoryFile, InMemoryPath

from .design import Conv1d


@pytest.fixture
def conv1d_design() -> Conv1d:
return Conv1d(
name="conv1d",
total_bits=16,
frac_bits=8,
in_channels=1,
out_channels=2,
kernel_size=3,
signal_length=4,
weights=[[[1, 1, 1]], [[1, 1, 1]]],
bias=[1, 1],
)


def save_design(design: Conv1d) -> dict[str, str]:
destination = InMemoryPath("conv1d", parent=None)
design.save_to(destination)
files = cast(list[InMemoryFile], list(destination.children.values()))
return {file.name: "\n".join(file.text) for file in files}


def test_saved_design_contains_needed_files(conv1d_design: Conv1d) -> None:
saved_files = save_design(conv1d_design)

expected_files = {"conv1d_w_rom.vhd", "conv1d_b_rom.vhd", "conv1d.vhd"}
actual_files = set(saved_files.keys())

assert expected_files == actual_files


def test_weight_rom_code_generated_correctly(conv1d_design: Conv1d) -> None:
expected_code = """library ieee;
use ieee.std_logic_1164.all;
use ieee.std_logic_unsigned.all;
entity conv1d_w_rom is
port (
clk : in std_logic;
en : in std_logic;
addr : in std_logic_vector(3-1 downto 0);
data : out std_logic_vector(16-1 downto 0)
);
end entity conv1d_w_rom;
architecture rtl of conv1d_w_rom is
type conv1d_w_rom_array_t is array (0 to 2**3-1) of std_logic_vector(16-1 downto 0);
signal ROM : conv1d_w_rom_array_t:=("0000000000000001","0000000000000001","0000000000000001","0000000000000001","0000000000000001","0000000000000001","0000000000000000","0000000000000000");
attribute rom_style : string;
attribute rom_style of ROM : signal is "auto";
begin
ROM_process: process(clk)
begin
if rising_edge(clk) then
if (en = '1') then
data <= ROM(conv_integer(addr));
end if;
end if;
end process ROM_process;
end architecture rtl;"""
saved_files = save_design(conv1d_design)
actual_code = saved_files["conv1d_w_rom.vhd"]
assert expected_code == actual_code


def test_bias_rom_code_generated_correctly(conv1d_design: Conv1d) -> None:
expected_code = """library ieee;
use ieee.std_logic_1164.all;
use ieee.std_logic_unsigned.all;
entity conv1d_b_rom is
port (
clk : in std_logic;
en : in std_logic;
addr : in std_logic_vector(1-1 downto 0);
data : out std_logic_vector(16-1 downto 0)
);
end entity conv1d_b_rom;
architecture rtl of conv1d_b_rom is
type conv1d_b_rom_array_t is array (0 to 2**1-1) of std_logic_vector(16-1 downto 0);
signal ROM : conv1d_b_rom_array_t:=("0000000000000001","0000000000000001");
attribute rom_style : string;
attribute rom_style of ROM : signal is "auto";
begin
ROM_process: process(clk)
begin
if rising_edge(clk) then
if (en = '1') then
data <= ROM(conv_integer(addr));
end if;
end if;
end process ROM_process;
end architecture rtl;"""
saved_files = save_design(conv1d_design)
actual_code = saved_files["conv1d_b_rom.vhd"]
assert expected_code == actual_code


def test_conv1d_code_generated_correctly(conv1d_design: Conv1d) -> None:
expected_code = """-- Dummy File for testing implementation of conv1d Design
16
8
1
2
3"""
saved_files = save_design(conv1d_design)
actual_code = saved_files["conv1d.vhd"]
assert expected_code == actual_code

0 comments on commit c2c94bd

Please sign in to comment.