From 0807db64f0fd629f423294a3df588f1af557c6d6 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Tue, 11 Oct 2022 14:14:28 -0700 Subject: [PATCH] [Types] Use AggregateWireable API to avoid elaboration (#1153) * Use AggregateWireable API to avoid elaboration Some cases in the repr and insert_wrap_casts logic were unnecessarily expanding arrays. This updates the logic to avoid elaboration when possible, also updating it to use the AggregateWireable API. Some of the smart_bits tests relied on this elaboration for checking the repr of two circutis, so some code has been added there to force elaboration so the repr works. * Update function call --- magma/array.py | 12 ++++++------ magma/backend/coreir/insert_wrap_casts.py | 13 ++++++++++--- magma/circuit.py | 18 ++++++++++-------- magma/tuple.py | 6 +++--- magma/wire_container.py | 4 ++-- tests/test_smart/test_smart_bits.py | 3 +++ 6 files changed, 34 insertions(+), 22 deletions(-) diff --git a/magma/array.py b/magma/array.py index 8ec6f765d..324b154b6 100644 --- a/magma/array.py +++ b/magma/array.py @@ -606,8 +606,8 @@ def _wire_children(self, o, debug_info): def _should_wire_children(self, o): return ( - self._has_elaborated_children() or - o._has_elaborated_children() or + self.has_elaborated_children() or + o.has_elaborated_children() or self.T.is_mixed() ) @@ -632,12 +632,12 @@ def unwire(self, o=None, debug_info=None, keep_wired_when_contexts=False): keep_wired_when_contexts=keep_wired_when_contexts) def iswhole(self): - if self._has_elaborated_children(): + if self.has_elaborated_children(): return Array._iswhole(self._collect_children(lambda x: x)) return True def const(self): - if self._has_elaborated_children(): + if self.has_elaborated_children(): return all(child.const() for _, child in self._enumerate_children()) return False @@ -843,7 +843,7 @@ def flatten(self): return ts def __repr__(self): - if self.name.anon() and self._has_elaborated_children(): + if self.name.anon() and self.has_elaborated_children(): t_strs = ', '.join(repr(t) for t in self.ts) return f'array([{t_strs}])' return Type.__repr__(self) @@ -887,7 +887,7 @@ def _collect_children(self, func): return T - def _has_elaborated_children(self): + def has_elaborated_children(self): return bool(self._ts) or bool(self._slices) @aggregate_wireable_method diff --git a/magma/backend/coreir/insert_wrap_casts.py b/magma/backend/coreir/insert_wrap_casts.py index 1c1df3b0c..02a377f18 100644 --- a/magma/backend/coreir/insert_wrap_casts.py +++ b/magma/backend/coreir/insert_wrap_casts.py @@ -7,6 +7,7 @@ from magma.t import In, Out, Direction from magma.tuple import Tuple from magma.wire import wire +from magma.wire_container import AggregateWireable _NAMED_TYPES = (AsyncReset, AsyncResetN, Clock) @@ -48,9 +49,15 @@ def wrap_if_named_type(self, port, definition): # via recursion in case the children are named types (since # .value() will return Array[N, T.flip()], the anon value may not # have the namedtypes in its type) - if (is_clock_or_nested_clock(type(port), _NAMED_TYPES) or - is_clock_or_nested_clock(type(value), _NAMED_TYPES) or - value.anon()): + if ( + is_clock_or_nested_clock(type(port), _NAMED_TYPES) or + is_clock_or_nested_clock(type(value), _NAMED_TYPES) or + ( + value.anon() and + isinstance(value, AggregateWireable) and + value.has_elaborated_children() + ) + ): for child, _ in port.connection_iter(): if not self.wrap_if_named_type(child, definition): return False diff --git a/magma/circuit.py b/magma/circuit.py index f116076ad..f9c97fcdf 100644 --- a/magma/circuit.py +++ b/magma/circuit.py @@ -40,7 +40,7 @@ from magma.ref import TempNamedRef from magma.t import In from magma.view import PortView -from magma.wire_container import WiringLog +from magma.wire_container import WiringLog, AggregateWireable __all__ = ['AnonymousCircuitType'] @@ -185,13 +185,15 @@ def _get_intermediate_values(value): driver = value.value() if driver is None: return OrderedIdentitySet() - if getattr(type(driver), "N", False) and driver.name.anon(): - conn_iter = list(value.connection_iter()) - if len(conn_iter) > 1: - return functools.reduce( - operator.or_, (_get_intermediate_values(v) for v, _ in - conn_iter), - OrderedIdentitySet()) + if ( + isinstance(driver, AggregateWireable) and + driver.name.anon() and + driver.has_elaborated_children() + ): + return functools.reduce( + operator.or_, (_get_intermediate_values(v) for v, _ in + value.connection_iter()), + OrderedIdentitySet()) values = OrderedIdentitySet() while driver is not None: values |= _add_intermediate_value(driver) diff --git a/magma/tuple.py b/magma/tuple.py index ef9825585..e58b358d1 100644 --- a/magma/tuple.py +++ b/magma/tuple.py @@ -266,7 +266,7 @@ def _make_t(self, idx): value.set_enclosing_when_context(self._enclosing_when_context) return value - def _has_elaborated_children(self): + def has_elaborated_children(self): return bool(self._ts) def _enumerate_children(self): @@ -330,8 +330,8 @@ def wire(self, o, debug_info): ) return if (self.is_mixed() or - self._has_elaborated_children() or - o._has_elaborated_children()): + self.has_elaborated_children() or + o.has_elaborated_children()): for self_elem, o_elem in zip(self, o): self_elem = magma_value(self_elem) o_elem = magma_value(o_elem) diff --git a/magma/wire_container.py b/magma/wire_container.py index 02eb06ac4..01db62836 100644 --- a/magma/wire_container.py +++ b/magma/wire_container.py @@ -247,7 +247,7 @@ def rewire(self, o, debug_info=None): class AggregateWireable(Wireable): @abc.abstractmethod - def _has_elaborated_children(self): + def has_elaborated_children(self): raise NotImplementedError() @abc.abstractmethod @@ -336,7 +336,7 @@ def aggregate_wireable_method(fn): @functools.wraps(fn) def wrapper(self, *args, **kwargs): - if self._has_elaborated_children(): + if self.has_elaborated_children(): return fn(self, *args, **kwargs) wireable_fn = getattr(AggregateWireable, fn.__name__) return wireable_fn(self, *args, **kwargs) diff --git a/tests/test_smart/test_smart_bits.py b/tests/test_smart/test_smart_bits.py index 92a2b13fb..0b46f3abf 100644 --- a/tests/test_smart/test_smart_bits.py +++ b/tests/test_smart/test_smart_bits.py @@ -71,6 +71,7 @@ class _Gold(m.Circuit): O3=m.Out(m.UInt[16])) O1 = m.UInt[8]() O1 @= op(m.zext_to(io.I0, 12), io.I1)[:8] + O1[0] # force elaboration for repr test io.O1 @= O1 io.O2 @= op(m.zext_to(io.I0, 12), io.I1) io.O3 @= op(m.zext_to(io.I0, 16), m.zext_to(io.I1, 16)) @@ -162,6 +163,7 @@ class _Gold(m.Circuit): O3=m.Out(m.UInt[16])) O1 = m.UInt[4]() O1 @= (io.I0 >> m.zext_to(io.I1, 8))[:4] + O1[0] # force elaboration for repr test io.O1 @= O1 io.O2 @= m.zext_to(io.I0 >> m.zext_to(io.I1, 8), 8) O3 = m.UInt[16]() @@ -227,6 +229,7 @@ class _Gold(m.Circuit): O2=m.Out(m.UInt[16])) O1 = m.UInt[4]() O1 @= op(io.I0)[:4] + O1[0] # force elaboration for repr test io.O1 @= O1 io.O2 @= op(m.zext_to(io.I0, 16))