Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Aug 17, 2023
1 parent fdb1980 commit f80f3a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
4 changes: 4 additions & 0 deletions magma/primitives/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ def output_to_index(self) -> Dict[Type, int]:
def output_to_name(self):
return self._output_to_name

@property
def input_to_name(self):
return self._input_to_name

def check_existing_derived_ref(self, value, value_to_name, value_to_index):
"""If value is a child of an array or tuple that has already been added,
we return the child of the existing value, rather than adding a new
Expand Down
48 changes: 29 additions & 19 deletions magma/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,25 +675,27 @@ def otherwise():
return prev_block.new_otherwise_block()


def _get_builder_ports(builder, names):
"""Filter out removed ports."""
return [getattr(builder, name) for name in names if hasattr(builder, name)]


def _find_values_to_split(builder):
"""Detect output values that feed into inputs"""
"""Detect output values that feed into inputs."""
to_split = []
for x in builder.output_to_name.values():
x = getattr(builder, x, None)
if x is None:
continue
for y in builder._input_to_name.values():
y = getattr(builder, y, None)
if y is None:
continue
if y.trace() is x:
to_split.append(x)
break
outputs = _get_builder_ports(builder, builder.output_to_name.values())
inputs = _get_builder_ports(builder, builder.input_to_name.values())
for value in inputs:
source = value.trace()
# TODO(leonardt): handle nesting
if any(source is output for output in outputs):
to_split.append(source)
break
return to_split


def _emit_new_when_assign(value, driver_map, curr_block):
"""Reconstruct when logic in new set of blocks"""
"""Reconstruct when logic in new set of blocks."""
if isinstance(curr_block, _WhenBlock):
new_block = when(curr_block._info.condition)
elif isinstance(curr_block, _ElseWhenBlock):
Expand All @@ -712,15 +714,23 @@ def _emit_new_when_assign(value, driver_map, curr_block):
_emit_new_when_assign(value, driver_map, curr_block.otherwise_block)


def _build_driver_map(drivee):
"""
driver_map: for each context that drivee is driven, store the driver
"""
driver_map = {}
for ctx in drivee._wired_when_contexts:
wires = ctx.get_conditional_wires_for_drivee(drivee)
driver_map[ctx] = (wire.driver for wire in wires)
return driver_map


def split_when_cycles(builder, defn):
to_split = _find_values_to_split(builder)
for value in to_split:
driving = value.driving()
driver_map = {}
contexts = driving[0]._wired_when_contexts[:]
for ctx in contexts:
wires = ctx.get_conditional_wires_for_drivee(driving[0])
driver_map[ctx] = (wire.driver for wire in wires)
driver_map = _build_driver_map(driving[0])
for drivee in driving:
drivee.unwire()
_emit_new_when_assign(drivee, driver_map, contexts[0].root)
root = next(iter(driver_map.keys())).root
_emit_new_when_assign(drivee, driver_map, root)

0 comments on commit f80f3a1

Please sign in to comment.