diff --git a/magma/primitives/mux.py b/magma/primitives/mux.py index 84c6e1e6b..b225b3c70 100644 --- a/magma/primitives/mux.py +++ b/magma/primitives/mux.py @@ -185,15 +185,20 @@ def dict_lookup(dict_, select, default=0): return output -def list_lookup(list_, select, default=0): +def list_lookup(list_, select, default=None): """ Use `select` as an index into `list` (similar to a case statement) `default` is used when `select` does not match any of the indices (e.g. - when the select width is longer than the list) and has a default value of - 0. + when the select width is longer than the list). If it is `None`, the last + element of the list will be used. """ output = default - for i, elem in enumerate(list_): - output = mux([output, elem], i == select) + if default is None: + output = list_[-1] + list_ = list_[:-1] + # We chain the muxes in reverse order so that the emitted Verilog is in + # forward order. + for i in range(len(list_) - 1, -1, -1): + output = mux([output, list_[i]], i == select) return output diff --git a/tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir b/tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir new file mode 100644 index 000000000..188c00dc4 --- /dev/null +++ b/tests/test_primitives/gold/test_mux_list_lookup_default_none.mlir @@ -0,0 +1,14 @@ +hw.module @test_mux_list_lookup_default_none(%S: i2) -> (O: i5) { + %0 = hw.constant 2 : i5 + %1 = hw.constant 1 : i5 + %2 = hw.constant 1 : i2 + %3 = comb.icmp eq %S, %2 : i2 + %5 = hw.array_create %1, %0 : i5 + %4 = hw.array_get %5[%3] : !hw.array<2xi5> + %6 = hw.constant 0 : i5 + %7 = hw.constant 0 : i2 + %8 = comb.icmp eq %S, %7 : i2 + %10 = hw.array_create %6, %4 : i5 + %9 = hw.array_get %10[%8] : !hw.array<2xi5> + hw.output %9 : i5 +} diff --git a/tests/test_primitives/test_mux.py b/tests/test_primitives/test_mux.py index 752924006..8cde7492e 100644 --- a/tests/test_primitives/test_mux.py +++ b/tests/test_primitives/test_mux.py @@ -5,6 +5,7 @@ import hwtypes as ht import magma as m from magma.testing import check_files_equal, SimpleMagmaProtocol +from magma.testing.utils import check_gold, update_gold def test_basic_mux(): @@ -350,3 +351,31 @@ class _(m.Circuit): vec = [T() for _ in range(n)] sel = T_sel(0) if sel_bits == 1 else T_sel(0) m.mux(vec, sel) + + +def test_mux_list_lookup_default_none(): + class test_mux_list_lookup_default_none(m.Circuit): + io = m.IO(S=m.In(m.Bits[2]), O=m.Out(m.Bits[5])) + + list_ = [ht.BitVector[5](0), ht.BitVector[5](1), ht.BitVector[5](2)] + io.O @= m.list_lookup(list_, io.S) + + m.compile("build/test_mux_list_lookup_default_none", + test_mux_list_lookup_default_none, + output="mlir") + if check_gold(__file__, "test_mux_list_lookup_default_none.mlir"): + return + + tester = fault.Tester(test_mux_list_lookup_default_none) + for i in range(4): + tester.circuit.S = i + tester.eval() + if i < 3: + tester.circuit.O.expect(i) + else: + tester.circuit.O.expect(2) + + tester.compile_and_run("verilator", magma_output="mlir-verilog", + directory=os.path.join(os.path.dirname(__file__), + "build")) + update_gold(__file__, "test_mux_list_lookup_default_none.mlir")