From 41841ebb862215d310712803e53e1791146e141c Mon Sep 17 00:00:00 2001 From: wenhu1024 <96782898+wenhu1024@users.noreply.github.com> Date: Wed, 31 Jul 2024 13:52:11 +0800 Subject: [PATCH 1/6] [ImportVerilog] Support for String Types, String Literals (#7403) Co-authored-by: itaras20 --- include/circt/Dialect/Moore/MooreOps.td | 15 +++++++++++++++ lib/Conversion/ImportVerilog/Expressions.cpp | 6 ++++++ lib/Conversion/ImportVerilog/Types.cpp | 4 ++++ test/Conversion/ImportVerilog/basic.sv | 13 +++++++++++++ test/Conversion/ImportVerilog/types.sv | 6 ++++++ 5 files changed, 44 insertions(+) diff --git a/include/circt/Dialect/Moore/MooreOps.td b/include/circt/Dialect/Moore/MooreOps.td index 6d494ff88f81..b4af35ad13f5 100644 --- a/include/circt/Dialect/Moore/MooreOps.td +++ b/include/circt/Dialect/Moore/MooreOps.td @@ -454,6 +454,21 @@ def NamedConstantOp : MooreOp<"named_constant", [ }]; } +def StringConstantOp : MooreOp<"string_constant", [Pure, ConstantLike]> { + let summary = "Produce a constant string value"; + let description = [{ + Produces a constant value of string type. + + Example: + ```mlir + %0 = moore.string "hello world" + ``` + }]; + let arguments = (ins StrAttr:$value); + let results = (outs IntType:$result); + let assemblyFormat = "$value attr-dict `:` type($result)"; +} + def ConversionOp : MooreOp<"conversion", [Pure]> { let summary = "A type conversion"; let description = [{ diff --git a/lib/Conversion/ImportVerilog/Expressions.cpp b/lib/Conversion/ImportVerilog/Expressions.cpp index acf6959d661a..48b268bb7649 100644 --- a/lib/Conversion/ImportVerilog/Expressions.cpp +++ b/lib/Conversion/ImportVerilog/Expressions.cpp @@ -712,6 +712,12 @@ struct RvalueExprVisitor { return {}; } + /// Handle string literals. + Value visit(const slang::ast::StringLiteral &expr) { + auto type = context.convertType(*expr.type); + return builder.create(loc, type, expr.getValue()); + } + /// Emit an error for all other expressions. template Value visit(T &&node) { diff --git a/lib/Conversion/ImportVerilog/Types.cpp b/lib/Conversion/ImportVerilog/Types.cpp index 8ea0bcd5e490..0e99577fde88 100644 --- a/lib/Conversion/ImportVerilog/Types.cpp +++ b/lib/Conversion/ImportVerilog/Types.cpp @@ -147,6 +147,10 @@ struct TypeVisitor { return moore::UnpackedUnionType::get(context.getContext(), members); } + Type visit(const slang::ast::StringType &type) { + return moore::StringType::get(context.getContext()); + } + /// Emit an error for all other types. template Type visit(T &&node) { diff --git a/test/Conversion/ImportVerilog/basic.sv b/test/Conversion/ImportVerilog/basic.sv index c038dec746ff..8a795d8cbeab 100644 --- a/test/Conversion/ImportVerilog/basic.sv +++ b/test/Conversion/ImportVerilog/basic.sv @@ -201,6 +201,19 @@ module Basic; // CHECK: %ev2 = moore.variable [[VARIANT_B]] MyEnum ev1 = VariantA; MyEnum ev2 = VariantB; + + // CHECK: [[STR_WELCOME:%.+]] = moore.string_constant "Welcome to Moore" : i128 + // CHECK: [[CONV_WELCOME:%.+]] = moore.conversion [[STR_WELCOME]] : !moore.i128 -> !moore.string + // CHECK: [[VAR_S:%.+]] = moore.variable [[CONV_WELCOME]] : + string s = "Welcome to Moore"; + + // CHECK: [[VAR_S1:%.+]] = moore.variable : + // CHECK: [[STR_HELLO:%.+]] = moore.string_constant "Hello World" : i88 + // CHECK: [[CONV_HELLO:%.+]] = moore.conversion [[STR_HELLO]] : !moore.i88 -> !moore.string + // CHECK: moore.assign [[VAR_S1]], [[CONV_HELLO]] : string + string s1; + assign s1 = "Hello World"; + endmodule // CHECK-LABEL: moore.module @Statements diff --git a/test/Conversion/ImportVerilog/types.sv b/test/Conversion/ImportVerilog/types.sv index daa4db7e2d74..10f201ca5540 100644 --- a/test/Conversion/ImportVerilog/types.sv +++ b/test/Conversion/ImportVerilog/types.sv @@ -146,3 +146,9 @@ module Typedefs; myType1 v0; myType2 v1; endmodule + +// CHECK-LABEL: moore.module @String +module String; + // CHECK-NEXT: %s = moore.variable : + string s; +endmodule From 05136f0cac997b2395b01d12c22b7849c6b79ee2 Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 31 Jul 2024 05:42:57 -0700 Subject: [PATCH 2/6] [ESI] MMIO: add read/write port so service (#7407) Replaces write with read_write. --- frontends/PyCDE/integration_test/esi_test.py | 39 +++++++- .../test_software/esi_test.py | 20 ++++ frontends/PyCDE/src/pycde/bsp/common.py | 92 ++++++++++--------- frontends/PyCDE/src/pycde/bsp/cosim.py | 6 +- frontends/PyCDE/src/pycde/esi.py | 12 ++- lib/Dialect/ESI/ESIStdServices.cpp | 22 ++++- .../ESI/runtime/cpp/lib/backends/Cosim.cpp | 54 +++++++---- test/Dialect/ESI/services.mlir | 8 +- 8 files changed, 178 insertions(+), 75 deletions(-) diff --git a/frontends/PyCDE/integration_test/esi_test.py b/frontends/PyCDE/integration_test/esi_test.py index 48966add3e68..2c78c15a41c1 100644 --- a/frontends/PyCDE/integration_test/esi_test.py +++ b/frontends/PyCDE/integration_test/esi_test.py @@ -7,9 +7,10 @@ import pycde from pycde import (AppID, Clock, Module, Reset, modparams, generator) from pycde.bsp import cosim -from pycde.constructs import Wire -from pycde.esi import FuncService, MMIO +from pycde.constructs import Reg, Wire +from pycde.esi import FuncService, MMIO, MMIOReadWriteCmdType from pycde.types import (Bits, Channel, UInt) +from pycde.behavioral import If, Else, EndIf import sys @@ -48,7 +49,7 @@ def build(ports): address_chan_wire = Wire(Channel(UInt(32))) address, address_valid = address_chan_wire.unwrap(1) - response_data = (address.as_uint() + add_amt).as_bits(64) + response_data = (address + add_amt).as_bits(64) response_chan, response_ready = Channel(Bits(64)).wrap( response_data, address_valid) @@ -58,6 +59,37 @@ def build(ports): return MMIOClient +class MMIOReadWriteClient(Module): + clk = Clock() + rst = Reset() + + @generator + def build(ports): + mmio_read_write_bundle = MMIO.read_write(appid=AppID("mmio_rw_client")) + + cmd_chan_wire = Wire(Channel(MMIOReadWriteCmdType)) + resp_ready_wire = Wire(Bits(1)) + cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire) + + add_amt = Reg(UInt(64), + clk=ports.clk, + rst=ports.rst, + rst_value=0, + ce=cmd_valid & cmd.write & (cmd.offset == 0x8).as_bits()) + add_amt.assign(cmd.data.as_uint()) + with If(cmd.write): + response_data = Bits(64)(0) + with Else(): + response_data = (cmd.offset + add_amt).as_bits(64) + EndIf() + response_chan, response_ready = Channel(Bits(64)).wrap( + response_data, cmd_valid) + resp_ready_wire.assign(response_ready) + + cmd_chan = mmio_read_write_bundle.unpack(data=response_chan)['cmd'] + cmd_chan_wire.assign(cmd_chan) + + class Top(Module): clk = Clock() rst = Reset() @@ -67,6 +99,7 @@ def construct(ports): LoopbackInOutAdd7(clk=ports.clk, rst=ports.rst) for i in range(4, 18, 5): MMIOClient(i)() + MMIOReadWriteClient(clk=ports.clk, rst=ports.rst) if __name__ == "__main__": diff --git a/frontends/PyCDE/integration_test/test_software/esi_test.py b/frontends/PyCDE/integration_test/test_software/esi_test.py index b9377e43b747..f264cf923e06 100644 --- a/frontends/PyCDE/integration_test/test_software/esi_test.py +++ b/frontends/PyCDE/integration_test/test_software/esi_test.py @@ -37,6 +37,26 @@ def read_offset(mmio_offset: int, offset: int, add_amt: int): read_offset(mmio_client_14_offset, 0, 14) read_offset(mmio_client_14_offset, 13, 14) +################################################################################ +# MMIOReadWriteClient tests +################################################################################ + +mmio_rw_client_offset = 262144 + + +def read_offset_check(i: int, add_amt: int): + d = mmio.read(mmio_rw_client_offset + i) + if d == i + 9: + print(f"PASS: read_offset_check({mmio_rw_client_offset} + {i}: {d}") + else: + assert False, f": read_offset_check({mmio_rw_client_offset} + {i}: {d}" + + +mmio.write(mmio_rw_client_offset + 8, 9) +read_offset_check(0, 9) +read_offset_check(12, 9) +read_offset_check(0x1400, 9) + ################################################################################ # Manifest tests ################################################################################ diff --git a/frontends/PyCDE/src/pycde/bsp/common.py b/frontends/PyCDE/src/pycde/bsp/common.py index 5a03cd148acf..9373bee5738f 100644 --- a/frontends/PyCDE/src/pycde/bsp/common.py +++ b/frontends/PyCDE/src/pycde/bsp/common.py @@ -117,7 +117,7 @@ class ChannelMMIO(esi.ServiceImplementation): clk = Clock() rst = Input(Bits(1)) - read = Input(esi.MMIO.read.type) + cmd = Input(esi.MMIOReadWriteCmdType) # Amount of register space each client gets. This is a GIANT HACK and needs to # be replaced by parameterizable services. @@ -140,70 +140,75 @@ class ChannelMMIO(esi.ServiceImplementation): @generator def generate(ports, bundles: esi._ServiceGeneratorBundles): - read_table, write_table, manifest_loc = ChannelMMIO.build_table( - ports, bundles) - ChannelMMIO.build_read(ports, manifest_loc, read_table) - ChannelMMIO.build_write(ports, write_table) + table, manifest_loc = ChannelMMIO.build_table(bundles) + ChannelMMIO.build_read(ports, manifest_loc, table) return True @staticmethod - def build_table( - ports, bundles - ) -> Tuple[Dict[int, AssignableSignal], Dict[int, AssignableSignal], int]: + def build_table(bundles) -> Tuple[Dict[int, AssignableSignal], int]: """Build a table of read and write addresses to BundleSignals.""" offset = ChannelMMIO.initial_offset - read_table: Dict[int, AssignableSignal] = {} - write_table: Dict[int, AssignableSignal] = {} + table: Dict[int, AssignableSignal] = {} for bundle in bundles.to_client_reqs: if bundle.port == 'read': - read_table[offset] = bundle - bundle.add_record({"offset": offset}) + table[offset] = bundle + bundle.add_record({"offset": offset, "type": "ro"}) + offset += ChannelMMIO.RegisterSpace + elif bundle.port == 'read_write': + table[offset] = bundle + bundle.add_record({"offset": offset, "type": "rw"}) offset += ChannelMMIO.RegisterSpace else: assert False, "Unrecognized port name." manifest_loc = offset - return read_table, write_table, manifest_loc + return table, manifest_loc @staticmethod - def build_read(ports, manifest_loc: int, read_table: Dict[int, - AssignableSignal]): + def build_read(ports, manifest_loc: int, table: Dict[int, AssignableSignal]): """Builds the read side of the MMIO service.""" # Instantiate the header and manifest ROM. Fill in the read_table with # bundle wires to be assigned identically to the other MMIO clients. header_bundle_wire = Wire(esi.MMIO.read.type) - read_table[0] = header_bundle_wire + table[0] = header_bundle_wire HeaderMMIO(manifest_loc)(clk=ports.clk, rst=ports.rst, read=header_bundle_wire) mani_bundle_wire = Wire(esi.MMIO.read.type) - read_table[manifest_loc] = mani_bundle_wire + table[manifest_loc] = mani_bundle_wire ESI_Manifest_ROM_Wrapper(clk=ports.clk, read=mani_bundle_wire) - # Unpack the read bundle. + # Unpack the cmd bundle. data_resp_channel = Wire(Channel(esi.MMIODataType)) counted_output = Wire(Channel(esi.MMIODataType)) - read_addr_channel = ports.read.unpack(data=counted_output)["offset"] + cmd_channel = ports.cmd.unpack(data=counted_output)["cmd"] counted_output.assign(data_resp_channel) # Get the selection index and the address to hand off to the clients. - sel_bits, client_address_chan = ChannelMMIO.build_addr_read( - read_addr_channel) + sel_bits, client_cmd_chan = ChannelMMIO.build_addr_read(cmd_channel) # Build the demux/mux and assign the results of each appropriately. - read_clients_clog2 = clog2(len(read_table)) - client_addr_channels = esi.ChannelDemux( + read_clients_clog2 = clog2(len(table)) + client_cmd_channels = esi.ChannelDemux( sel=sel_bits.pad_or_truncate(read_clients_clog2), - input=client_address_chan, - num_outs=len(read_table)) + input=client_cmd_chan, + num_outs=len(table)) client_data_channels = [] - for (idx, offset) in enumerate(sorted(read_table.keys())): - bundle, bundle_froms = esi.MMIO.read.type.pack( - offset=client_addr_channels[idx]) + for (idx, offset) in enumerate(sorted(table.keys())): + bundle_wire = table[offset] + bundle_type = bundle_wire.type + if bundle_type == esi.MMIO.read.type: + offset = client_cmd_channels[idx].transform(lambda cmd: cmd.offset) + bundle, bundle_froms = esi.MMIO.read.type.pack(offset=offset) + elif bundle_type == esi.MMIO.read_write.type: + bundle, bundle_froms = esi.MMIO.read_write.type.pack( + cmd=client_cmd_channels[idx]) + else: + assert False, "Unrecognized bundle type." + bundle_wire.assign(bundle) client_data_channels.append(bundle_froms["data"]) - read_table[offset].assign(bundle) resp_channel = esi.ChannelMux(client_data_channels) data_resp_channel.assign(resp_channel) @@ -218,18 +223,21 @@ def build_addr_read( # change to support more flexibility in addressing. Not clear if what we're # doing now it sufficient or not. - addr_ready_wire = Wire(Bits(1)) - addr, addr_valid = read_addr_chan.unwrap(addr_ready_wire) - addr = addr.as_bits() + cmd_ready_wire = Wire(Bits(1)) + cmd, cmd_valid = read_addr_chan.unwrap(cmd_ready_wire) sel_bits = NamedWire(Bits(32 - ChannelMMIO.RegisterSpaceBits), "sel_bits") - sel_bits.assign(addr[ChannelMMIO.RegisterSpaceBits:]) - client_addr = NamedWire(Bits(32), "client_addr") - client_addr.assign(addr & Bits(32)(ChannelMMIO.AddressMask)) - client_addr_chan, client_addr_ready = Channel(UInt(32)).wrap( - client_addr.as_uint(), addr_valid) - addr_ready_wire.assign(client_addr_ready) + sel_bits.assign(cmd.offset.as_bits()[ChannelMMIO.RegisterSpaceBits:]) + client_cmd = NamedWire(esi.MMIOReadWriteCmdType, "client_cmd") + client_cmd.assign( + esi.MMIOReadWriteCmdType({ + "write": + cmd.write, + "offset": (cmd.offset.as_bits() & + Bits(32)(ChannelMMIO.AddressMask)).as_uint(), + "data": + cmd.data + })) + client_addr_chan, client_addr_ready = Channel( + esi.MMIOReadWriteCmdType).wrap(client_cmd, cmd_valid) + cmd_ready_wire.assign(client_addr_ready) return sel_bits, client_addr_chan - - def build_write(self, bundles): - # TODO: this. - pass diff --git a/frontends/PyCDE/src/pycde/bsp/cosim.py b/frontends/PyCDE/src/pycde/bsp/cosim.py index fa3eb07ada8a..62360afa7d83 100644 --- a/frontends/PyCDE/src/pycde/bsp/cosim.py +++ b/frontends/PyCDE/src/pycde/bsp/cosim.py @@ -38,13 +38,13 @@ class ESI_Cosim_UserTopWrapper(Module): def build(ports): user_module(clk=ports.clk, rst=ports.rst) - mmio_read = esi.FuncService.get_coerced(esi.AppID("__cosim_mmio_read"), - esi.MMIO.read.type) + mmio_read_write = esi.FuncService.get_coerced( + esi.AppID("__cosim_mmio_read_write"), esi.MMIO.read_write.type) ChannelMMIO(esi.MMIO, appid=esi.AppID("__cosim_mmio"), clk=ports.clk, rst=ports.rst, - read=mmio_read) + cmd=mmio_read_write) class ESI_Cosim_Top(Module): clk = Clock() diff --git a/frontends/PyCDE/src/pycde/esi.py b/frontends/PyCDE/src/pycde/esi.py index f6f3873f9dc7..170a7b4b8c8e 100644 --- a/frontends/PyCDE/src/pycde/esi.py +++ b/frontends/PyCDE/src/pycde/esi.py @@ -10,7 +10,7 @@ from .support import get_user_loc from .system import System from .types import (Bits, Bundle, BundledChannel, Channel, ChannelDirection, - Type, UInt, types, _FromCirctType) + StructType, Type, UInt, types, _FromCirctType) from .circt import ir from .circt.dialects import esi as raw_esi, hw, msft @@ -469,6 +469,11 @@ def param(name: str, type: Type = None): MMIODataType = Bits(64) +MMIOReadWriteCmdType = StructType([ + ("write", Bits(1)), + ("offset", UInt(32)), + ("data", MMIODataType), +]) @ServiceDecl @@ -482,6 +487,11 @@ class MMIO: BundledChannel("data", ChannelDirection.FROM, MMIODataType) ]) + read_write = Bundle([ + BundledChannel("cmd", ChannelDirection.TO, MMIOReadWriteCmdType), + BundledChannel("data", ChannelDirection.FROM, MMIODataType) + ]) + @staticmethod def _op(sym_name: ir.StringAttr): return raw_esi.MMIOServiceDeclOp(sym_name) diff --git a/lib/Dialect/ESI/ESIStdServices.cpp b/lib/Dialect/ESI/ESIStdServices.cpp index 7bb11e259a74..c5b4dba24742 100644 --- a/lib/Dialect/ESI/ESIStdServices.cpp +++ b/lib/Dialect/ESI/ESIStdServices.cpp @@ -108,14 +108,26 @@ void MMIOServiceDeclOp::getPortList(SmallVectorImpl &ports) { BundledChannel{StringAttr::get(ctxt, "data"), ChannelDirection::from, ChannelType::get(ctxt, IntegerType::get(ctxt, 64))}}, /*resettable=*/UnitAttr())}); - // Write only port. + // Read-write port. + auto cmdType = hw::StructType::get( + ctxt, { + hw::StructType::FieldInfo{StringAttr::get(ctxt, "write"), + IntegerType::get(ctxt, 1)}, + hw::StructType::FieldInfo{ + StringAttr::get(ctxt, "offset"), + IntegerType::get( + ctxt, 32, IntegerType::SignednessSemantics::Unsigned)}, + hw::StructType::FieldInfo{StringAttr::get(ctxt, "data"), + IntegerType::get(ctxt, 64)}, + }); ports.push_back(ServicePortInfo{ - hw::InnerRefAttr::get(getSymNameAttr(), StringAttr::get(ctxt, "write")), + hw::InnerRefAttr::get(getSymNameAttr(), + StringAttr::get(ctxt, "read_write")), ChannelBundleType::get( ctxt, - {BundledChannel{StringAttr::get(ctxt, "offset"), ChannelDirection::to, - ChannelType::get(ctxt, IntegerType::get(ctxt, 32))}, - BundledChannel{StringAttr::get(ctxt, "data"), ChannelDirection::to, + {BundledChannel{StringAttr::get(ctxt, "cmd"), ChannelDirection::to, + ChannelType::get(ctxt, cmdType)}, + BundledChannel{StringAttr::get(ctxt, "data"), ChannelDirection::from, ChannelType::get(ctxt, IntegerType::get(ctxt, 64))}}, /*resettable=*/UnitAttr())}); } diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index 2533f68213b8..86a4685fd58b 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -232,7 +232,7 @@ class ReadCosimChannelPort getType()->getID() + ", got " + desc.type()); if (desc.dir() != ChannelDesc::Direction::ChannelDesc_Direction_TO_CLIENT) throw std::runtime_error("Channel '" + name + - "' is not a to server channel"); + "' is not a to client channel"); assert(desc.name() == name); // Initiate a stream of messages from the server. @@ -349,35 +349,49 @@ class CosimMMIO : public MMIO { CosimMMIO(Context &ctxt, StubContainer *rpcClient) { // We have to locate the channels ourselves since this service might be used // to retrieve the manifest. - ChannelDesc readArg, readResp; - if (!rpcClient->getChannelDesc("__cosim_mmio_read.arg", readArg) || - !rpcClient->getChannelDesc("__cosim_mmio_read.result", readResp)) + ChannelDesc cmdArg, cmdResp; + if (!rpcClient->getChannelDesc("__cosim_mmio_read_write.arg", cmdArg) || + !rpcClient->getChannelDesc("__cosim_mmio_read_write.result", cmdResp)) throw std::runtime_error("Could not find MMIO channels"); - const esi::Type *i32Type = getType(ctxt, new UIntType(readArg.type(), 32)); - const esi::Type *i64Type = getType(ctxt, new UIntType(readResp.type(), 64)); + const esi::Type *i64Type = getType(ctxt, new UIntType(cmdResp.type(), 64)); + const esi::Type *cmdType = + getType(ctxt, new StructType(cmdArg.type(), + {{"write", new BitsType("i1", 1)}, + {"offset", new UIntType("ui32", 32)}, + {"data", new BitsType("i64", 64)}})); // Get ports, create the function, then connect to it. - readArgPort = std::make_unique( - rpcClient->stub.get(), readArg, i32Type, "__cosim_mmio_read.arg"); - readRespPort = std::make_unique( - rpcClient->stub.get(), readResp, i64Type, "__cosim_mmio_read.result"); - readMMIO.reset(FuncService::Function::get(AppID("__cosim_mmio_read"), - *readArgPort, *readRespPort)); - readMMIO->connect(); + cmdArgPort = std::make_unique( + rpcClient->stub.get(), cmdArg, cmdType, "__cosim_mmio_read_write.arg"); + cmdRespPort = std::make_unique( + rpcClient->stub.get(), cmdResp, i64Type, + "__cosim_mmio_read_write.result"); + cmdMMIO.reset(FuncService::Function::get(AppID("__cosim_mmio"), *cmdArgPort, + *cmdRespPort)); + cmdMMIO->connect(); } + struct MMIOCmd { + uint64_t data; + uint32_t offset; + bool write; + } __attribute__((packed)); + // Call the read function and wait for a response. uint64_t read(uint32_t addr) const override { - auto arg = MessageData::from(addr); - std::future result = readMMIO->call(arg); + MMIOCmd cmd{.offset = addr, .write = false}; + auto arg = MessageData::from(cmd); + std::future result = cmdMMIO->call(arg); result.wait(); return *result.get().as(); } void write(uint32_t addr, uint64_t data) override { - // TODO: this. - throw std::runtime_error("Cosim MMIO write not implemented"); + MMIOCmd cmd{.data = data, .offset = addr, .write = true}; + auto arg = MessageData::from(cmd); + std::future result = cmdMMIO->call(arg); + result.wait(); } private: @@ -389,9 +403,9 @@ class CosimMMIO : public MMIO { ctxt.registerType(type); return type; } - std::unique_ptr readArgPort; - std::unique_ptr readRespPort; - std::unique_ptr readMMIO; + std::unique_ptr cmdArgPort; + std::unique_ptr cmdRespPort; + std::unique_ptr cmdMMIO; }; class CosimHostMem : public HostMem { diff --git a/test/Dialect/ESI/services.mlir b/test/Dialect/ESI/services.mlir index bc432cf64f5b..719116eba911 100644 --- a/test/Dialect/ESI/services.mlir +++ b/test/Dialect/ESI/services.mlir @@ -206,17 +206,23 @@ hw.module @CallableAccel1(in %clk: !seq.clock, in %rst: i1) { esi.service.std.mmio @mmio !mmioReq = !esi.bundle<[!esi.channel to "offset", !esi.channel from "data"]> +!mmioRWReq = !esi.bundle<[!esi.channel> to "cmd", !esi.channel from "data"]> -// CONN-LABEL: hw.module @MMIOManifest(in %clk : !seq.clock, in %rst : i1, in %manifest : !esi.bundle<[!esi.channel to "offset", !esi.channel from "data"]>) { +// CONN-LABEL: hw.module @MMIOManifest(in %clk : !seq.clock, in %rst : i1, in %manifest : !esi.bundle<[!esi.channel to "offset", !esi.channel from "data"]>, in %manifestRW : !esi.bundle<[!esi.channel> to "cmd", !esi.channel from "data"]>) { // CONN-NEXT: %true = hw.constant true // CONN-NEXT: %c0_i64 = hw.constant 0 : i64 // CONN-NEXT: esi.manifest.req #esi.appid<"manifest">, <@mmio::@read> std "esi.service.std.mmio", !esi.bundle<[!esi.channel to "offset", !esi.channel from "data"]> // CONN-NEXT: %chanOutput, %ready = esi.wrap.vr %c0_i64, %true : i64 // CONN-NEXT: %offset = esi.bundle.unpack %chanOutput from %manifest : !esi.bundle<[!esi.channel to "offset", !esi.channel from "data"]> +// CONN-NEXT: esi.manifest.req #esi.appid<"manifestRW">, <@mmio::@read_write> std "esi.service.std.mmio", !esi.bundle<[!esi.channel> to "cmd", !esi.channel from "data"]> hw.module @MMIOManifest(in %clk: !seq.clock, in %rst: i1) { %req = esi.service.req <@mmio::@read> (#esi.appid<"manifest">) : !mmioReq %data = hw.constant 0 : i64 %valid = hw.constant 1 : i1 %data_ch, %ready = esi.wrap.vr %data, %valid : i64 %addr = esi.bundle.unpack %data_ch from %req : !mmioReq + + %reqRW = esi.service.req <@mmio::@read_write> (#esi.appid<"manifestRW">) : !mmioRWReq + %dataChannel, %dataChannelReady = esi.wrap.vr %data, %valid: i64 + %cmdChannel = esi.bundle.unpack %dataChannel from %reqRW : !mmioRWReq } From 3c12682a1aff12249d89fb01cfa0554c0ea28f1b Mon Sep 17 00:00:00 2001 From: Jiahan Xie <88367305+jiahanxie353@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:47:59 -0400 Subject: [PATCH 3/6] Support `scf.if` Op Lowering to Calyx (#6256) * support lowering scf if op and add a corresponding test --- lib/Conversion/SCFToCalyx/SCFToCalyx.cpp | 234 ++++++++++++++++-- .../SCFToCalyx/convert_controlflow.mlir | 69 ++++++ 2 files changed, 284 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index c780492bc36a..7411c3a4dd3b 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -95,6 +95,10 @@ class ScfForOp : public calyx::RepeatOpInterface { // Lowering state classes //===----------------------------------------------------------------------===// +struct IfScheduleable { + scf::IfOp ifOp; +}; + struct WhileScheduleable { /// While operation to schedule. ScfWhileOp whileOp; @@ -115,8 +119,63 @@ struct CallScheduleable { }; /// A variant of types representing scheduleable operations. -using Scheduleable = std::variant; +using Scheduleable = + std::variant; + +class IfLoweringStateInterface { +public: + void setThenGroup(scf::IfOp op, calyx::GroupOp group) { + Operation *operation = op.getOperation(); + assert(thenGroup.count(operation) == 0 && + "A then group was already set for this scf::IfOp!\n"); + thenGroup[operation] = group; + } + + calyx::GroupOp getThenGroup(scf::IfOp op) { + auto it = thenGroup.find(op.getOperation()); + assert(it != thenGroup.end() && + "No then group was set for this scf::IfOp!\n"); + return it->second; + } + + void setElseGroup(scf::IfOp op, calyx::GroupOp group) { + Operation *operation = op.getOperation(); + assert(elseGroup.count(operation) == 0 && + "An else group was already set for this scf::IfOp!\n"); + elseGroup[operation] = group; + } + + calyx::GroupOp getElseGroup(scf::IfOp op) { + auto it = elseGroup.find(op.getOperation()); + assert(it != elseGroup.end() && + "No else group was set for this scf::IfOp!\n"); + return it->second; + } + + void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx) { + assert(resultRegs[op.getOperation()].count(idx) == 0 && + "A register was already registered for the given yield result.\n"); + assert(idx < op->getNumOperands()); + resultRegs[op.getOperation()][idx] = reg; + } + + const DenseMap &getResultRegs(scf::IfOp op) { + return resultRegs[op.getOperation()]; + } + + calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx) { + auto regs = getResultRegs(op); + auto it = regs.find(idx); + assert(it != regs.end() && "resultReg not found"); + return it->second; + } + +private: + DenseMap thenGroup; + DenseMap elseGroup; + DenseMap> resultRegs; +}; class WhileLoopLoweringStateInterface : calyx::LoopLoweringStateInterface { @@ -187,6 +246,7 @@ class ForLoopLoweringStateInterface class ComponentLoweringState : public calyx::ComponentLoweringStateInterface, public WhileLoopLoweringStateInterface, public ForLoopLoweringStateInterface, + public IfLoweringStateInterface, public calyx::SchedulerInterface { public: ComponentLoweringState(calyx::ComponentOp component) @@ -213,7 +273,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { TypeSwitch(_op) .template Case(yieldOp->getParentOp()); - if (!whileOp) { - return yieldOp.getOperation()->emitError() - << "Currently only support yield operations inside for and while " - "loops."; - } - ScfWhileOp whileOpInterface(whileOp); - - auto assignGroup = - getState().buildWhileLoopIterArgAssignments( - rewriter, whileOpInterface, - getState().getComponentOp(), - getState().getUniqueName(whileOp) + "_latch", - yieldOp->getOpOperands()); - getState().setWhileLoopLatchGroup(whileOpInterface, - assignGroup); + if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { + ScfWhileOp whileOpInterface(whileOp); + + auto assignGroup = + getState().buildWhileLoopIterArgAssignments( + rewriter, whileOpInterface, + getState().getComponentOp(), + getState().getUniqueName(whileOp) + + "_latch", + yieldOp->getOpOperands()); + getState().setWhileLoopLatchGroup(whileOpInterface, + assignGroup); + return success(); + } + + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + auto resultRegs = getState().getResultRegs(ifOp); + + if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) { + auto thenGroup = getState().getThenGroup(ifOp); + for (auto op : enumerate(yieldOp.getOperands())) { + auto resultReg = + getState().getResultRegs(ifOp, op.index()); + buildAssignmentsForRegisterWrite( + rewriter, thenGroup, + getState().getComponentOp(), resultReg, + op.value()); + getState().registerEvaluatingGroup( + ifOp.getResult(op.index()), thenGroup); + } + } + + if (!ifOp.getElseRegion().empty() && + (yieldOp->getParentRegion() == &ifOp.getElseRegion())) { + auto elseGroup = getState().getElseGroup(ifOp); + for (auto op : enumerate(yieldOp.getOperands())) { + auto resultReg = + getState().getResultRegs(ifOp, op.index()); + buildAssignmentsForRegisterWrite( + rewriter, elseGroup, + getState().getComponentOp(), resultReg, + op.value()); + getState().registerEvaluatingGroup( + ifOp.getResult(op.index()), elseGroup); + } + } + } return success(); } @@ -945,6 +1037,13 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, return success(); } +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, + scf::IfOp ifOp) const { + getState().addBlockScheduleable( + ifOp.getOperation()->getBlock(), IfScheduleable{ifOp}); + return success(); +} + LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, CallOp callOp) const { std::string instanceName = calyx::getInstanceName(callOp); @@ -1291,6 +1390,51 @@ class BuildForGroups : public calyx::FuncOpPartialLoweringPattern { } }; +class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern { + using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern; + + LogicalResult + partiallyLowerFuncToComp(FuncOp funcOp, + PatternRewriter &rewriter) const override { + LogicalResult res = success(); + funcOp.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + + auto scfIfOp = cast(op); + + calyx::ComponentOp componentOp = + getState().getComponentOp(); + + std::string thenGroupName = + getState().getUniqueName("then_br"); + auto thenGroupOp = calyx::createGroup( + rewriter, componentOp, scfIfOp.getLoc(), thenGroupName); + getState().setThenGroup(scfIfOp, thenGroupOp); + + if (!scfIfOp.getElseRegion().empty()) { + std::string elseGroupName = + getState().getUniqueName("else_br"); + auto elseGroupOp = calyx::createGroup( + rewriter, componentOp, scfIfOp.getLoc(), elseGroupName); + getState().setElseGroup(scfIfOp, elseGroupOp); + } + + for (auto ifOpRes : scfIfOp.getResults()) { + auto reg = createRegister( + scfIfOp.getLoc(), rewriter, getComponent(), + ifOpRes.getType().getIntOrFloatBitWidth(), + getState().getUniqueName("if_res")); + getState().setResultRegs( + scfIfOp, reg, ifOpRes.getResultNumber()); + } + + return WalkResult::advance(); + }); + return res; + } +}; + /// Builds a control schedule by traversing the CFG of the function and /// associating this with the previously created groups. /// For simplicity, the generated control flow is expanded for all possible @@ -1384,6 +1528,50 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { forLatchGroup.getName()); if (res.failed()) return res; + } else if (auto *ifSchedPtr = std::get_if(&group); + ifSchedPtr) { + auto ifOp = ifSchedPtr->ifOp; + + Location loc = ifOp->getLoc(); + + auto cond = ifOp.getCondition(); + auto condGroup = getState() + .getEvaluatingGroup(cond); + + auto symbolAttr = FlatSymbolRefAttr::get( + StringAttr::get(getContext(), condGroup.getSymName())); + + bool initElse = !ifOp.getElseRegion().empty(); + auto ifCtrlOp = rewriter.create( + loc, cond, symbolAttr, /*initializeElseBody=*/initElse); + + rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock()); + + auto thenSeqOp = + rewriter.create(ifOp.getThenRegion().getLoc()); + auto *thenSeqOpBlock = thenSeqOp.getBodyBlock(); + + rewriter.setInsertionPointToEnd(thenSeqOpBlock); + + calyx::GroupOp thenGroup = + getState().getThenGroup(ifOp); + rewriter.create(thenGroup.getLoc(), + thenGroup.getName()); + + if (!ifOp.getElseRegion().empty()) { + rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody()); + + auto elseSeqOp = + rewriter.create(ifOp.getElseRegion().getLoc()); + auto *elseSeqOpBlock = elseSeqOp.getBodyBlock(); + + rewriter.setInsertionPointToEnd(elseSeqOpBlock); + + calyx::GroupOp elseGroup = + getState().getElseGroup(ifOp); + rewriter.create(elseGroup.getLoc(), + elseGroup.getName()); + } } else if (auto *callSchedPtr = std::get_if(&group)) { auto instanceOp = callSchedPtr->instanceOp; OpBuilder::InsertionGuard g(rewriter); @@ -1540,6 +1728,12 @@ class LateSSAReplacement : public calyx::FuncOpPartialLoweringPattern { LogicalResult partiallyLowerFuncToComp(FuncOp funcOp, PatternRewriter &) const override { + funcOp.walk([&](scf::IfOp op) { + for (auto res : getState().getResultRegs(op)) + op.getOperation()->getResults()[res.first].replaceAllUsesWith( + res.second.getOut()); + }); + funcOp.walk([&](scf::WhileOp op) { /// The yielded values returned from the while op will be present in the /// iterargs registers post execution of the loop. @@ -1790,6 +1984,8 @@ void SCFToCalyxPass::runOnOperation() { addOncePattern(loweringPatterns, patternState, funcMap, *loweringState); + addOncePattern(loweringPatterns, patternState, funcMap, + *loweringState); /// This pattern converts operations within basic blocks to Calyx library /// operators. Combinational operations are assigned inside a /// calyx::CombGroupOp, and sequential inside calyx::GroupOps. diff --git a/test/Conversion/SCFToCalyx/convert_controlflow.mlir b/test/Conversion/SCFToCalyx/convert_controlflow.mlir index 09e7ef4214fb..d4a87139f621 100644 --- a/test/Conversion/SCFToCalyx/convert_controlflow.mlir +++ b/test/Conversion/SCFToCalyx/convert_controlflow.mlir @@ -572,3 +572,72 @@ module { return } } + +// ----- + +// Test if op with else branch. + +module { +// CHECK-LABEL: calyx.component @main( +// CHECK-SAME: %[[VAL_0:in0]]: i32, +// CHECK-SAME: %[[VAL_1:in1]]: i32, +// CHECK-SAME: %[[VAL_2:.*]]: i1 {clk}, +// CHECK-SAME: %[[VAL_3:.*]]: i1 {reset}, +// CHECK-SAME: %[[VAL_4:.*]]: i1 {go}) -> ( +// CHECK-SAME: %[[VAL_5:out0]]: i32, +// CHECK-SAME: %[[VAL_6:.*]]: i1 {done}) { +// CHECK: %[[VAL_7:.*]] = hw.constant true +// CHECK: %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]] = calyx.std_add @std_add_0 : i32, i32, i32 +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = calyx.std_slt @std_slt_0 : i32, i32, i1 +// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]], %[[VAL_16:.*]], %[[VAL_17:.*]], %[[VAL_18:.*]], %[[VAL_19:.*]] = calyx.register @if_res_0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]], %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]] = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1 +// CHECK: calyx.wires { +// CHECK: calyx.assign %[[VAL_5]] = %[[VAL_24]] : i32 +// CHECK: calyx.group @then_br_0 { +// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_10]] : i32 +// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_7]] : i1 +// CHECK: calyx.assign %[[VAL_8]] = %[[VAL_0]] : i32 +// CHECK: calyx.assign %[[VAL_9]] = %[[VAL_1]] : i32 +// CHECK: calyx.group_done %[[VAL_19]] : i1 +// CHECK: } +// CHECK: calyx.group @else_br_0 { +// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_1]] : i32 +// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_7]] : i1 +// CHECK: calyx.group_done %[[VAL_19]] : i1 +// CHECK: } +// CHECK: calyx.comb_group @bb0_0 { +// CHECK: calyx.assign %[[VAL_11]] = %[[VAL_0]] : i32 +// CHECK: calyx.assign %[[VAL_12]] = %[[VAL_1]] : i32 +// CHECK: } +// CHECK: calyx.group @ret_assign_0 { +// CHECK: calyx.assign %[[VAL_20]] = %[[VAL_18]] : i32 +// CHECK: calyx.assign %[[VAL_21]] = %[[VAL_7]] : i1 +// CHECK: calyx.group_done %[[VAL_25]] : i1 +// CHECK: } +// CHECK: } +// CHECK: calyx.control { +// CHECK: calyx.seq { +// CHECK: calyx.if %[[VAL_13]] with @bb0_0 { +// CHECK: calyx.seq { +// CHECK: calyx.enable @then_br_0 +// CHECK: } +// CHECK: } else { +// CHECK: calyx.seq { +// CHECK: calyx.enable @else_br_0 +// CHECK: } +// CHECK: } +// CHECK: calyx.enable @ret_assign_0 +// CHECK: } +// CHECK: } +// CHECK: } {toplevel} + func.func @main(%arg0 : i32, %arg1 : i32) -> i32 { + %0 = arith.cmpi slt, %arg0, %arg1 : i32 + %1 = scf.if %0 -> i32 { + %3 = arith.addi %arg0, %arg1 : i32 + scf.yield %3 : i32 + } else { + scf.yield %arg1 : i32 + } + return %1 : i32 + } +} From 17c036f87c4c9e103160b207e18ad595911945e8 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Wed, 31 Jul 2024 12:19:12 -0600 Subject: [PATCH 4/6] [OM] Pass Python values back and forth, not Attributes. (#7417) Internally, primitive OM EvaluatorValues are represented as TypedAttributes. This is great internally, but when we pass these from C++ out to Python, we have to use a very inefficient method to pull the Python value out of the attribute. This updates how primitives are handled at the Python <> C++ interface to directly construct the appropriate Python values and return them. Similarly, for top-level inputs to the Evaluator, Python values are directly accepted and converted to Attributes internally. On large designs, this was shown to decrease single threaded CPU time to process large amounts of OM data by roughly 70%. There is no difference in the output. --- lib/Bindings/Python/OMModule.cpp | 139 ++++++++++++++++++++++++----- lib/Bindings/Python/dialects/om.py | 21 ++--- 2 files changed, 128 insertions(+), 32 deletions(-) diff --git a/lib/Bindings/Python/OMModule.cpp b/lib/Bindings/Python/OMModule.cpp index 5d9e40461862..5122c88e8f44 100644 --- a/lib/Bindings/Python/OMModule.cpp +++ b/lib/Bindings/Python/OMModule.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "DialectModules.h" +#include "circt-c/Dialect/HW.h" #include "circt-c/Dialect/OM.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" @@ -29,18 +30,27 @@ struct Map; struct BasePath; struct Path; -/// None is used to by pybind when default initializing a PythonValue. The order -/// of types in the variant matters here, and we want pybind to try casting to -/// the Python classes defined in this file first, before MlirAttribute and the -/// upstream MLIR type casters. If the MlirAttribute is tried first, then we -/// can hit an assert inside the MLIR codebase. +/// These are the Python types that are represented by the different primitive +/// OMEvaluatorValues as Attributes. +using PythonPrimitive = std::variant; + +/// None is used to by pybind when default initializing a PythonValue. The +/// order of types in the variant matters here, and we want pybind to try +/// casting to the Python classes defined in this file first, before +/// MlirAttribute and the upstream MLIR type casters. If the MlirAttribute +/// is tried first, then we can hit an assert inside the MLIR codebase. struct None {}; -using PythonValue = - std::variant; +using PythonValue = std::variant; /// Map an opaque OMEvaluatorValue into a python value. PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result); -OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result); +OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result, + MlirContext ctx); +static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr); +static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value, + MlirContext ctx); /// Provides a List class by simply wrapping the OMObject CAPI. struct List { @@ -79,13 +89,15 @@ struct Map { Map(OMEvaluatorValue value) : value(value) {} /// Return the keys. - std::vector getKeys() { + std::vector getKeys() { auto attr = omEvaluatorMapGetKeys(value); intptr_t numFieldNames = mlirArrayAttrGetNumElements(attr); - std::vector pyFieldNames; - for (intptr_t i = 0; i < numFieldNames; ++i) - pyFieldNames.emplace_back(mlirArrayAttrGetElement(attr, i)); + std::vector pyFieldNames; + for (intptr_t i = 0; i < numFieldNames; ++i) { + auto name = mlirStringAttrGetValue(mlirArrayAttrGetElement(attr, i)); + pyFieldNames.emplace_back(py::str(name.data, name.length)); + } return pyFieldNames; } @@ -224,7 +236,8 @@ struct Evaluator { std::vector actualParams) { std::vector values; for (auto ¶m : actualParams) - values.push_back(pythonValueToOMEvaluatorValue(param)); + values.push_back(pythonValueToOMEvaluatorValue( + param, mlirModuleGetContext(getModule()))); // Instantiate the Object via the CAPI. OMEvaluatorValue result = omEvaluatorInstantiate( @@ -288,7 +301,8 @@ class PyMapAttrIterator { throw py::stop_iteration(); MlirIdentifier key = omMapAttrGetElementKey(attr, nextIndex); - MlirAttribute value = omMapAttrGetElementValue(attr, nextIndex); + PythonValue value = + omPrimitiveToPythonValue(omMapAttrGetElementValue(attr, nextIndex)); nextIndex++; auto keyName = mlirIdentifierStr(key); @@ -349,6 +363,88 @@ Map::dunderGetItem(std::variant key) { return dunderGetItemAttr(std::get(key)); } +// Convert a generic MLIR Attribute to a PythonValue. This is basically a C++ +// fast path of the parts of attribute_to_var that we use in the OM dialect. +static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) { + if (omAttrIsAIntegerAttr(attr)) { + auto strRef = omIntegerAttrToString(attr); + return py::int_(py::str(strRef.data, strRef.length)); + } + + if (mlirAttributeIsAFloat(attr)) { + return py::float_(mlirFloatAttrGetValueDouble(attr)); + } + + if (mlirAttributeIsAString(attr)) { + auto strRef = mlirStringAttrGetValue(attr); + return py::str(strRef.data, strRef.length); + } + + if (mlirAttributeIsABool(attr)) { + return py::bool_(mlirBoolAttrGetValue(attr)); + } + + if (omAttrIsAReferenceAttr(attr)) { + auto innerRef = omReferenceAttrGetInnerRef(attr); + auto moduleStrRef = + mlirStringAttrGetValue(hwInnerRefAttrGetModule(innerRef)); + auto nameStrRef = mlirStringAttrGetValue(hwInnerRefAttrGetName(innerRef)); + auto moduleStr = py::str(moduleStrRef.data, moduleStrRef.length); + auto nameStr = py::str(nameStrRef.data, nameStrRef.length); + return py::make_tuple(moduleStr, nameStr); + } + + if (omAttrIsAListAttr(attr)) { + py::list results; + for (intptr_t i = 0, e = omListAttrGetNumElements(attr); i < e; ++i) + results.append(omPrimitiveToPythonValue(omListAttrGetElement(attr, i))); + return results; + } + + if (omAttrIsAMapAttr(attr)) { + py::dict results; + for (intptr_t i = 0, e = omMapAttrGetNumElements(attr); i < e; ++i) { + auto keyStrRef = mlirIdentifierStr(omMapAttrGetElementKey(attr, i)); + auto key = py::str(keyStrRef.data, keyStrRef.length); + auto value = omPrimitiveToPythonValue(omMapAttrGetElementValue(attr, i)); + results[key] = value; + } + return results; + } + + mlirAttributeDump(attr); + throw py::type_error("Unexpected OM primitive attribute"); +} + +// Convert a primitive PythonValue to a generic MLIR Attribute. This is +// basically a C++ fast path of the parts of var_to_attribute that we use in the +// OM dialect. +static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value, + MlirContext ctx) { + if (auto *intValue = std::get_if(&value)) { + auto intType = mlirIntegerTypeGet(ctx, 64); + auto intAttr = mlirIntegerAttrGet(intType, intValue->cast()); + return omIntegerAttrGet(intAttr); + } + + if (auto *attr = std::get_if(&value)) { + auto floatType = mlirF64TypeGet(ctx); + return mlirFloatAttrDoubleGet(ctx, floatType, attr->cast()); + } + + if (auto *attr = std::get_if(&value)) { + auto str = attr->cast(); + auto strRef = mlirStringRefCreate(str.data(), str.length()); + return mlirStringAttrGet(ctx, strRef); + } + + if (auto *attr = std::get_if(&value)) { + return mlirBoolAttrGet(ctx, attr->cast()); + } + + throw py::type_error("Unexpected OM primitive value"); +} + PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result) { // If the result is null, something failed. Diagnostic handling is // implemented in pure Python, so nothing to do here besides throwing an @@ -386,13 +482,11 @@ PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result) { // If the field was a primitive, return the Attribute. assert(omEvaluatorValueIsAPrimitive(result)); - return omEvaluatorValueGetPrimitive(result); + return omPrimitiveToPythonValue(omEvaluatorValueGetPrimitive(result)); } -OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result) { - if (auto *attr = std::get_if(&result)) - return omEvaluatorValueFromPrimitive(*attr); - +OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result, + MlirContext ctx) { if (auto *list = std::get_if(&result)) return list->getValue(); @@ -408,7 +502,12 @@ OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result) { if (auto *path = std::get_if(&result)) return path->getValue(); - return std::get(result).getValue(); + if (auto *object = std::get_if(&result)) + return object->getValue(); + + auto primitive = std::get(result); + return omEvaluatorValueFromPrimitive( + omPythonValueToPrimitive(primitive, ctx)); } } // namespace diff --git a/lib/Bindings/Python/dialects/om.py b/lib/Bindings/Python/dialects/om.py index 2abeaffe0e28..bdf51eed84e1 100644 --- a/lib/Bindings/Python/dialects/om.py +++ b/lib/Bindings/Python/dialects/om.py @@ -21,9 +21,9 @@ # Wrap a base mlir object with high-level object. def wrap_mlir_object(value): - # For primitives, return a Python value. - if isinstance(value, Attribute): - return attribute_to_var(value) + # For primitives, return the Python value directly. + if isinstance(value, (int, float, str, bool, tuple, list, dict)): + return value if isinstance(value, BaseList): return List(value) @@ -52,12 +52,7 @@ def om_var_to_attribute(obj, none_on_fail: bool = False) -> ir.Attrbute: def unwrap_python_object(value): - # Check if the value is a Primitive. - try: - return om_var_to_attribute(value) - except: - pass - + # Check if the value is any of our container or custom types. if isinstance(value, List): return BaseList(value) @@ -73,9 +68,11 @@ def unwrap_python_object(value): if isinstance(value, Path): return value - # Otherwise, it must be an Object. Cast to the mlir object. - assert isinstance(value, Object) - return BaseObject(value) + if isinstance(value, Object): + return BaseObject(value) + + # Otherwise, it must be a primitive, so just return it. + return value class List(BaseList): From 2dbab26c890a43103b50fafff9555a4c19b31bbc Mon Sep 17 00:00:00 2001 From: Bea Healy <57840981+TaoBi22@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:06:29 +0100 Subject: [PATCH 5/6] [docs] Remove confusing reset in Seq docs SV example (#7419) --- docs/Dialects/Seq/RationaleSeq.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Dialects/Seq/RationaleSeq.md b/docs/Dialects/Seq/RationaleSeq.md index 463f2b4045d9..d2e8b24e01b4 100644 --- a/docs/Dialects/Seq/RationaleSeq.md +++ b/docs/Dialects/Seq/RationaleSeq.md @@ -146,7 +146,7 @@ Examples of registers: A register without a reset lowers directly to an always block: ``` -always @(posedge clk or [posedge reset]) begin +always @(posedge clk) begin a <= [%input] end ``` From 20cb546d18a28ebf7a2e8dc87102a53a8da33bc3 Mon Sep 17 00:00:00 2001 From: Will Dietz Date: Thu, 1 Aug 2024 07:02:53 -0500 Subject: [PATCH 6/6] [FIRRTL][Dedup] Rework hashing for perf and bug fixes. (#7420) Primary change is to only generate and populate mappings for sources of values, and not each value themselves. Values are identified using these base numberings plus an appropriate offset. The main benefit of this is to greatly reduce the number of entries in the `indices` map. When handling operations with many block arguments (module-like's with many ports) or with many results (instances of those module-like's) this greatly reduces the pressure on the `indices` map. For these designs, dedup now runs dramatically faster and uses significantly less memory. Also separates location of the value impl, such that if a Value's impl is storage inline into an Operation or Block such that there is aliasing, the two are given different numbers (and especially the numbering isn't changed). On a synthetic design containing a module with 2^20 ports and 256 instances of that module, this is the difference between completing in 20s and OOM'ing on my machine after running for 30 minutes. Functional changes: Fixes #7415. Fixes #7416. Also fixes deduping if block arg types are different (but unused). This is done by hashing block count, and each block's numbering between as well as the types of its arguments before that block's operations. Additionally fixes use of numberings (indices) before it was populated where attribute processing for inner symbol ports hashed using the block argument's numbering before it was populated. --- lib/Dialect/FIRRTL/Transforms/Dedup.cpp | 79 ++++++++++++++-------- test/Dialect/FIRRTL/dedup-errors.mlir | 89 +++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 28 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp index 560aa0e3d7a0..1fb2ad2fe9aa 100644 --- a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp +++ b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp @@ -78,8 +78,17 @@ struct ModuleInfo { mlir::ArrayAttr referredModuleNames; }; -struct SymbolTarget { +/// Unique identifier for a value. All value sources are numbered by apperance, +/// and values are identified using this numbering (`index`) and an `offset`. +/// For BlockArgument's, this is the argument number. +/// For OpResult's, this is the result number. +struct ValueId { uint64_t index; + uint64_t offset; +}; + +struct SymbolTarget { + ValueId index; uint64_t fieldID; }; @@ -161,14 +170,19 @@ struct StructuralHasher { void record(void *address) { auto size = indices.size(); + assert(!indices.contains(address)); indices[address] = size; } - void update(BlockArgument arg) { record(arg.getAsOpaquePointer()); } + /// Get the unique id for the specified value. + ValueId getId(Value val) { + if (auto arg = dyn_cast(val)) + return {indices.at(arg.getOwner()), arg.getArgNumber()}; + auto result = cast(val); + return {indices.at(result.getOwner()), result.getResultNumber()}; + } void update(OpResult result) { - record(result.getAsOpaquePointer()); - // Like instance ops, don't use object ops' result types since they might be // replaced by dedup. Record the class names and lazily combine their hashes // using the same mechanism as instances and modules. @@ -180,23 +194,23 @@ struct StructuralHasher { update(result.getType()); } - void update(OpOperand &operand) { - // We hash the value's index as it apears in the block. - auto it = indices.find(operand.get().getAsOpaquePointer()); - assert(it != indices.end() && "op should have been previously hashed"); - update(it->second); + void update(ValueId index) { + update(index.index); + update(index.offset); } + void update(OpOperand &operand) { update(getId(operand.get())); } + void update(Operation *op, hw::InnerSymAttr attr) { for (auto props : attr) innerSymTargets[props.getName()] = - SymbolTarget{indices[op], props.getFieldID()}; + SymbolTarget{{indices.at(op), 0}, props.getFieldID()}; } void update(Value value, hw::InnerSymAttr attr) { for (auto props : attr) innerSymTargets[props.getName()] = - SymbolTarget{indices[value.getAsOpaquePointer()], props.getFieldID()}; + SymbolTarget{getId(value), props.getFieldID()}; } void update(const SymbolTarget &target) { @@ -281,15 +295,6 @@ struct StructuralHasher { } } - void update(Block &block) { - // Hash the block arguments. - for (auto arg : block.getArguments()) - update(arg); - // Hash the operations in the block. - for (auto &op : block) - update(&op); - } - void update(mlir::OperationName name) { // Operation names are interned. update(name.getAsOpaquePointer()); @@ -299,26 +304,44 @@ struct StructuralHasher { void update(Operation *op) { record(op); update(op->getName()); - update(op, op->getAttrDictionary()); + // Hash the operands. for (auto &operand : op->getOpOperands()) update(operand); + + // Number the block pointers, for use numbering their arguments. + for (auto ®ion : op->getRegions()) + for (auto &block : region.getBlocks()) + record(&block); + + // This happens after the numbering above, as it uses blockarg numbering + // for inner symbols. + update(op, op->getAttrDictionary()); + // Hash the regions. We need to make sure an empty region doesn't hash the // same as no region, so we include the number of regions. update(op->getNumRegions()); - for (auto ®ion : op->getRegions()) - for (auto &block : region.getBlocks()) - update(block); - // Record any op results. + for (auto ®ion : op->getRegions()) { + update(region.getBlocks().size()); + for (auto &block : region.getBlocks()) { + update(indices.at(&block)); + for (auto argType : block.getArgumentTypes()) + update(argType); + for (auto &op : block) + update(&op); + } + } + + // Record any op results (types). for (auto result : op->getResults()) update(result); } - // Every operation and value is assigned a unique id based on their order of - // appearance + // Every operation and block is assigned a unique id based on their order of + // appearance. All values are uniquely identified using these. DenseMap indices; - // Every value is assigned a unique id based on their order of appearance. + // Track inner symbol name -> target's unique identification. DenseMap innerSymTargets; // This keeps track of module names in the order of the appearance. diff --git a/test/Dialect/FIRRTL/dedup-errors.mlir b/test/Dialect/FIRRTL/dedup-errors.mlir index f0a8639041d3..19355c0e4a38 100644 --- a/test/Dialect/FIRRTL/dedup-errors.mlir +++ b/test/Dialect/FIRRTL/dedup-errors.mlir @@ -290,6 +290,95 @@ firrtl.circuit "MustDedup" attributes {annotations = [{ // ----- +// Check same number of blocks but instructions across are same. +// https://github.com/llvm/circt/issues/7415 +// expected-error@below {{module "Test1" not deduplicated with "Test0"}} +firrtl.circuit "SameInstDiffBlock" attributes {annotations = [{ + class = "firrtl.transforms.MustDeduplicateAnnotation", + modules = ["~SameInstDiffBlock|Test0", "~SameInstDiffBlock|Test1"] + }]} { + firrtl.module private @Test0(in %a : !firrtl.uint<1>) { + "test"()({ + ^bb0: + // expected-note@below {{first block has more operations}} + "return"() : () -> () + }, { + ^bb0: + }) : () -> () + } + firrtl.module private @Test1(in %a : !firrtl.uint<1>) { + // expected-note@below {{second block here}} + "test"() ({ + ^bb0: + }, { + ^bb0: + "return"() : () -> () + }): () -> () + } + firrtl.module @SameInstDiffBlock() { + firrtl.instance test0 @Test0(in a : !firrtl.uint<1>) + firrtl.instance test1 @Test1(in a : !firrtl.uint<1>) + } +} + +// ----- + +// Check differences in block arguments. +// expected-error@below {{module "Test1" not deduplicated with "Test0"}} +firrtl.circuit "BlockArgTypes" attributes {annotations = [{ + class = "firrtl.transforms.MustDeduplicateAnnotation", + modules = ["~BlockArgTypes|Test0", "~BlockArgTypes|Test1"] + }]} { + firrtl.module private @Test0(in %a : !firrtl.uint<1>) { + // expected-note@below {{types don't match, first type is 'i32'}} + "test"()({ + ^bb0(%arg0 : i32): + "return"() : () -> () + }) : () -> () + } + firrtl.module private @Test1(in %a : !firrtl.uint<1>) { + // expected-note@below {{second type is 'i64'}} + "test"() ({ + ^bb0(%arg0 : i64): + "return"() : () -> () + }): () -> () + } + firrtl.module @BlockArgTypes() { + firrtl.instance test0 @Test0(in a : !firrtl.uint<1>) + firrtl.instance test1 @Test1(in a : !firrtl.uint<1>) + } +} + +// ----- + +// Check empty block not same as no block. +// https://github.com/llvm/circt/issues/7416 +// expected-error@below {{module "B" not deduplicated with "A"}} +firrtl.circuit "NoBlockEmptyBlock" attributes {annotations = [{ + class = "firrtl.transforms.MustDeduplicateAnnotation", + modules = ["~NoBlockEmptyBlock|A", "~NoBlockEmptyBlock|B"] + }]} { + firrtl.module private @A(in %x: !firrtl.uint<1>) { + // expected-note @below {{operation regions have different number of blocks}} + firrtl.when %x : !firrtl.uint<1> { + } + } + firrtl.module private @B(in %x: !firrtl.uint<1>) { + // expected-note @below {{second operation here}} + firrtl.when %x : !firrtl.uint<1> { + } else { + } + } + firrtl.module @NoBlockEmptyBlock(in %x: !firrtl.uint<1>) { + %a_x = firrtl.instance a @A(in x: !firrtl.uint<1>) + firrtl.matchingconnect %a_x, %x : !firrtl.uint<1> + %b_x = firrtl.instance b @B(in x: !firrtl.uint<1>) + firrtl.matchingconnect %b_x, %x : !firrtl.uint<1> + } +} + +// ----- + // expected-error@below {{module "Test1" not deduplicated with "Test0"}} firrtl.circuit "MustDedup" attributes {annotations = [{ class = "firrtl.transforms.MustDeduplicateAnnotation",