Skip to content

Commit

Permalink
[Types] Use AggregateWireable API to avoid elaboration (#1153)
Browse files Browse the repository at this point in the history
* Use AggregateWireable API to avoid elaboration

Some cases in the repr and insert_wrap_casts logic were unnecessarily
expanding arrays.  This updates the logic to avoid elaboration when
possible, also updating it to use the AggregateWireable API.

Some of the smart_bits tests relied on this elaboration for checking the
repr of two circutis, so some code has been added there to force
elaboration so the repr works.

* Update function call
  • Loading branch information
leonardt authored Oct 11, 2022
1 parent 5c6cfc5 commit 0807db6
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 22 deletions.
12 changes: 6 additions & 6 deletions magma/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ def _wire_children(self, o, debug_info):

def _should_wire_children(self, o):
return (
self._has_elaborated_children() or
o._has_elaborated_children() or
self.has_elaborated_children() or
o.has_elaborated_children() or
self.T.is_mixed()
)

Expand All @@ -632,12 +632,12 @@ def unwire(self, o=None, debug_info=None, keep_wired_when_contexts=False):
keep_wired_when_contexts=keep_wired_when_contexts)

def iswhole(self):
if self._has_elaborated_children():
if self.has_elaborated_children():
return Array._iswhole(self._collect_children(lambda x: x))
return True

def const(self):
if self._has_elaborated_children():
if self.has_elaborated_children():
return all(child.const()
for _, child in self._enumerate_children())
return False
Expand Down Expand Up @@ -843,7 +843,7 @@ def flatten(self):
return ts

def __repr__(self):
if self.name.anon() and self._has_elaborated_children():
if self.name.anon() and self.has_elaborated_children():
t_strs = ', '.join(repr(t) for t in self.ts)
return f'array([{t_strs}])'
return Type.__repr__(self)
Expand Down Expand Up @@ -887,7 +887,7 @@ def _collect_children(self, func):

return T

def _has_elaborated_children(self):
def has_elaborated_children(self):
return bool(self._ts) or bool(self._slices)

@aggregate_wireable_method
Expand Down
13 changes: 10 additions & 3 deletions magma/backend/coreir/insert_wrap_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from magma.t import In, Out, Direction
from magma.tuple import Tuple
from magma.wire import wire
from magma.wire_container import AggregateWireable


_NAMED_TYPES = (AsyncReset, AsyncResetN, Clock)
Expand Down Expand Up @@ -48,9 +49,15 @@ def wrap_if_named_type(self, port, definition):
# via recursion in case the children are named types (since
# .value() will return Array[N, T.flip()], the anon value may not
# have the namedtypes in its type)
if (is_clock_or_nested_clock(type(port), _NAMED_TYPES) or
is_clock_or_nested_clock(type(value), _NAMED_TYPES) or
value.anon()):
if (
is_clock_or_nested_clock(type(port), _NAMED_TYPES) or
is_clock_or_nested_clock(type(value), _NAMED_TYPES) or
(
value.anon() and
isinstance(value, AggregateWireable) and
value.has_elaborated_children()
)
):
for child, _ in port.connection_iter():
if not self.wrap_if_named_type(child, definition):
return False
Expand Down
18 changes: 10 additions & 8 deletions magma/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from magma.ref import TempNamedRef
from magma.t import In
from magma.view import PortView
from magma.wire_container import WiringLog
from magma.wire_container import WiringLog, AggregateWireable


__all__ = ['AnonymousCircuitType']
Expand Down Expand Up @@ -185,13 +185,15 @@ def _get_intermediate_values(value):
driver = value.value()
if driver is None:
return OrderedIdentitySet()
if getattr(type(driver), "N", False) and driver.name.anon():
conn_iter = list(value.connection_iter())
if len(conn_iter) > 1:
return functools.reduce(
operator.or_, (_get_intermediate_values(v) for v, _ in
conn_iter),
OrderedIdentitySet())
if (
isinstance(driver, AggregateWireable) and
driver.name.anon() and
driver.has_elaborated_children()
):
return functools.reduce(
operator.or_, (_get_intermediate_values(v) for v, _ in
value.connection_iter()),
OrderedIdentitySet())
values = OrderedIdentitySet()
while driver is not None:
values |= _add_intermediate_value(driver)
Expand Down
6 changes: 3 additions & 3 deletions magma/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _make_t(self, idx):
value.set_enclosing_when_context(self._enclosing_when_context)
return value

def _has_elaborated_children(self):
def has_elaborated_children(self):
return bool(self._ts)

def _enumerate_children(self):
Expand Down Expand Up @@ -330,8 +330,8 @@ def wire(self, o, debug_info):
)
return
if (self.is_mixed() or
self._has_elaborated_children() or
o._has_elaborated_children()):
self.has_elaborated_children() or
o.has_elaborated_children()):
for self_elem, o_elem in zip(self, o):
self_elem = magma_value(self_elem)
o_elem = magma_value(o_elem)
Expand Down
4 changes: 2 additions & 2 deletions magma/wire_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def rewire(self, o, debug_info=None):

class AggregateWireable(Wireable):
@abc.abstractmethod
def _has_elaborated_children(self):
def has_elaborated_children(self):
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -336,7 +336,7 @@ def aggregate_wireable_method(fn):

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
if self._has_elaborated_children():
if self.has_elaborated_children():
return fn(self, *args, **kwargs)
wireable_fn = getattr(AggregateWireable, fn.__name__)
return wireable_fn(self, *args, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_smart/test_smart_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class _Gold(m.Circuit):
O3=m.Out(m.UInt[16]))
O1 = m.UInt[8]()
O1 @= op(m.zext_to(io.I0, 12), io.I1)[:8]
O1[0] # force elaboration for repr test
io.O1 @= O1
io.O2 @= op(m.zext_to(io.I0, 12), io.I1)
io.O3 @= op(m.zext_to(io.I0, 16), m.zext_to(io.I1, 16))
Expand Down Expand Up @@ -162,6 +163,7 @@ class _Gold(m.Circuit):
O3=m.Out(m.UInt[16]))
O1 = m.UInt[4]()
O1 @= (io.I0 >> m.zext_to(io.I1, 8))[:4]
O1[0] # force elaboration for repr test
io.O1 @= O1
io.O2 @= m.zext_to(io.I0 >> m.zext_to(io.I1, 8), 8)
O3 = m.UInt[16]()
Expand Down Expand Up @@ -227,6 +229,7 @@ class _Gold(m.Circuit):
O2=m.Out(m.UInt[16]))
O1 = m.UInt[4]()
O1 @= op(io.I0)[:4]
O1[0] # force elaboration for repr test
io.O1 @= O1
io.O2 @= op(m.zext_to(io.I0, 16))

Expand Down

0 comments on commit 0807db6

Please sign in to comment.