Skip to content

Commit

Permalink
refactor(translation): refactor autowiring for sequential network module
Browse files Browse the repository at this point in the history
  • Loading branch information
glencoe committed Mar 10, 2023
1 parent 5402782 commit 431862f
Show file tree
Hide file tree
Showing 29 changed files with 733 additions and 164 deletions.
2 changes: 1 addition & 1 deletion elasticai/creator/hdl/code_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._abstract_base_template import AbstractBaseTemplate
from .abstract_base_template import AbstractBaseTemplate
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from itertools import chain, repeat
from string import Template as StringTemplate
from typing import Iterable, Iterator, cast
from typing import Iterable, Iterator, Mapping, cast

from elasticai.creator.resource_utils import read_text


class AbstractBaseTemplate(ABC):
Expand Down Expand Up @@ -36,28 +39,10 @@ def update_parameters(self, **parameters: str | Iterable[str]) -> None:
(
single_line_parameters,
multiline_parameters,
) = self._split_single_and_multiline_parameters(parameters)
) = _split_single_and_multiline_parameters(parameters)
self._parameters.update(**single_line_parameters)
self._multiline_parameters.update(**multiline_parameters)

@staticmethod
def _split_single_and_multiline_parameters(
parameters: dict[str, str | Iterable[str]],
) -> tuple[dict[str, str], dict[str, Iterable[str]]]:
single_line_parameters: dict[str, str] = dict(
cast(
Iterator[tuple[str, str]],
filter(lambda i: isinstance(i[1], str), parameters.items()),
)
)
multiline_parameters = dict(
cast(
Iterator[tuple[str, Iterable[str]]],
filter(lambda i: not isinstance(i[1], str), parameters.items()),
)
)
return single_line_parameters, multiline_parameters

@property
def single_line_parameters(self) -> dict[str, str]:
return dict(**self._parameters)
Expand All @@ -83,6 +68,69 @@ def lines(self) -> list[str]:
return list(lines)


def module_to_package(module: str) -> str:
return ".".join(module.split(".")[:-1])


@dataclass
class TemplateConfig:
"""
Used in design definition, by the hw designer.
HW designer just provides template configs and port definitions as a design.
Contributor of a new translatable module provides design and ml module as well as how to map parameters
Creator takes these and uses the template expander to generate the correct file
"""

package: str
file_name: str
parameters: dict[str, str | list[str]]


class TemplateExpander:
"""
Used during translation by the creator tool. HW designer does not need to touch this or inherit from it.
"""

def _read_raw_template(self) -> Iterator[str]:
return read_text(
self.config.package,
self.config.file_name,
)

def __init__(self, config: TemplateConfig):
super().__init__()
self.config = config

def lines(self) -> list[str]:
single_line_params, multi_line_params = _split_single_and_multiline_parameters(
self.config.parameters
)
template = self._read_raw_template()
_lines = _expand_template(template, **single_line_params)
_lines = _expand_multiline_template(_lines, **multi_line_params)
return list(_lines)


def _split_single_and_multiline_parameters(
parameters: Mapping[str, str | Iterable[str]],
) -> tuple[dict[str, str], dict[str, Iterable[str]]]:
single_line_parameters: dict[str, str] = dict(
cast(
Iterator[tuple[str, str]],
filter(lambda i: isinstance(i[1], str), parameters.items()),
)
)
multiline_parameters = dict(
cast(
Iterator[tuple[str, Iterable[str]]],
filter(lambda i: not isinstance(i[1], str), parameters.items()),
)
)
return single_line_parameters, multiline_parameters


def _expand_multiline_template(
template: str | list[str] | Iterator[str], **kwargs: Iterable[str]
) -> Iterator[str]:
Expand Down
4 changes: 4 additions & 0 deletions elasticai/creator/hdl/code_generation/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

def calculate_address_width(num_items: int) -> int:
return max(1, math.ceil(math.log2(num_items)))


def to_hex(number: int, bit_width: int) -> str:
return f"{number:0{math.ceil(bit_width / 4)}x}"
10 changes: 4 additions & 6 deletions elasticai/creator/hdl/design_base/network_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,27 @@ def __init__(self, name: str, x_width: int, y_width: int):
outgoing=[_signals.y(y_width)],
)

@property
def port(self) -> Port:
return self._port


class BufferedNetworkBlock(Design, ABC):
def __init__(
self,
name: str,
x_width: int,
y_width: int,
x_count: int,
y_count: int,
):
super().__init__(name)
in_signals = [
_signals.enable(),
_signals.clock(),
_signals.x(x_width),
_signals.y_address(calculate_address_width(y_width)),
_signals.y_address(calculate_address_width(y_count)),
]
out_signals = [
_signals.done(),
_signals.y(y_width),
_signals.x_address(calculate_address_width(x_width)),
_signals.x_address(calculate_address_width(x_count)),
]
self._port = Port(incoming=in_signals, outgoing=out_signals)

Expand Down
33 changes: 33 additions & 0 deletions elasticai/creator/hdl/design_base/ports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from elasticai.creator.hdl.code_generation.code_generation import (
calculate_address_width,
)
from elasticai.creator.hdl.design_base import std_signals as _signals
from elasticai.creator.hdl.design_base.design import Port


def create_port_for_base_design(x_width: int, y_width: int):
return Port(
incoming=[
_signals.enable(),
_signals.clock(),
_signals.x(x_width),
],
outgoing=[_signals.y(y_width)],
)


def create_port_for_buffered_design(
x_width: int, y_width: int, x_count: int, y_count: int
) -> Port:
in_signals = [
_signals.enable(),
_signals.clock(),
_signals.x(x_width),
_signals.y_address(calculate_address_width(y_count)),
]
out_signals = [
_signals.done(),
_signals.y(y_width),
_signals.x_address(calculate_address_width(x_count)),
]
return Port(incoming=in_signals, outgoing=out_signals)
10 changes: 0 additions & 10 deletions elasticai/creator/hdl/vhdl/code_files/network_component.py

This file was deleted.

2 changes: 1 addition & 1 deletion elasticai/creator/hdl/vhdl/code_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .code_generation import create_instance, signal_definition
from .code_generation import create_instance, signal_definition, to_vhdl_hex_string
31 changes: 31 additions & 0 deletions elasticai/creator/hdl/vhdl/code_generation/code_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from abc import abstractmethod
from typing import Protocol, Sequence

from elasticai.creator.hdl.code_generation.code_generation import to_hex


def _sorted_dict(items: dict[str, str]) -> dict[str, str]:
return dict((key, items[key]) for key in sorted(items))

Expand Down Expand Up @@ -30,6 +36,27 @@ def create_connections(mapping: dict[str, str]) -> list[str]:
return connections


class Signal(Protocol):
@property
@abstractmethod
def name(self) -> str:
...

@property
@abstractmethod
def width(self) -> int:
...


def create_signal_definitions(prefix: str, signals: Sequence[Signal]):
return sorted(
[
signal_definition(name=f"{prefix}{signal.name}", width=signal.width)
for signal in signals
]
)


def signal_definition(
*,
name: str,
Expand All @@ -56,3 +83,7 @@ def hex_representation(hex_value: str) -> str:

def bin_representation(bin_value: str) -> str:
return f'"{bin_value}"'


def to_vhdl_hex_string(number: int, bit_width: int) -> str:
return f"'x{to_hex(number, bit_width)}'"
38 changes: 20 additions & 18 deletions elasticai/creator/hdl/vhdl/code_generation/template.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from typing import Iterator
from elasticai.creator.hdl.code_generation.abstract_base_template import (
TemplateConfig,
TemplateExpander,
)

from elasticai.creator.hdl.code_generation import AbstractBaseTemplate
from elasticai.creator.resource_utils import read_text


class Template(AbstractBaseTemplate):
def _read_raw_template(self) -> Iterator[str]:
return read_text(
self._template_package, f"{self._template_name}{self._template_file_suffix}"
class Template:
def __init__(
self,
base_name: str,
package: str = "elasticai.creator.hdl.vhdl.template_resources",
suffix: str = ".tpl.vhd",
):
self._template_name = base_name
self._internal_template = TemplateExpander(
TemplateConfig(
file_name=f"{base_name}{suffix}", package=package, parameters=dict()
)
)

_template_package = "elasticai.creator.hdl.vhdl.template_resources"
_template_file_suffix = ".tpl.vhd"

def __init__(self, base_name: str, **parameters: str | tuple[str] | list[str]):
super().__init__(**parameters)
self._template_name = base_name
self._saved_raw_template: list[str] = []
def update_parameters(self, **parameters: str | list[str]):
self._internal_template.config.parameters.update(parameters)

@property
def name(self) -> str:
return f"{self._template_name}{self._template_file_suffix}"
def lines(self) -> list[str]:
return self._internal_template.lines()
38 changes: 0 additions & 38 deletions elasticai/creator/hdl/vhdl/designs/fp_hard_sigmoid.py

This file was deleted.

23 changes: 20 additions & 3 deletions elasticai/creator/hdl/vhdl/designs/fp_linear_1d.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from typing import Optional

from elasticai.creator.hdl.design_base.network_blocks import BufferedNetworkBlock
from elasticai.creator.hdl.code_generation.abstract_base_template import (
module_to_package,
)
from elasticai.creator.hdl.design_base.design import Design, Port
from elasticai.creator.hdl.design_base.ports import create_port_for_buffered_design
from elasticai.creator.hdl.design_base.ports import (
create_port_for_buffered_design as create_port,
)
from elasticai.creator.hdl.translatable import Path
from elasticai.creator.hdl.vhdl.code_generation.template import Template


class FPLinear1d(BufferedNetworkBlock):
class FPLinear1d(Design):
@property
def port(self) -> Port:
return self._port

def __init__(
self,
*,
Expand All @@ -19,8 +30,12 @@ def __init__(
):
super().__init__(
name="fp_linear1d" if name is None else name,
)
self._port = create_port(
x_width=total_bits,
y_width=total_bits,
x_count=in_feature_num,
y_count=out_feature_num,
)
self.in_feature_num = in_feature_num
self.out_feature_num = out_feature_num
Expand All @@ -45,7 +60,9 @@ def _template_parameters(self) -> dict[str, str]:
)

def save_to(self, destination: Path):
template = Template(base_name="fp_linear_1d")
template = Template(
base_name="fp_linear_1d", package=module_to_package(self.__module__)
)
template.update_parameters(
layer_name=self.name,
work_library_name=self.work_library_name,
Expand Down
Loading

0 comments on commit 431862f

Please sign in to comment.