From ff40aabb49203a12998575e7549cd2e803875445 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 13 Aug 2024 06:17:00 -0700 Subject: [PATCH] [PyCDE] Support for adding module constants to manifests (#7510) --- frontends/PyCDE/integration_test/esi_test.py | 11 ++++++---- .../test_software/esi_test.py | 16 ++++++++++---- frontends/PyCDE/src/pycde/common.py | 17 +++++++++++++++ frontends/PyCDE/src/pycde/module.py | 21 ++++++++++++++++--- frontends/PyCDE/src/pycde/support.py | 13 ++++++++++++ frontends/PyCDE/test/test_esi.py | 5 ++++- 6 files changed, 71 insertions(+), 12 deletions(-) diff --git a/frontends/PyCDE/integration_test/esi_test.py b/frontends/PyCDE/integration_test/esi_test.py index 2c78c15a41c1..ddc1cd4ced6a 100644 --- a/frontends/PyCDE/integration_test/esi_test.py +++ b/frontends/PyCDE/integration_test/esi_test.py @@ -7,6 +7,7 @@ import pycde from pycde import (AppID, Clock, Module, Reset, modparams, generator) from pycde.bsp import cosim +from pycde.common import Constant from pycde.constructs import Reg, Wire from pycde.esi import FuncService, MMIO, MMIOReadWriteCmdType from pycde.types import (Bits, Channel, UInt) @@ -15,21 +16,23 @@ import sys -class LoopbackInOutAdd7(Module): +class LoopbackInOutAdd(Module): """Loopback the request from the host, adding 7 to the first 15 bits.""" clk = Clock() rst = Reset() + add_amt = Constant(UInt(16), 11) + @generator def construct(ports): loopback = Wire(Channel(UInt(16))) - args = FuncService.get_call_chans(AppID("loopback_add7"), + args = FuncService.get_call_chans(AppID("add"), arg_type=UInt(24), result=loopback) ready = Wire(Bits(1)) data, valid = args.unwrap(ready) - plus7 = data + 7 + plus7 = data + LoopbackInOutAdd.add_amt.value data_chan, data_ready = loopback.type.wrap(plus7.as_uint(16), valid) data_chan_buffered = data_chan.buffer(ports.clk, ports.rst, 5) ready.assign(data_ready) @@ -96,7 +99,7 @@ class Top(Module): @generator def construct(ports): - LoopbackInOutAdd7(clk=ports.clk, rst=ports.rst) + LoopbackInOutAdd(clk=ports.clk, rst=ports.rst, appid=AppID("loopback")) for i in range(4, 18, 5): MMIOClient(i)() MMIOReadWriteClient(clk=ports.clk, rst=ports.rst) diff --git a/frontends/PyCDE/integration_test/test_software/esi_test.py b/frontends/PyCDE/integration_test/test_software/esi_test.py index f264cf923e06..4eeeabbe8cb1 100644 --- a/frontends/PyCDE/integration_test/test_software/esi_test.py +++ b/frontends/PyCDE/integration_test/test_software/esi_test.py @@ -67,13 +67,21 @@ def read_offset_check(i: int, add_amt: int): print(m.type_table) d = acc.build_accelerator() - -recv = d.ports[esi.AppID("loopback_add7")].read_port("result") +loopback = d.children[esi.AppID("loopback")] +recv = loopback.ports[esi.AppID("add")].read_port("result") recv.connect() -send = d.ports[esi.AppID("loopback_add7")].write_port("arg") +send = loopback.ports[esi.AppID("add")].write_port("arg") send.connect() +loopback_info = None +for mod_info in m.module_infos: + if mod_info.name == "LoopbackInOutAdd": + loopback_info = mod_info + break +assert loopback_info is not None +add_amt = mod_info.constants["add_amt"].value + ################################################################################ # Loopback add 7 tests ################################################################################ @@ -85,6 +93,6 @@ def read_offset_check(i: int, add_amt: int): print(f"data: {data}") print(f"resp: {resp}") -assert resp == data + 7 +assert resp == data + add_amt print("PASS") diff --git a/frontends/PyCDE/src/pycde/common.py b/frontends/PyCDE/src/pycde/common.py index b2e993164e8a..2fce095417e0 100644 --- a/frontends/PyCDE/src/pycde/common.py +++ b/frontends/PyCDE/src/pycde/common.py @@ -135,6 +135,23 @@ def __repr__(self) -> str: return f"{self.name}[{self.index}]" +class Constant: + """A constant value associated with a module. Gets added to the ESI system + manifest so it is accessible at runtime. + + Example usage: + + ``` + def ExampleModule(Module): + const_name = Constant(UInt(16), 42) + ``` + """ + + def __init__(self, type: Type, value: object): + self.type = type + self.value = value + + class _PyProxy: """Parent class for a Python object which has a corresponding IR op (i.e. a proxy class).""" diff --git a/frontends/PyCDE/src/pycde/module.py b/frontends/PyCDE/src/pycde/module.py index 46904adbbeb9..5f76cc9f8da1 100644 --- a/frontends/PyCDE/src/pycde/module.py +++ b/frontends/PyCDE/src/pycde/module.py @@ -6,9 +6,9 @@ from dataclasses import dataclass from typing import Any, List, Optional, Set, Tuple, Dict -from .common import (AppID, Clock, Input, ModuleDecl, Output, PortError, - _PyProxy, Reset) -from .support import (get_user_loc, _obj_to_attribute, create_type_string, +from .common import (AppID, Clock, Constant, Input, ModuleDecl, Output, + PortError, _PyProxy, Reset) +from .support import (get_user_loc, _obj_to_attribute, obj_to_typed_attribute, create_const_zero) from .signals import ClockSignal, Signal, _FromCirctValue from .types import ClockType, Type, _FromCirctType @@ -237,6 +237,7 @@ def scan_cls(self): clock_ports = set() reset_ports = set() generators = {} + constants = {} num_inputs = 0 num_outputs = 0 for attr_name, attr in self.cls_dct.items(): @@ -273,11 +274,14 @@ def scan_cls(self): ports.append(attr) elif isinstance(attr, Generator): generators[attr_name] = attr + elif isinstance(attr, Constant): + constants[attr_name] = attr self.ports = ports self.clocks = clock_ports self.resets = reset_ports self.generators = generators + self.constants = constants def create_port_proxy(self) -> PortProxyBase: """Create a proxy class for generators to use in order to access module @@ -475,6 +479,17 @@ def create_op(self, sys, symbol): else: self.add_metadata(sys, symbol, None) + # If there are associated constants, add them to the manifest. + if len(self.constants) > 0: + constants_dict: Dict[str, ir.Attribute] = {} + for name, constant in self.constants.items(): + constant_attr = obj_to_typed_attribute(constant.value, constant.type) + constants_dict[name] = constant_attr + with ir.InsertionPoint(sys.mod.body): + from .dialects.esi import esi + esi.SymbolConstantsOp(symbolRef=ir.FlatSymbolRefAttr.get(symbol), + constants=ir.DictAttr.get(constants_dict)) + if len(self.generators) > 0: if hasattr(self, "parameters") and self.parameters is not None: self.attributes["pycde.parameters"] = self.parameters diff --git a/frontends/PyCDE/src/pycde/support.py b/frontends/PyCDE/src/pycde/support.py index a276f444d680..b29db4634a2a 100644 --- a/frontends/PyCDE/src/pycde/support.py +++ b/frontends/PyCDE/src/pycde/support.py @@ -1,6 +1,12 @@ +from __future__ import annotations + from .circt import support from .circt import ir +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .types import Type + import os @@ -43,6 +49,13 @@ def _obj_to_attribute(obj) -> ir.Attribute: "This is required for parameters.") +def obj_to_typed_attribute(obj: object, type: Type) -> ir.Attribute: + from .types import BitVectorType + if isinstance(type, BitVectorType): + return ir.IntegerAttr.get(type._type, obj) + raise ValueError(f"Type '{type}' conversion to attribute not supported yet.") + + __dir__ = os.path.dirname(__file__) _local_files = set([os.path.join(__dir__, x) for x in os.listdir(__dir__)]) _hidden_filenames = set(["functools.py"]) diff --git a/frontends/PyCDE/test/test_esi.py b/frontends/PyCDE/test/test_esi.py index cf4766897304..84f7878cc267 100644 --- a/frontends/PyCDE/test/test_esi.py +++ b/frontends/PyCDE/test/test_esi.py @@ -4,7 +4,7 @@ from pycde import (Clock, Input, InputChannel, Output, OutputChannel, Module, Reset, generator, types) from pycde import esi -from pycde.common import AppID, RecvBundle, SendBundle +from pycde.common import AppID, Constant, RecvBundle, SendBundle from pycde.constructs import Wire from pycde.esi import MMIO from pycde.module import Metadata @@ -36,6 +36,7 @@ class HostComms: # CHECK: esi.manifest.sym @LoopbackInOutTop name "LoopbackInOut" {{.*}}version "0.1" {bar = "baz", foo = 1 : i64} +# CHECK: esi.manifest.constants @LoopbackInOutTop {c1 = 54 : ui8} # CHECK-LABEL: hw.module @LoopbackInOutTop(in %clk : !seq.clock, in %rst : i1) @@ -59,6 +60,8 @@ class LoopbackInOutTop(Module): }, ) + c1 = Constant(UInt(8), 54) + @generator def construct(self): # Use Cosim to implement the 'HostComms' service.