Skip to content

Commit

Permalink
[MLIR] Add elaborate_magma_registers option
Browse files Browse the repository at this point in the history
  • Loading branch information
rsetaluri committed May 3, 2022
1 parent 768a4ca commit 69faecc
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 4 deletions.
1 change: 1 addition & 0 deletions magma/backend/mlir/compile_to_mlir_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ class CompileToMlirOpts:
verilog_prefix: Optional[str] = None
user_namespace: Optional[str] = None
disable_initial_blocks: bool = False
elaborate_magma_registers: bool = False
13 changes: 9 additions & 4 deletions magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,10 @@ def visit_instance(self, module: ModuleWrapper) -> bool:
inst = module.module
assert isinstance(inst, AnonymousCircuitType)
defn = type(inst)
elaborate_magma_registers = self._ctx.opts.elaborate_magma_registers
if isinstance(defn, Mux):
return self.visit_magma_mux(module)
if isinstance(defn, Register):
if isinstance(defn, Register) and not elaborate_magma_registers:
return self.visit_magma_register(module)
if getattr(defn, "inline_verilog_strs", []):
return self.visit_inline_verilog(module)
Expand Down Expand Up @@ -691,14 +692,18 @@ def visit(self, module: MagmaModuleLike):
self.visit(inst)


def treat_as_primitive(defn_or_decl: CircuitKind) -> bool:
def treat_as_primitive(
defn_or_decl: CircuitKind,
ctx: 'HardwareModule'
) -> bool:
# NOTE(rsetaluri): This is a round-about way to mark new types as
# primitives. These definitions should actually be marked as primitives.
elaborate_magma_registers = ctx.opts.elaborate_magma_registers
if isprimitive(defn_or_decl):
return True
if isinstance(defn_or_decl, Mux):
return True
if isinstance(defn_or_decl, Register):
if isinstance(defn_or_decl, Register) and not elaborate_magma_registers:
return True
if getattr(defn_or_decl, "inline_verilog_strs", []):
return True
Expand Down Expand Up @@ -878,7 +883,7 @@ def compile(self):
self._add_module_parameters(self._hw_module)

def _compile(self) -> hw.ModuleOpBase:
if treat_as_primitive(self._magma_defn_or_decl):
if treat_as_primitive(self._magma_defn_or_decl, self):
return

def new_values(fn, ports):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
hw.module @Register(%I: !hw.struct<x: i8, y: i1>, %CE: i1, %CLK: i1, %ASYNCRESET: i1) -> (O: !hw.struct<x: i8, y: i1>) {
%1 = comb.extract %0 from 1 : (i9) -> i1
%2 = comb.extract %0 from 2 : (i9) -> i1
%3 = comb.extract %0 from 3 : (i9) -> i1
%4 = comb.extract %0 from 4 : (i9) -> i1
%5 = comb.extract %0 from 5 : (i9) -> i1
%6 = comb.extract %0 from 6 : (i9) -> i1
%7 = comb.extract %0 from 7 : (i9) -> i1
%9 = comb.concat %7, %6, %5, %4, %3, %2, %1, %8 : i1, i1, i1, i1, i1, i1, i1, i1
%10 = comb.extract %0 from 8 : (i9) -> i1
%11 = hw.struct_create (%9, %10) : !hw.struct<x: i8, y: i1>
%13 = hw.array_create %11, %I : !hw.struct<x: i8, y: i1>
%12 = hw.array_get %13[%CE] : !hw.array<2x!hw.struct<x: i8, y: i1>>
%14 = hw.struct_extract %12["x"] : !hw.struct<x: i8, y: i1>
%15 = comb.extract %14 from 0 : (i8) -> i1
%16 = comb.extract %14 from 1 : (i8) -> i1
%17 = comb.extract %14 from 2 : (i8) -> i1
%18 = comb.extract %14 from 3 : (i8) -> i1
%19 = comb.extract %14 from 4 : (i8) -> i1
%20 = comb.extract %14 from 5 : (i8) -> i1
%21 = comb.extract %14 from 6 : (i8) -> i1
%22 = comb.extract %14 from 7 : (i8) -> i1
%23 = hw.struct_extract %12["y"] : !hw.struct<x: i8, y: i1>
%24 = comb.concat %23, %22, %21, %20, %19, %18, %17, %16, %15 : i1, i1, i1, i1, i1, i1, i1, i1, i1
%25 = sv.reg {name = "reg_PR9_inst0"} : !hw.inout<i9>
sv.alwaysff(posedge %CLK) {
sv.passign %25, %24 : i9
} (asyncreset : posedge %ASYNCRESET) {
sv.passign %25, %26 : i9
}
%26 = hw.constant 266 : i9
sv.initial {
sv.bpassign %25, %26 : i9
}
%0 = sv.read_inout %25 : !hw.inout<i9>
%8 = comb.extract %0 from 0 : (i9) -> i1
%27 = comb.concat %7, %6, %5, %4, %3, %2, %1, %8 : i1, i1, i1, i1, i1, i1, i1, i1
%28 = hw.struct_create (%27, %10) : !hw.struct<x: i8, y: i1>
hw.output %28 : !hw.struct<x: i8, y: i1>
}
hw.module @complex_register_wrapper(%a: !hw.struct<x: i8, y: i1>, %b: !hw.array<6xi16>, %CLK: i1, %CE: i1, %ASYNCRESET: i1) -> (y: !hw.struct<u: !hw.struct<x: i8, y: i1>, v: !hw.array<6xi16>>) {
%0 = hw.instance "Register_inst0" @Register(I: %a: !hw.struct<x: i8, y: i1>, CE: %CE: i1, CLK: %CLK: i1, ASYNCRESET: %ASYNCRESET: i1) -> (O: !hw.struct<x: i8, y: i1>)
%1 = hw.instance "Register_inst1" @Register(I: %b: !hw.array<6xi16>, CE: %CLK: i1) -> (O: !hw.array<6xi16>)
%2 = hw.struct_create (%0, %1) : !hw.struct<u: !hw.struct<x: i8, y: i1>, v: !hw.array<6xi16>>
%3 = hw.struct_extract %a["x"] : !hw.struct<x: i8, y: i1>
%4 = hw.instance "Register_inst2" @Register(I: %3: i8, CE: %CE: i1, CLK: %CLK: i1) -> (O: i8)
hw.output %2 : !hw.struct<u: !hw.struct<x: i8, y: i1>, v: !hw.array<6xi16>>
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
hw.module @Register(%I: i8, %CLK: i1) -> (O: i8) {
%1 = sv.reg {name = "reg_P8_inst0"} : !hw.inout<i8>
sv.alwaysff(posedge %CLK) {
sv.passign %1, %I : i8
}
%2 = hw.constant 3 : i8
sv.initial {
sv.bpassign %1, %2 : i8
}
%0 = sv.read_inout %1 : !hw.inout<i8>
hw.output %0 : i8
}
hw.module @simple_register_wrapper(%a: i8, %CLK: i1) -> (y: i8) {
%0 = hw.instance "reg0" @Register(I: %a: i8, CLK: %CLK: i1) -> (O: i8)
hw.output %0 : i8
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module Register( // <stdin>:1:1
input [7:0] I,
input CLK,
output [7:0] O);

reg [7:0] reg_P8_inst0; // <stdin>:2:10

always_ff @(posedge CLK) // <stdin>:3:5
reg_P8_inst0 <= I; // <stdin>:4:9
initial // <stdin>:7:5
reg_P8_inst0 = 8'h3; // <stdin>:6:10, :8:9
assign O = reg_P8_inst0; // <stdin>:10:10, :11:5
endmodule

module simple_register_wrapper( // <stdin>:13:1
input [7:0] a,
input CLK,
output [7:0] y);

Register reg0 ( // <stdin>:14:10
.I (a),
.CLK (CLK),
.O (y)
);
endmodule

Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,18 @@ def test_compile_to_mlir_disable_initial_blocks(ckt):
"gold_name": f"{ckt.name}_disable_initial_blocks",
}
run_test_compile_to_mlir(ckt, **kwargs)


@pytest.mark.parametrize(
"ckt",
[
simple_register_wrapper,
complex_register_wrapper,
]
)
def test_compile_to_mlir_elaborate_magma_registers(ckt):
kwargs = {
"elaborate_magma_registers": True,
"gold_name": f"{ckt.name}_elaborate_magma_registers"
}
run_test_compile_to_mlir(ckt, **kwargs)

0 comments on commit 69faecc

Please sign in to comment.