Skip to content

Commit

Permalink
fix(translation): correct values for x/y_address_width
Browse files Browse the repository at this point in the history
  • Loading branch information
glencoe committed Mar 10, 2023
1 parent 431862f commit c7af1af
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 122 deletions.
25 changes: 5 additions & 20 deletions elasticai/creator/hdl/design_base/signal.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
from elasticai.creator.hdl.design_base.acceptor import Acceptor
from dataclasses import dataclass


class Signal(Acceptor):
def __init__(self, name: str, width: int, accepted_names: list[str]):
self._name = name
self._width = width
self._accepted_names = tuple(accepted_names)

@property
def name(self) -> str:
return self._name

@property
def width(self) -> int:
return self._width

def __hash__(self):
return hash((self.name, self.width, self._accepted_names))

def accepts(self, other: "Signal") -> bool:
return other.name in self._accepted_names and self.width == other.width
@dataclass(eq=True, frozen=True)
class Signal:
name: str
width: int
18 changes: 7 additions & 11 deletions elasticai/creator/hdl/design_base/std_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,28 @@


def x(width: int) -> Signal:
return Signal(name="x", width=width, accepted_names=["x", "y"])
return Signal(name="x", width=width)


def done() -> Signal:
return Signal(name="done", width=0, accepted_names=["done", "enable"])
return Signal(name="done", width=0)


def enable() -> Signal:
return Signal(name="enable", width=0, accepted_names=["done", "enable"])
return Signal(name="enable", width=0)


def clock() -> Signal:
return Signal(name="clock", width=0, accepted_names=["clock"])
return Signal(name="clock", width=0)


def y(width: int) -> Signal:
return Signal(name="y", width=width, accepted_names=["x", "y"])
return Signal(name="y", width=width)


def x_address(width: int) -> Signal:
return Signal(
name="x_address", width=width, accepted_names=["y_address", "x_address"]
)
return Signal(name="x_address", width=width)


def y_address(width: int) -> Signal:
return Signal(
name="y_address", width=width, accepted_names=["x_address", "y_address"]
)
return Signal(name="y_address", width=width)
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_io_pairs(self) -> dict[int, int]:

@property
def port(self) -> Port:
signal = partial(Signal, width=self._width, accepted_names=[])
signal = partial(Signal, width=self._width)
return Port(
incoming=[signal(name="x")],
outgoing=[signal(name="y")],
Expand Down
155 changes: 70 additions & 85 deletions elasticai/creator/hdl/vhdl/designs/sequential.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import partial
from itertools import chain
from typing import Iterator

from elasticai.creator.hdl.code_generation.abstract_base_template import (
module_to_package,
)
from elasticai.creator.hdl.code_generation.code_generation import (
calculate_address_width,
)
from elasticai.creator.hdl.design_base import std_signals
from elasticai.creator.hdl.design_base.design import Design, Port
from elasticai.creator.hdl.design_base.signal import Signal
from elasticai.creator.hdl.design_base.std_signals import (
clock,
done,
Expand All @@ -23,10 +21,7 @@
y_address,
)
from elasticai.creator.hdl.translatable import Path
from elasticai.creator.hdl.vhdl.code_generation import (
create_instance,
signal_definition,
)
from elasticai.creator.hdl.vhdl.code_generation import create_instance
from elasticai.creator.hdl.vhdl.code_generation.code_generation import (
create_connections,
create_signal_definitions,
Expand All @@ -41,23 +36,22 @@ def __init__(
*,
x_width: int,
y_width: int,
x_address_width: int,
y_address_width: int,
):
super().__init__("Sequential")
self._x_width = x_width
self._y_width = y_width
self._y_address_width = y_address_width
self._x_address_width = x_address_width
self.instances: dict[str, Design] = {}
self._names: dict[str, int] = {}
self._library_name_for_instances = "work"
self._architecture_name_for_instances = "rtl"
self._autowirer = _AutoWirer
for design in sub_designs:
self._register_subdesign(design)

@property
def _y_address_width(self) -> int:
return calculate_address_width(self._y_width)

@property
def _x_address_width(self) -> int:
return calculate_address_width(self._x_width)

def _make_name_unique(self, name: str) -> str:
return f"i_{name}_{self._get_counter_for_name(name)}"

Expand All @@ -78,30 +72,6 @@ def _register_subdesign(self, d: Design):
def _qualified_signal_name(self, instance: str, signal: str) -> str:
return f"{instance}_{signal}"

def _default_port(self) -> Port:
return Port(incoming=[x(1), y_address(1)], outgoing=[y(1), x_address(1)])

def _get_width_of_signal_or_default(self, signal_name: str, default: int) -> int:
return self._get_width_for_signal_from_designs_or_default(
signal_name, designs=iter(self.instances.values()), default=default
)

@staticmethod
def _get_width_for_signal_from_designs_or_default(
selected_signal: str, default: int, designs: Iterator[Design]
) -> int:
for sub_design in designs:
if selected_signal in sub_design.port.signal_names:
return sub_design.port[selected_signal].width
return default

def _reversed_get_width_of_signal_or_default(
self, signal_name: str, default: int
) -> int:
return self._get_width_for_signal_from_designs_or_default(
signal_name, default, reversed(self.instances.values())
)

@property
def port(self) -> Port:
return Port(
Expand All @@ -122,23 +92,27 @@ def _save_subdesigns(self, destination: Path) -> None:
for name, design in self.instances.items():
design.save_to(destination.create_subpath(name))

@staticmethod
def _connection(a: str, b: str) -> str:
return f"{a} <= {b};"

def _create_dataflow_nodes(self) -> list["_BaseNode"]:
nodes: list[_BaseNode] = [_StartNode()]
def _create_dataflow_nodes(self) -> list["_DataFlowNode"]:
nodes: list[_DataFlowNode] = [
_StartNode(x_width=self._x_width, y_address_width=self._y_address_width)
]
for instance, design in self.instances.items():
node: _BaseNode = _DataFlowNode(instance=instance)
node.add_sinks([signal.name for signal in design.port.incoming])
node.add_sources([signal.name for signal in design.port.outgoing])
nodes.append(node)
nodes.append(_EndNode())
nodes.append(self._create_data_flow_node(instance, design))
nodes.append(
_EndNode(y_width=self._y_width, x_address_width=self._x_address_width)
)
return nodes

@staticmethod
def _create_data_flow_node(instance: str, design: Design) -> "_DataFlowNode":
node = _InstanceNode(instance=instance)
node.add_sinks(design.port.incoming)
node.add_sources(design.port.outgoing)
return node

def _generate_connections(self) -> list[str]:
nodes = self._create_dataflow_nodes()
wirer = _AutoWirer(nodes=nodes)
wirer = self._autowirer(nodes=nodes)
return create_connections(wirer.connect())

def _generate_instantiations(self) -> list[str]:
Expand All @@ -152,8 +126,8 @@ def _generate_instantiations(self) -> list[str]:
create_instance(
name=instance,
entity=design.name,
library="work",
architecture="rtl",
library=self._library_name_for_instances,
architecture=self._architecture_name_for_instances,
signal_mapping=signal_map,
)
)
Expand Down Expand Up @@ -186,7 +160,7 @@ def save_to(self, destination: Path):


class _AutoWirer:
def __init__(self, nodes: list["_BaseNode"]):
def __init__(self, nodes: list["_DataFlowNode"]):
self.nodes = nodes
self.mapping = {
"x": ["x", "y"],
Expand All @@ -197,40 +171,40 @@ def __init__(self, nodes: list["_BaseNode"]):
"done": ["done", "enable"],
"clock": ["clock"],
}
self.available_sources: dict[str, "_Source"] = {}
self.available_sources: dict[str, "_OwnedSignal"] = {}

def _pick_best_matching_source(self, sink: "_Sink") -> "_Source":
def _pick_best_matching_source(self, sink: "_OwnedSignal") -> "_OwnedSignal":
for source_name in self.mapping[sink.name]:
if source_name in self.available_sources:
source = self.available_sources[source_name]
return source
return _TopNode().sources[0]

def _update_available_sources(self, node: "_BaseNode") -> None:
def _update_available_sources(self, node: "_DataFlowNode") -> None:
self.available_sources.update({s.name: s for s in node.sources})

def connect(self) -> dict[str, str]:
connections: dict[str, str] = {}
for node in self.nodes:
for sink in node.sinks:
source = self._pick_best_matching_source(sink)
connections[sink.get_qualified_name()] = source.get_qualified_name()
connections[sink.qualified_name] = source.qualified_name
self._update_available_sources(node)

return connections


class _BaseNode(ABC):
class _DataFlowNode(ABC):
def __init__(self) -> None:
self.sinks: list[_Sink] = []
self.sources: list[_Source] = []
self.sinks: list[_OwnedSignal] = []
self.sources: list[_OwnedSignal] = []

def add_sinks(self, sinks: list[str]):
create_sink = partial(_Sink, owner=self)
def add_sinks(self, sinks: list[Signal]):
create_sink = partial(_OwnedSignal, owner=self)
self.sinks.extend(map(create_sink, sinks))

def add_sources(self, sources: list[str]):
create_source = partial(_Source, owner=self)
def add_sources(self, sources: list[Signal]):
create_source = partial(_OwnedSignal, owner=self)
self.sources.extend(map(create_source, sources))

@property
Expand All @@ -239,29 +213,42 @@ def prefix(self) -> str:
...


class _TopNode(_BaseNode):
class _TopNode(_DataFlowNode):
@property
def prefix(self) -> str:
return ""


class _StartNode(_TopNode):
def __init__(self):
def __init__(self, x_width: int, y_address_width: int):
super().__init__()
self.add_sources(["x", "y_address", "enable", "clock"])
self.add_sources(
[
std_signals.x(x_width),
std_signals.y_address(y_address_width),
std_signals.enable(),
std_signals.clock(),
]
)

@property
def prefix(self) -> str:
return ""


class _EndNode(_TopNode):
def __init__(self):
def __init__(self, y_width: int, x_address_width: int):
super().__init__()
self.add_sinks(["y", "x_address", "done"])
self.add_sinks(
[
std_signals.y(y_width),
std_signals.x_address(x_address_width),
std_signals.done(),
]
)


class _DataFlowNode(_BaseNode):
class _InstanceNode(_DataFlowNode):
def __init__(self, instance: str):
self.instance = instance
super().__init__()
Expand All @@ -271,21 +258,19 @@ def prefix(self) -> str:
return f"{self.instance}_"


class _Sink:
def __init__(self, name: str, owner: _DataFlowNode):
self.name = name
class _OwnedSignal:
def __init__(self, signal: Signal, owner: _DataFlowNode):
self.owner = owner
self.source: _DataFlowNode | None = None

def get_qualified_name(self) -> str:
return f"{self.owner.prefix}{self.name}"
self._signal = signal

@property
def name(self) -> str:
return self._signal.name

class _Source:
def __init__(self, name: str, owner: _DataFlowNode):
self.name = name
self.owner = owner
self.sinks: list[_DataFlowNode] = []
@property
def width(self) -> int:
return self._signal.width

def get_qualified_name(self) -> str:
@property
def qualified_name(self) -> str:
return f"{self.owner.prefix}{self.name}"
18 changes: 17 additions & 1 deletion elasticai/creator/translatable_modules/vhdl/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@ def __init__(self, submodules: tuple[Module]):
def translate(self) -> Design:
submodules: list[Module] = [cast(Module, m) for m in self.children()]
subdesigns = [m.translate() for m in submodules]
x_address_width = 1
y_address_width = 1
if len(subdesigns) == 0:
x_width = 1
y_width = 1

else:
front = subdesigns[0]
back = subdesigns[-1]
x_width = front.port["x"].width
y_width = back.port["y"].width
found_y_address = False
found_x_address = False
for design in subdesigns:
if "y_address" in design.port and not found_y_address:
found_y_address = True
y_address_width = design.port["y_address"].width
if "x_address" in design.port and not found_x_address:
found_x_address = True
x_address_width = back.port["x_address"].width
return _SequentialDesign(
sub_designs=subdesigns, x_width=x_width, y_width=y_width
sub_designs=subdesigns,
x_width=x_width,
y_width=y_width,
x_address_width=x_address_width,
y_address_width=y_address_width,
)
Loading

0 comments on commit c7af1af

Please sign in to comment.