From c4e3cb796a5c3e31972b5758730e08b2b16458d0 Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 9 Jan 2025 22:03:30 +0000 Subject: [PATCH] [ESI] Add hostmem write support to cosim --- frontends/PyCDE/integration_test/esitester.py | 47 +++++++++++ frontends/PyCDE/src/pycde/bsp/common.py | 80 +++++++++++++++++-- frontends/PyCDE/src/pycde/bsp/cosim.py | 6 ++ frontends/PyCDE/src/pycde/esi.py | 40 ++++++---- frontends/PyCDE/src/pycde/signals.py | 5 +- frontends/PyCDE/src/pycde/types.py | 7 +- frontends/PyCDE/test/test_esi.py | 4 +- .../ESI/runtime/cpp/lib/backends/Cosim.cpp | 57 +++++++++++++ .../ESI/runtime/cpp/tools/esitester.cpp | 19 ++++- 9 files changed, 234 insertions(+), 31 deletions(-) diff --git a/frontends/PyCDE/integration_test/esitester.py b/frontends/PyCDE/integration_test/esitester.py index 463a7ab9096b..b42ad9d21d10 100644 --- a/frontends/PyCDE/integration_test/esitester.py +++ b/frontends/PyCDE/integration_test/esitester.py @@ -114,6 +114,52 @@ def construct(ports): mem_data_ce.assign(hostmem_read_resp_valid) +class WriteMem(Module): + """Writes a cycle count to host memory at address 0 in MMIO upon each MMIO + transaction.""" + clk = Clock() + rst = Reset() + + @generator + def construct(ports): + cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType)) + resp_ready_wire = Wire(Bits(1)) + cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire) + mmio_xact = cmd_valid & resp_ready_wire + + write_loc_ce = mmio_xact & cmd.write & (cmd.offset == UInt(32)(0)) + write_loc = Reg(UInt(64), + clk=ports.clk, + rst=ports.rst, + rst_value=0, + ce=write_loc_ce) + write_loc.assign(cmd.data.as_uint()) + + response_data = write_loc.as_bits() + response_chan, response_ready = Channel(Bits(64)).wrap( + response_data, cmd_valid) + resp_ready_wire.assign(response_ready) + + mmio_rw = esi.MMIO.read_write(appid=AppID("WriteMem")) + mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd'] + cmd_chan_wire.assign(mmio_rw_cmd_chan) + + tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact) + + cycle_counter = Counter(64)(clk=ports.clk, + rst=ports.rst, + increment=Bits(1)(1)) + + hostmem_write_req, _ = esi.HostMem.wrap_write_req( + write_loc, + cycle_counter.out.as_bits(), + tag.out, + valid=mmio_xact.reg(ports.clk, ports.rst)) + + hostmem_write_resp = esi.HostMem.write(appid=AppID("WriteMem_hostwrite"), + req=hostmem_write_req) + + class EsiTesterTop(Module): clk = Clock() rst = Reset() @@ -122,6 +168,7 @@ class EsiTesterTop(Module): def construct(ports): PrintfExample(clk=ports.clk, rst=ports.rst) ReadMem(clk=ports.clk, rst=ports.rst) + WriteMem(clk=ports.clk, rst=ports.rst) if __name__ == "__main__": diff --git a/frontends/PyCDE/src/pycde/bsp/common.py b/frontends/PyCDE/src/pycde/bsp/common.py index 78d441ad2da9..b61aa8360ded 100644 --- a/frontends/PyCDE/src/pycde/bsp/common.py +++ b/frontends/PyCDE/src/pycde/bsp/common.py @@ -5,13 +5,13 @@ from __future__ import annotations from ..common import Clock, Input, Output, Reset -from ..constructs import AssignableSignal, ControlReg, NamedWire, Wire +from ..constructs import AssignableSignal, NamedWire, Wire from .. import esi from ..module import Module, generator, modparams from ..signals import BitsSignal, BundleSignal, ChannelSignal from ..support import clog2 from ..types import (Array, Bits, Bundle, BundledChannel, Channel, - ChannelDirection, StructType, Type, UInt) + ChannelDirection, StructType, UInt) from typing import Dict, List, Tuple import typing @@ -266,14 +266,14 @@ class ChannelHostMemImpl(esi.ServiceImplementation): clk = Clock() rst = Reset() - UpstreamReq = StructType([ + UpstreamReadReq = StructType([ ("address", UInt(64)), ("length", UInt(32)), ("tag", UInt(8)), ]) read = Output( Bundle([ - BundledChannel("req", ChannelDirection.TO, UpstreamReq), + BundledChannel("req", ChannelDirection.TO, UpstreamReadReq), BundledChannel( "resp", ChannelDirection.FROM, StructType([ @@ -281,11 +281,77 @@ class ChannelHostMemImpl(esi.ServiceImplementation): ("data", Bits(read_width)), ])), ])) + UpstreamWriteReq = StructType([ + ("address", UInt(64)), + ("tag", UInt(8)), + ("data", Bits(write_width)), + ]) + write = Output( + Bundle([ + BundledChannel("req", ChannelDirection.TO, UpstreamWriteReq), + BundledChannel("ackTag", ChannelDirection.FROM, UInt(8)), + ])) @generator def generate(ports, bundles: esi._ServiceGeneratorBundles): read_reqs = [req for req in bundles.to_client_reqs if req.port == 'read'] ports.read = ChannelHostMemImpl.build_tagged_read_mux(ports, read_reqs) + write_reqs = [ + req for req in bundles.to_client_reqs if req.port == 'write' + ] + ports.write = ChannelHostMemImpl.build_tagged_write_mux(ports, write_reqs) + + @staticmethod + def build_tagged_write_mux( + ports, reqs: List[esi._OutputBundleSetter]) -> BundleSignal: + """Build the write side of the HostMem service.""" + + # If there's no write clients, just return a no-op write bundle + if len(reqs) == 0: + req, _ = Channel(ChannelHostMemImpl.UpstreamWriteReq).wrap( + { + "address": 0, + "tag": 0, + "data": 0 + }, 0) + write_bundle, _ = ChannelHostMemImpl.write.type.pack(req=req) + return write_bundle + + # TODO: mux together multiple write clients. + assert len(reqs) == 1, "Only one write client supported for now." + + # Build the write request channels and ack wires. + write_channels: List[ChannelSignal] = [] + write_acks = [] + for req in reqs: + # Get the request channel and its data type. + reqch = [c.channel for c in req.type.channels if c.name == 'req'][0] + data_type = reqch.inner_type.data + assert data_type == Bits( + write_width + ), f"Gearboxing not yet supported. Client {req.client_name}" + + # Write acks to be filled in later. + write_ack = Wire(Channel(UInt(8))) + write_acks.append(write_ack) + + # Pack up the bundle and assign the request channel. + write_req_bundle_type = esi.HostMem.write_req_bundle_type(data_type) + bundle_sig, froms = write_req_bundle_type.pack(ackTag=write_ack) + tagged_client_req = froms["req"] + req.assign(bundle_sig) + write_channels.append(tagged_client_req) + + # TODO: re-write the tags and store the client and client tag. + + # Build a channel mux for the write requests. + tagged_write_channel = esi.ChannelMux(write_channels) + upstream_write_bundle, froms = ChannelHostMemImpl.write.type.pack( + req=tagged_write_channel) + ack_tag = froms["ackTag"] + # TODO: decode the ack tag and assign it to the correct client. + write_acks[0].assign(ack_tag) + return upstream_write_bundle @staticmethod def build_tagged_read_mux( @@ -293,7 +359,7 @@ def build_tagged_read_mux( """Build the read side of the HostMem service.""" if len(reqs) == 0: - req, req_ready = Channel(ChannelHostMemImpl.UpstreamReq).wrap( + req, req_ready = Channel(ChannelHostMemImpl.UpstreamReadReq).wrap( { "tag": 0, "length": 0, @@ -305,7 +371,7 @@ def build_tagged_read_mux( # TODO: mux together multiple read clients. assert len(reqs) == 1, "Only one read client supported for now." - req = Wire(Channel(ChannelHostMemImpl.UpstreamReq)) + req = Wire(Channel(ChannelHostMemImpl.UpstreamReadReq)) read_bundle, froms = ChannelHostMemImpl.read.type.pack(req=req) resp_chan_ready = Wire(Bits(1)) resp_data, resp_valid = froms["resp"].unwrap(resp_chan_ready) @@ -335,7 +401,7 @@ def build_tagged_read_mux( # Assign the multiplexed read request to the upstream request. req.assign( - client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReq({ + client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReadReq({ "address": r.address, "length": 1, "tag": r.tag diff --git a/frontends/PyCDE/src/pycde/bsp/cosim.py b/frontends/PyCDE/src/pycde/bsp/cosim.py index 2ac771c54de9..54e5586c3ba5 100644 --- a/frontends/PyCDE/src/pycde/bsp/cosim.py +++ b/frontends/PyCDE/src/pycde/bsp/cosim.py @@ -72,6 +72,12 @@ def build(ports): resp_wire.type) resp_wire.assign(data) + ack_wire = Wire(Channel(UInt(8))) + write_req = hostmem.write.unpack(ackTag=ack_wire)['req'] + ack_tag = esi.CallService.call(esi.AppID("__cosim_hostmem_write"), + write_req, UInt(8)) + ack_wire.assign(ack_tag) + class ESI_Cosim_Top(Module): clk = Clock() rst = Input(Bits(1)) diff --git a/frontends/PyCDE/src/pycde/esi.py b/frontends/PyCDE/src/pycde/esi.py index dd15c2b4fa45..a5ee4f2dd2f7 100644 --- a/frontends/PyCDE/src/pycde/esi.py +++ b/frontends/PyCDE/src/pycde/esi.py @@ -519,24 +519,34 @@ class _HostMem(ServiceDecl): ("tag", UInt(8)), ]) - WriteReqType = StructType([ - ("address", UInt(64)), - ("tag", UInt(8)), - ("data", Any()), - ]) - def __init__(self): super().__init__(self.__class__) + def write_req_bundle_type(self, data_type: Type) -> Bundle: + """Build a write request bundle type for the given data type.""" + write_req_type = StructType([ + ("address", UInt(64)), + ("tag", UInt(8)), + ("data", data_type), + ]) + return Bundle([ + BundledChannel("req", ChannelDirection.FROM, write_req_type), + BundledChannel("ackTag", ChannelDirection.TO, UInt(8)) + ]) + + def write_req_channel_type(self, data_type: Type) -> StructType: + """Return a write request struct type for 'data_type'.""" + return StructType([ + ("address", UInt(64)), + ("tag", UInt(8)), + ("data", data_type), + ]) + def wrap_write_req(self, address: UIntSignal, data: Type, tag: UIntSignal, valid: BitsSignal) -> Tuple[ChannelSignal, BitsSignal]: """Create the proper channel type for a write request and use it to wrap the given request arguments. Returns the Channel signal and a ready bit.""" - inner_type = StructType([ - ("address", UInt(64)), - ("tag", UInt(8)), - ("data", data.type), - ]) + inner_type = self.write_req_channel_type(data.type) return Channel(inner_type).wrap( inner_type({ "address": address, @@ -548,10 +558,10 @@ def write(self, appid: AppID, req: ChannelSignal) -> ChannelSignal: """Create a write request to the host memory out of a request channel.""" self._materialize_service_decl() - write_bundle_type = Bundle([ - BundledChannel("req", ChannelDirection.FROM, _HostMem.WriteReqType), - BundledChannel("ackTag", ChannelDirection.TO, UInt(8)) - ]) + # Extract the data type from the request channel and call the helper to get + # the write bundle type for the req channel. + req_data_type = req.type.inner_type.data + write_bundle_type = self.write_req_bundle_type(req_data_type) bundle = cast( BundleSignal, diff --git a/frontends/PyCDE/src/pycde/signals.py b/frontends/PyCDE/src/pycde/signals.py index df521b672a59..c6de3f9175f9 100644 --- a/frontends/PyCDE/src/pycde/signals.py +++ b/frontends/PyCDE/src/pycde/signals.py @@ -809,8 +809,9 @@ def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]: raise ValueError( f"Missing channel values for {', '.join(from_channels.keys())}") - unpack_op = esi.UnpackBundleOp([bc.channel._type for bc in to_channels], - self.value, operands) + with get_user_loc(): + unpack_op = esi.UnpackBundleOp([bc.channel._type for bc in to_channels], + self.value, operands) to_channels_results = unpack_op.toChannels ret = { diff --git a/frontends/PyCDE/src/pycde/types.py b/frontends/PyCDE/src/pycde/types.py index 56bf89a070fd..40a8d7a25101 100644 --- a/frontends/PyCDE/src/pycde/types.py +++ b/frontends/PyCDE/src/pycde/types.py @@ -858,9 +858,10 @@ def pack( if len(to_channels) > 0: raise ValueError(f"Missing channels: {', '.join(to_channels.keys())}") - pack_op = esi.PackBundleOp(self._type, - [bc.channel._type for bc in from_channels], - operands) + with get_user_loc(): + pack_op = esi.PackBundleOp(self._type, + [bc.channel._type for bc in from_channels], + operands) return BundleSignal(pack_op.bundle, self), Bundle.PackSignalResults( [_FromCirctValue(c) for c in pack_op.fromChannels], self) diff --git a/frontends/PyCDE/test/test_esi.py b/frontends/PyCDE/test/test_esi.py index 8e783240b432..2644670b4691 100644 --- a/frontends/PyCDE/test/test_esi.py +++ b/frontends/PyCDE/test/test_esi.py @@ -308,8 +308,8 @@ def build(ports): # CHECK-NEXT: [[R5:%.+]] = hwarith.constant 0 : ui256 # CHECK-NEXT: [[R6:%.+]] = hw.struct_create ([[R0]], [[R4]], [[R5]]) : !hw.struct # CHECK-NEXT: %chanOutput_0, %ready_1 = esi.wrap.vr [[R6]], %false : !hw.struct -# CHECK-NEXT: [[R7:%.+]] = esi.service.req <@_HostMem::@write>(#esi.appid<"host_mem_write_req">) : !esi.bundle<[!esi.channel> from "req", !esi.channel to "ackTag"]> -# CHECK-NEXT: %ackTag = esi.bundle.unpack %chanOutput_0 from [[R7]] : !esi.bundle<[!esi.channel> from "req", !esi.channel to "ackTag"]> +# CHECK-NEXT: [[R7:%.+]] = esi.service.req <@_HostMem::@write>(#esi.appid<"host_mem_write_req">) : !esi.bundle<[!esi.channel> from "req", !esi.channel to "ackTag"]> +# CHECK-NEXT: %ackTag = esi.bundle.unpack %chanOutput_0 from [[R7]] : !esi.bundle<[!esi.channel> from "req", !esi.channel to "ackTag"]> # CHECK: esi.service.std.hostmem @_HostMem @unittestmodule(esi_sys=True) class HostMemReq(Module): diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index afe42dd3a95e..71a24bfe7907 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -426,6 +426,14 @@ struct HostMemReadResp { uint64_t data; uint8_t tag; }; + +struct HostMemWriteReq { + uint64_t data; + uint8_t tag; + uint64_t address; +}; + +using HostMemWriteResp = uint8_t; #pragma pack(pop) class CosimHostMem : public HostMem { @@ -438,6 +446,9 @@ class CosimHostMem : public HostMem { // We have to locate the channels ourselves since this service might be used // to retrieve the manifest. + // TODO: The types here are WRONG. They need to be wrapped in Channels! Fix + // this in a subsequent PR. + // Setup the read side callback. ChannelDesc readArg, readResp; if (!rpcClient->getChannelDesc("__cosim_hostmem_read.arg", readArg) || @@ -465,6 +476,32 @@ class CosimHostMem : public HostMem { *readRespPort, *readReqPort)); read->connect([this](const MessageData &req) { return serviceRead(req); }, true); + + // Setup the write side callback. + ChannelDesc writeArg, writeResp; + if (!rpcClient->getChannelDesc("__cosim_hostmem_write.arg", writeArg) || + !rpcClient->getChannelDesc("__cosim_hostmem_write.result", writeResp)) + throw std::runtime_error("Could not find HostMem channels"); + + const esi::Type *writeRespType = + getType(ctxt, new UIntType(writeResp.type(), 8)); + const esi::Type *writeReqType = + getType(ctxt, new StructType(writeArg.type(), + {{"address", new UIntType("ui64", 64)}, + {"tag", new UIntType("ui8", 8)}, + {"data", new BitsType("i64", 64)}})); + + // Get ports, create the function, then connect to it. + writeRespPort = std::make_unique( + rpcClient->stub.get(), writeResp, writeRespType, + "__cosim_hostmem_write.result"); + writeReqPort = std::make_unique( + rpcClient->stub.get(), writeArg, writeReqType, + "__cosim_hostmem_write.arg"); + write.reset(CallService::Callback::get(acc, AppID("__cosim_hostmem_write"), + *writeRespPort, *writeReqPort)); + write->connect([this](const MessageData &req) { return serviceWrite(req); }, + true); } // Service the read request as a callback. Simply reads the data from the @@ -491,6 +528,23 @@ class CosimHostMem : public HostMem { return MessageData::from(resp); } + // Service a write request as a callback. Simply write the data to the + // location specified. TODO: check that the memory has been mapped. + MessageData serviceWrite(const MessageData &reqBytes) { + const HostMemWriteReq *req = reqBytes.as(); + acc.getLogger().debug( + [&](std::string &subsystem, std::string &msg, + std::unique_ptr> &details) { + subsystem = "HostMem"; + msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" + + toHex(req->data) + " tag=" + std::to_string(req->tag); + }); + uint64_t *dataPtr = reinterpret_cast(req->address); + *dataPtr = req->data; + HostMemWriteResp resp = req->tag; + return MessageData::from(resp); + } + struct CosimHostMemRegion : public HostMemRegion { CosimHostMemRegion(std::size_t size) { ptr = malloc(size); @@ -530,6 +584,9 @@ class CosimHostMem : public HostMem { std::unique_ptr readRespPort; std::unique_ptr readReqPort; std::unique_ptr read; + std::unique_ptr writeRespPort; + std::unique_ptr writeReqPort; + std::unique_ptr write; }; } // namespace diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index 7c65245f4113..fedf99fe86d3 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -103,11 +103,12 @@ void dmaTest(AcceleratorConnection *conn, Accelerator *acc) { // Enable the host memory service. auto hostmem = conn->getService(); hostmem->start(); + auto scratchRegion = hostmem->allocate(8, /*memOpts=*/{}); + uint64_t *dataPtr = static_cast(scratchRegion->getPtr()); // Initiate a test read. auto *readMem = acc->getPorts().at(AppID("ReadMem")).getAs(); - uint64_t *dataPtr = new uint64_t; *dataPtr = 0x12345678; readMem->write(8, (uint64_t)dataPtr); @@ -121,5 +122,19 @@ void dmaTest(AcceleratorConnection *conn, Accelerator *acc) { std::this_thread::sleep_for(std::chrono::microseconds(100)); } if (val != *dataPtr) - throw std::runtime_error("DMA test failed"); + throw std::runtime_error("DMA read test failed"); + + // Initiate a test write. + auto *writeMem = + acc->getPorts().at(AppID("WriteMem")).getAs(); + *dataPtr = 0; + writeMem->write(0, (uint64_t)dataPtr); + // Wait for the accelerator to write. Timeout and fail after 10ms. + for (int i = 0; i < 100; ++i) { + if (*dataPtr != 0) + break; + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + if (*dataPtr == 0) + throw std::runtime_error("DMA write test failed"); }