From 0a992a0673c9b7c144b4cecedc2fa331eea9c160 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Tue, 11 Oct 2022 18:46:31 -0700 Subject: [PATCH 1/2] [Primitives] Adds default=None for list_lookup Rather than use an explicit default value as the standard pattern, it makes more sense to use the final element of the list since we're buildling a mux tree anyways. The existing explicit "default" value pattern might still be useful to emulate a certain "else" pattern that's different than the last element. In adding this, I noticed the mux tree was being built forwards, which made the generated code a bit convoluted for this default case, so I reversed the tree instead to make it look more natural in the generated code. --- magma/primitives/mux.py | 13 +++++---- .../test_mux_list_lookup_default_none.mlir | 14 +++++++++ tests/test_primitives/test_mux.py | 29 +++++++++++++++++++ 3 files changed, 51 insertions(+), 5 deletions(-) create mode 100644 tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir diff --git a/magma/primitives/mux.py b/magma/primitives/mux.py index 84c6e1e6b..09cf9bf58 100644 --- a/magma/primitives/mux.py +++ b/magma/primitives/mux.py @@ -185,15 +185,18 @@ def dict_lookup(dict_, select, default=0): return output -def list_lookup(list_, select, default=0): +def list_lookup(list_, select, default=None): """ Use `select` as an index into `list` (similar to a case statement) `default` is used when `select` does not match any of the indices (e.g. - when the select width is longer than the list) and has a default value of - 0. + when the select width is longer than the list). If it is `None`, the last + element of the list will be used. """ output = default - for i, elem in enumerate(list_): - output = mux([output, elem], i == select) + if default is None: + output = list_[-1] + list_ = list_[:-1] + for i in range(len(list_) - 1, -1, -1): + output = mux([output, list_[i]], i == select) return output diff --git a/tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir b/tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir new file mode 100644 index 000000000..188c00dc4 --- /dev/null +++ b/tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir @@ -0,0 +1,14 @@ +hw.module @test_mux_list_lookup_default_none(%S: i2) -> (O: i5) { + %0 = hw.constant 2 : i5 + %1 = hw.constant 1 : i5 + %2 = hw.constant 1 : i2 + %3 = comb.icmp eq %S, %2 : i2 + %5 = hw.array_create %1, %0 : i5 + %4 = hw.array_get %5[%3] : !hw.array<2xi5> + %6 = hw.constant 0 : i5 + %7 = hw.constant 0 : i2 + %8 = comb.icmp eq %S, %7 : i2 + %10 = hw.array_create %6, %4 : i5 + %9 = hw.array_get %10[%8] : !hw.array<2xi5> + hw.output %9 : i5 +} diff --git a/tests/test_primitives/test_mux.py b/tests/test_primitives/test_mux.py index 752924006..8cde7492e 100644 --- a/tests/test_primitives/test_mux.py +++ b/tests/test_primitives/test_mux.py @@ -5,6 +5,7 @@ import hwtypes as ht import magma as m from magma.testing import check_files_equal, SimpleMagmaProtocol +from magma.testing.utils import check_gold, update_gold def test_basic_mux(): @@ -350,3 +351,31 @@ class _(m.Circuit): vec = [T() for _ in range(n)] sel = T_sel(0) if sel_bits == 1 else T_sel(0) m.mux(vec, sel) + + +def test_mux_list_lookup_default_none(): + class test_mux_list_lookup_default_none(m.Circuit): + io = m.IO(S=m.In(m.Bits[2]), O=m.Out(m.Bits[5])) + + list_ = [ht.BitVector[5](0), ht.BitVector[5](1), ht.BitVector[5](2)] + io.O @= m.list_lookup(list_, io.S) + + m.compile("build/test_mux_list_lookup_default_none", + test_mux_list_lookup_default_none, + output="mlir") + if check_gold(__file__, "test_mux_list_lookup_default_none.mlir"): + return + + tester = fault.Tester(test_mux_list_lookup_default_none) + for i in range(4): + tester.circuit.S = i + tester.eval() + if i < 3: + tester.circuit.O.expect(i) + else: + tester.circuit.O.expect(2) + + tester.compile_and_run("verilator", magma_output="mlir-verilog", + directory=os.path.join(os.path.dirname(__file__), + "build")) + update_gold(__file__, "test_mux_list_lookup_default_none.mlir") From 7320ea4b537638391406b8f745bcb516d819cc28 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Fri, 14 Oct 2022 09:17:06 -0700 Subject: [PATCH 2/2] Add comment --- magma/primitives/mux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/magma/primitives/mux.py b/magma/primitives/mux.py index 09cf9bf58..b225b3c70 100644 --- a/magma/primitives/mux.py +++ b/magma/primitives/mux.py @@ -197,6 +197,8 @@ def list_lookup(list_, select, default=None): if default is None: output = list_[-1] list_ = list_[:-1] + # We chain the muxes in reverse order so that the emitted Verilog is in + # forward order. for i in range(len(list_) - 1, -1, -1): output = mux([output, list_[i]], i == select) return output