From ac0fef902e48432be762314638498b91ebb77751 Mon Sep 17 00:00:00 2001 From: John Demme Date: Fri, 9 Aug 2024 16:56:26 +0000 Subject: [PATCH 1/2] [ESI][Manifest] Embed constants in manifest Users want to know what constants were used in a design. Plumb them through using the manifest. No runtime support, no pycde support. Cleanups to the manifest as well. Update the runtime to support the schema changes to the manifest. --- include/circt/Dialect/ESI/ESIInterfaces.td | 8 + include/circt/Dialect/ESI/ESIManifest.td | 17 ++ lib/Dialect/ESI/ESIOps.cpp | 6 + lib/Dialect/ESI/Passes/ESIBuildManifest.cpp | 192 +++++++++++------- .../ESI/runtime/cpp/include/esi/Common.h | 12 +- lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp | 90 +++++--- test/Dialect/ESI/manifest.mlir | 167 +++++++-------- 7 files changed, 299 insertions(+), 193 deletions(-) diff --git a/include/circt/Dialect/ESI/ESIInterfaces.td b/include/circt/Dialect/ESI/ESIInterfaces.td index f37ccea5bccb..764a718ce6db 100644 --- a/include/circt/Dialect/ESI/ESIInterfaces.td +++ b/include/circt/Dialect/ESI/ESIInterfaces.td @@ -68,6 +68,14 @@ def IsManifestData : OpInterface<"IsManifestData"> { "Get the class name for this op.", "StringRef", "getManifestClass", (ins) >, + InterfaceMethod< + "Get the symbol to which this manifest data is referring, if any.", + "FlatSymbolRefAttr", "getSymbolRefAttr", (ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return FlatSymbolRefAttr(); + }] + >, InterfaceMethod< "Populate results with the manifest data.", "void", "getDetails", (ins "SmallVectorImpl&":$results), diff --git a/include/circt/Dialect/ESI/ESIManifest.td b/include/circt/Dialect/ESI/ESIManifest.td index 1fd3ef9089dc..94110e7d6919 100644 --- a/include/circt/Dialect/ESI/ESIManifest.td +++ b/include/circt/Dialect/ESI/ESIManifest.td @@ -169,6 +169,23 @@ def AppIDHierNodeOp : ESI_Op<"manifest.hier_node", [ }]; } +def SymbolConstantsOp : ESI_Op<"manifest.consts", [ + DeclareOpInterfaceMethods]> { + let summary = "Constant values associated with a symbol"; + + let arguments = (ins FlatSymbolRefAttr:$symbolRef, + DictionaryAttr:$constants); + let assemblyFormat = [{ + $symbolRef $constants attr-dict + }]; + + let extraClassDeclaration = [{ + // Get information which needs to appear in the manifest for the host to + // connect to this service. + void getDetails(SmallVectorImpl &results); + }]; +} + def SymbolMetadataOp : ESI_Op<"manifest.sym", [ DeclareOpInterfaceMethods]> { let summary = "Metadata about a symbol"; diff --git a/lib/Dialect/ESI/ESIOps.cpp b/lib/Dialect/ESI/ESIOps.cpp index 893f375f60fc..acf794dc679b 100644 --- a/lib/Dialect/ESI/ESIOps.cpp +++ b/lib/Dialect/ESI/ESIOps.cpp @@ -696,6 +696,12 @@ void ServiceRequestRecordOp::getDetails( StringRef SymbolMetadataOp::getManifestClass() { return "sym_info"; } +StringRef SymbolConstantsOp::getManifestClass() { return "sym_consts"; } +void SymbolConstantsOp::getDetails(SmallVectorImpl &results) { + for (auto &attr : getConstantsAttr()) + results.push_back(attr); +} + #define GET_OP_CLASSES #include "circt/Dialect/ESI/ESI.cpp.inc" diff --git a/lib/Dialect/ESI/Passes/ESIBuildManifest.cpp b/lib/Dialect/ESI/Passes/ESIBuildManifest.cpp index a50afdf48261..a4022cf75d07 100644 --- a/lib/Dialect/ESI/Passes/ESIBuildManifest.cpp +++ b/lib/Dialect/ESI/Passes/ESIBuildManifest.cpp @@ -39,10 +39,13 @@ struct ESIBuildManifestPass void gatherFilters(Operation *); void gatherFilters(Attribute); - /// Get a JSON representation of a type. - llvm::json::Value json(Operation *errorOp, Type); - /// Get a JSON representation of a type. - llvm::json::Value json(Operation *errorOp, Attribute); + /// Get a JSON representation of a type. 'useTable' indicates whether to use + /// the type table to determine if the type should be emitted as a reference + /// if it already exists in the type table. + llvm::json::Value json(Operation *errorOp, Type, bool useTable = true); + /// Get a JSON representation of a type. 'elideType' indicates to not print + /// the type if it would have been printed. + llvm::json::Value json(Operation *errorOp, Attribute, bool elideType = false); // Output a node in the appid hierarchy. void emitNode(llvm::json::OStream &, AppIDHierNodeOp nodeOp); @@ -88,6 +91,12 @@ void ESIBuildManifestPass::runOnOperation() { // scraping unnecessary types. appidRoot->walk([&](Operation *op) { gatherFilters(op); }); + // Also gather types from the manifest data. + for (Region ®ion : mod->getRegions()) + for (Block &block : region) + for (auto manifestInfo : block.getOps()) + gatherFilters(manifestInfo); + // JSONify the manifest. std::string jsonManifest = json(); @@ -165,15 +174,34 @@ std::string ESIBuildManifestPass::json() { j.attribute("api_version", esiApiVersion); j.attributeArray("symbols", [&]() { - for (auto symInfo : mod.getBody()->getOps()) { - if (!symbols.contains(symInfo.getSymbolRefAttr())) + // First, gather all of the manifest data for each symbol. + DenseMap> symbolInfoLookup; + for (auto symInfo : mod.getBody()->getOps()) { + FlatSymbolRefAttr sym = symInfo.getSymbolRefAttr(); + if (!sym || !symbols.contains(sym)) continue; + symbolInfoLookup[sym].push_back(symInfo); + } + + // Now, emit a JSON object for each symbol. + for (const auto &symNameInfo : symbolInfoLookup) { j.object([&] { - SmallVector attrs; - symInfo.getDetails(attrs); - for (auto attr : attrs) - j.attribute(attr.getName().getValue(), - json(symInfo, attr.getValue())); + j.attribute("symbol", json(symNameInfo.second.front(), + symNameInfo.first, /*elideType=*/true)); + for (auto symInfo : symNameInfo.second) { + j.attributeBegin(symInfo.getManifestClass()); + j.object([&] { + SmallVector attrs; + symInfo.getDetails(attrs); + for (auto attr : attrs) { + if (attr.getName().getValue() == "symbolRef") + continue; + j.attribute(attr.getName().getValue(), + json(symInfo, attr.getValue())); + } + }); + j.attributeEnd(); + } }); } }); @@ -215,7 +243,7 @@ std::string ESIBuildManifestPass::json() { j.attributeArray("types", [&]() { for (auto type : types) { - j.value(json(mod, type)); + j.value(json(mod, type, /*useTable=*/false)); } }); j.objectEnd(); @@ -245,6 +273,7 @@ void ESIBuildManifestPass::gatherFilters(Attribute attr) { // This is far from complete. Build out as necessary. TypeSwitch(attr) .Case([&](TypeAttr a) { addType(a.getValue()); }) + .Case([&](IntegerAttr a) { addType(a.getType()); }) .Case([&](FlatSymbolRefAttr a) { symbols.insert(a); }) .Case([&](hw::InnerRefAttr a) { symbols.insert(a.getModuleRef()); }) .Case([&](ArrayAttr a) { @@ -259,18 +288,28 @@ void ESIBuildManifestPass::gatherFilters(Attribute attr) { /// Get a JSON representation of a type. // NOLINTNEXTLINE(misc-no-recursion) -llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Type type) { +llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Type type, + bool useTable) { using llvm::json::Array; using llvm::json::Object; using llvm::json::Value; + if (useTable && typeLookup.contains(type)) { + // If the type is in the type table, it'll be present in the types + // section. Just give the circt type name, which is guaranteed to + // uniquely identify the type. + std::string typeName; + llvm::raw_string_ostream(typeName) << type; + return typeName; + } + std::string m; Object o = // This is not complete. Build out as necessary. TypeSwitch(type) .Case([&](ChannelType t) { m = "channel"; - return Object({{"inner", json(errorOp, t.getInner())}}); + return Object({{"inner", json(errorOp, t.getInner(), useTable)}}); }) .Case([&](ChannelBundleType t) { m = "bundle"; @@ -279,7 +318,7 @@ llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Type type) { channels.push_back(Object( {{"name", field.name.getValue()}, {"direction", stringifyChannelDirection(field.direction)}, - {"type", json(errorOp, field.type)}})); + {"type", json(errorOp, field.type, useTable)}})); return Object({{"channels", Value(std::move(channels))}}); }) .Case([&](AnyType t) { @@ -288,25 +327,29 @@ llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Type type) { }) .Case([&](ListType t) { m = "list"; - return Object({{"element", json(errorOp, t.getElementType())}}); + return Object( + {{"element", json(errorOp, t.getElementType(), useTable)}}); }) .Case([&](hw::ArrayType t) { m = "array"; - return Object({{"size", t.getNumElements()}, - {"element", json(errorOp, t.getElementType())}}); + return Object( + {{"size", t.getNumElements()}, + {"element", json(errorOp, t.getElementType(), useTable)}}); }) .Case([&](hw::StructType t) { m = "struct"; Array fields; for (auto field : t.getElements()) - fields.push_back(Object({{"name", field.name.getValue()}, - {"type", json(errorOp, field.type)}})); + fields.push_back( + Object({{"name", field.name.getValue()}, + {"type", json(errorOp, field.type, useTable)}})); return Object({{"fields", Value(std::move(fields))}}); }) .Case([&](hw::TypeAliasType t) { m = "alias"; - return Object({{"name", t.getTypeDecl(symCache).getPreferredName()}, - {"inner", json(errorOp, t.getInnerType())}}); + return Object( + {{"name", t.getTypeDecl(symCache).getPreferredName()}, + {"inner", json(errorOp, t.getInnerType(), useTable)}}); }) .Case([&](IntegerType t) { m = "int"; @@ -322,9 +365,9 @@ llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Type type) { }); // Common metadata. - std::string circtName; - llvm::raw_string_ostream(circtName) << type; - o["circt_name"] = circtName; + std::string typeID; + llvm::raw_string_ostream(typeID) << type; + o["id"] = typeID; int64_t width = hw::getBitWidth(type); if (auto chanType = dyn_cast(type)) @@ -340,59 +383,58 @@ llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Type type) { // Serialize an attribute to a JSON value. // NOLINTNEXTLINE(misc-no-recursion) -llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, - Attribute attr) { +llvm::json::Value ESIBuildManifestPass::json(Operation *errorOp, Attribute attr, + bool elideType) { + // This is far from complete. Build out as necessary. + using llvm::json::Object; using llvm::json::Value; - return TypeSwitch(attr) - .Case([&](StringAttr a) { return a.getValue(); }) - .Case([&](IntegerAttr a) { return a.getValue().getLimitedValue(); }) - .Case([&](TypeAttr a) { - Type t = a.getValue(); - - llvm::json::Object typeMD; - if (typeLookup.contains(t)) { - // If the type is in the type table, it'll be present in the types - // section. Just give the circt type name, which is guaranteed to - // uniquely identify the type. - std::string buff; - llvm::raw_string_ostream(buff) << a; - typeMD["circt_name"] = buff; - return typeMD; - } + Value value = + TypeSwitch(attr) + .Case([&](StringAttr a) { return a.getValue(); }) + .Case([&](IntegerAttr a) { return a.getValue().getLimitedValue(); }) + .Case([&](TypeAttr a) { return json(errorOp, a.getValue()); }) + .Case([&](ArrayAttr a) { + return llvm::json::Array(llvm::map_range( + a, [&](Attribute a) { return json(errorOp, a); })); + }) + .Case([&](DictionaryAttr a) { + llvm::json::Object dict; + for (const auto &entry : a.getValue()) + dict[entry.getName().getValue()] = + json(errorOp, entry.getValue()); + return dict; + }) + .Case([&](hw::InnerRefAttr ref) { + llvm::json::Object dict; + dict["outer_sym"] = ref.getModule().getValue(); + dict["inner"] = ref.getName().getValue(); + return dict; + }) + .Case([&](AppIDAttr appid) { + llvm::json::Object dict; + dict["name"] = appid.getName().getValue(); + auto idx = appid.getIndex(); + if (idx) + dict["index"] = *idx; + return dict; + }) + .Default([&](Attribute a) { + std::string value; + llvm::raw_string_ostream(value) << a; + return value; + }); - typeMD["type"] = json(errorOp, t); - return typeMD; - }) - .Case([&](ArrayAttr a) { - return llvm::json::Array( - llvm::map_range(a, [&](Attribute a) { return json(errorOp, a); })); - }) - .Case([&](DictionaryAttr a) { - llvm::json::Object dict; - for (const auto &entry : a.getValue()) - dict[entry.getName().getValue()] = json(errorOp, entry.getValue()); - return dict; - }) - .Case([&](hw::InnerRefAttr ref) { - llvm::json::Object dict; - dict["outer_sym"] = ref.getModule().getValue(); - dict["inner"] = ref.getName().getValue(); - return dict; - }) - .Case([&](AppIDAttr appid) { - llvm::json::Object dict; - dict["name"] = appid.getName().getValue(); - auto idx = appid.getIndex(); - if (idx) - dict["index"] = *idx; - return dict; - }) - .Default([&](Attribute a) { - std::string buff; - llvm::raw_string_ostream(buff) << a; - return buff; - }); + // Don't print the type if it's None or we're eliding it. + auto typedAttr = llvm::dyn_cast(attr); + if (elideType || !typedAttr || isa(typedAttr.getType())) + return value; + + // Otherwise, return an object with the value and type. + Object dict; + dict["value"] = value; + dict["type"] = json(errorOp, typedAttr.getType()); + return dict; } std::unique_ptr> diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h index 997fb53158e8..ee502c72f774 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h @@ -54,12 +54,12 @@ class AppIDPath : public std::vector { bool operator<(const AppIDPath &a, const AppIDPath &b); struct ModuleInfo { - const std::optional name; - const std::optional summary; - const std::optional version; - const std::optional repo; - const std::optional commitHash; - const std::map extra; + std::optional name; + std::optional summary; + std::optional version; + std::optional repo; + std::optional commitHash; + std::map extra; }; /// A description of a service port. Used pretty exclusively in setting up the diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp index 62d16feec419..0640538eb91f 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp @@ -19,7 +19,7 @@ #include #include -using namespace esi; +using namespace ::esi; // While building the design, keep around a std::map of active services indexed // by the service name. When a new service is encountered during descent, add it @@ -95,6 +95,10 @@ class Manifest::Impl { const Type *parseType(const nlohmann::json &typeJson); + const std::map &getSymbolInfo() const { + return symbolInfoCache; + } + private: Context &ctxt; std::vector _typeTable; @@ -106,7 +110,7 @@ class Manifest::Impl { // The parsed json. nlohmann::json manifestJson; // Cache the module info for each symbol. - std::map symbolInfoCache; + std::map symbolInfoCache; }; //===----------------------------------------------------------------------===// @@ -169,14 +173,12 @@ static std::any getAny(const nlohmann::json &value) { throw std::runtime_error("Unknown type in manifest: " + value.dump(2)); } -static ModuleInfo parseModuleInfo(const nlohmann::json &mod) { - - std::map extras; +static void parseModuleInfo(ModuleInfo &info, const nlohmann::json &mod) { for (auto &extra : mod.items()) if (extra.key() != "name" && extra.key() != "summary" && extra.key() != "version" && extra.key() != "repo" && - extra.key() != "commitHash" && extra.key() != "symbolRef") - extras[extra.key()] = getAny(extra.value()); + extra.key() != "commitHash") + info.extra[extra.key()] = getAny(extra.value()); auto value = [&](const std::string &key) -> std::optional { auto f = mod.find(key); @@ -184,8 +186,11 @@ static ModuleInfo parseModuleInfo(const nlohmann::json &mod) { return std::nullopt; return f.value(); }; - return ModuleInfo{value("name"), value("summary"), value("version"), - value("repo"), value("commitHash"), extras}; + info.name = value("name"); + info.summary = value("summary"); + info.version = value("version"); + info.repo = value("repo"); + info.commitHash = value("commitHash"); } //===----------------------------------------------------------------------===// @@ -196,10 +201,23 @@ Manifest::Impl::Impl(Context &ctxt, const std::string &manifestStr) : ctxt(ctxt) { manifestJson = nlohmann::ordered_json::parse(manifestStr); - for (auto &mod : manifestJson.at("symbols")) - symbolInfoCache.insert( - make_pair(mod.at("symbolRef"), parseModuleInfo(mod))); - populateTypes(manifestJson.at("types")); + try { + // Populate the types table first since anything else might need it. + populateTypes(manifestJson.at("types")); + + // Populate the symbol info cache. + for (auto &mod : manifestJson.at("symbols")) { + ModuleInfo info; + if (mod.contains("sym_info")) + parseModuleInfo(info, mod); + symbolInfoCache.insert(make_pair(mod.at("symbol"), info)); + } + } catch (const std::exception &e) { + std::string msg = "malformed manifest: " + std::string(e.what()); + if (manifestJson.at("api_version") == 0) + msg += " (schema version 0 is not considered stable)"; + throw std::runtime_error(msg); + } } std::unique_ptr @@ -377,7 +395,7 @@ Manifest::Impl::getBundlePorts(AcceleratorConnection &acc, AppIDPath idPath, } services::Service *svc = svcIter->second; - std::string typeName = content.at("bundleType").at("circt_name"); + std::string typeName = content.at("bundleType"); auto type = getType(typeName); if (!type) throw std::runtime_error( @@ -425,12 +443,12 @@ BundleType *parseBundleType(const nlohmann::json &typeJson, Context &cache) { channels.emplace_back(chanJson.at("name"), dir, parseType(chanJson["type"], cache)); } - return new BundleType(typeJson.at("circt_name"), channels); + return new BundleType(typeJson.at("id"), channels); } ChannelType *parseChannelType(const nlohmann::json &typeJson, Context &cache) { assert(typeJson.at("mnemonic") == "channel"); - return new ChannelType(typeJson.at("circt_name"), + return new ChannelType(typeJson.at("id"), parseType(typeJson.at("inner"), cache)); } @@ -438,7 +456,7 @@ Type *parseInt(const nlohmann::json &typeJson, Context &cache) { assert(typeJson.at("mnemonic") == "int"); std::string sign = typeJson.at("signedness"); uint64_t width = typeJson.at("hw_bitwidth"); - Type::ID id = typeJson.at("circt_name"); + Type::ID id = typeJson.at("id"); if (sign == "signed") return new SIntType(id, width); @@ -459,13 +477,13 @@ StructType *parseStruct(const nlohmann::json &typeJson, Context &cache) { for (auto &fieldJson : typeJson["fields"]) fields.emplace_back(fieldJson.at("name"), parseType(fieldJson["type"], cache)); - return new StructType(typeJson.at("circt_name"), fields); + return new StructType(typeJson.at("id"), fields); } ArrayType *parseArray(const nlohmann::json &typeJson, Context &cache) { assert(typeJson.at("mnemonic") == "array"); uint64_t size = typeJson.at("size"); - return new ArrayType(typeJson.at("circt_name"), + return new ArrayType(typeJson.at("id"), parseType(typeJson.at("element"), cache), size); } @@ -473,10 +491,8 @@ using TypeParser = std::function; const std::map typeParsers = { {"bundle", parseBundleType}, {"channel", parseChannelType}, - {"std::any", - [](const nlohmann::json &typeJson, Context &cache) { - return new AnyType(typeJson.at("circt_name")); - }}, + {"std::any", [](const nlohmann::json &typeJson, + Context &cache) { return new AnyType(typeJson.at("id")); }}, {"int", parseInt}, {"struct", parseStruct}, {"array", parseArray}, @@ -485,19 +501,24 @@ const std::map typeParsers = { // Parse a type if it doesn't already exist in the cache. const Type *parseType(const nlohmann::json &typeJson, Context &cache) { - // We use the circt type string as a unique ID. - std::string circt_name = typeJson.at("circt_name"); - if (std::optional t = cache.getType(circt_name)) + std::string id; + if (typeJson.is_string()) + id = typeJson.get(); + else + id = typeJson.at("id"); + if (std::optional t = cache.getType(id)) return *t; + if (typeJson.is_string()) + throw std::runtime_error("malformed manifest: unknown type '" + id + "'"); - std::string mnemonic = typeJson.at("mnemonic"); Type *t; + std::string mnemonic = typeJson.at("mnemonic"); auto f = typeParsers.find(mnemonic); if (f != typeParsers.end()) t = f->second(typeJson, cache); else // Types we don't know about are opaque. - t = new Type(circt_name); + t = new Type(id); // Insert into the cache. cache.registerType(t); @@ -529,13 +550,20 @@ uint32_t Manifest::getApiVersion() const { std::vector Manifest::getModuleInfos() const { std::vector ret; - for (auto &mod : impl->at("symbols")) - ret.push_back(parseModuleInfo(mod)); + for (auto &[symbol, info] : impl->getSymbolInfo()) + ret.push_back(info); return ret; } Accelerator *Manifest::buildAccelerator(AcceleratorConnection &acc) const { - return acc.takeOwnership(impl->buildAccelerator(acc)); + try { + return acc.takeOwnership(impl->buildAccelerator(acc)); + } catch (const std::exception &e) { + std::string msg = "malformed manifest: " + std::string(e.what()); + if (getApiVersion() == 0) + msg += " (schema version 0 is not considered stable)"; + throw std::runtime_error(msg); + } } const std::vector &Manifest::getTypeTable() const { diff --git a/test/Dialect/ESI/manifest.mlir b/test/Dialect/ESI/manifest.mlir index b32af1474641..b86f94f1de2b 100644 --- a/test/Dialect/ESI/manifest.mlir +++ b/test/Dialect/ESI/manifest.mlir @@ -33,6 +33,7 @@ hw.module @Loopback (in %clk: !seq.clock) { } esi.manifest.sym @Loopback name "LoopbackIP" version "v0.0" summary "IP which simply echos bytes" {foo=1} +esi.manifest.consts @Loopback {depth=5:ui32} esi.service.std.func @funcs @@ -95,13 +96,25 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK: { // CHECK-LABEL: "api_version": 0, + // CHECK-LABEL: "symbols": [ // CHECK-NEXT: { -// CHECK-NEXT: "foo": 1, -// CHECK-NEXT: "name": "LoopbackIP", -// CHECK-NEXT: "summary": "IP which simply echos bytes", -// CHECK-NEXT: "symbolRef": "@Loopback", -// CHECK-NEXT: "version": "v0.0" +// CHECK-NEXT: "symbol": "@Loopback", +// CHECK-NEXT: "sym_info": { +// CHECK-NEXT: "foo": { +// CHECK-NEXT: "type": "i64", +// CHECK-NEXT: "value": 1 +// CHECK-NEXT: }, +// CHECK-NEXT: "name": "LoopbackIP", +// CHECK-NEXT: "summary": "IP which simply echos bytes", +// CHECK-NEXT: "version": "v0.0" +// CHECK-NEXT: }, +// CHECK-NEXT: "sym_consts": { +// CHECK-NEXT: "depth": { +// CHECK-NEXT: "type": "ui32", +// CHECK-NEXT: "value": 5 +// CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ], @@ -231,9 +244,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "func1" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"arg\", !esi.channel from \"result\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel to \"arg\", !esi.channel from \"result\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "call", // CHECK-NEXT: "outer_sym": "funcs" @@ -254,9 +265,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "loopback_tohw" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"recv\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel to \"recv\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "Recv", // CHECK-NEXT: "outer_sym": "HostComms" @@ -267,9 +276,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "loopback_fromhw" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel from \"send\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "Send", // CHECK-NEXT: "outer_sym": "HostComms" @@ -280,9 +287,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "loopback_fromhw_i0" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel from \"send\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "SendI0", // CHECK-NEXT: "outer_sym": "HostComms" @@ -303,9 +308,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "loopback_tohw" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"recv\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel to \"recv\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "Recv", // CHECK-NEXT: "outer_sym": "HostComms" @@ -316,9 +319,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "loopback_fromhw" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel from \"send\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "Send", // CHECK-NEXT: "outer_sym": "HostComms" @@ -329,9 +330,7 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "appID": { // CHECK-NEXT: "name": "loopback_fromhw_i0" // CHECK-NEXT: }, -// CHECK-NEXT: "bundleType": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>" -// CHECK-NEXT: }, +// CHECK-NEXT: "bundleType": "!esi.bundle<[!esi.channel from \"send\"]>" // CHECK-NEXT: "servicePort": { // CHECK-NEXT: "inner": "SendI0", // CHECK-NEXT: "outer_sym": "HostComms" @@ -349,21 +348,15 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "ports": [ // CHECK-NEXT: { // CHECK-NEXT: "name": "Send", -// CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>" -// CHECK-NEXT: } +// CHECK-NEXT: "type": "!esi.bundle<[!esi.channel from \"send\"]>" // CHECK-NEXT: }, // CHECK-NEXT: { // CHECK-NEXT: "name": "Recv", -// CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"recv\"]>" -// CHECK-NEXT: } +// CHECK-NEXT: "type": "!esi.bundle<[!esi.channel to \"recv\"]>" // CHECK-NEXT: }, // CHECK-NEXT: { // CHECK-NEXT: "name": "SendI0", -// CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>" -// CHECK-NEXT: } +// CHECK-NEXT: "type": "!esi.bundle<[!esi.channel from \"send\"]>" // CHECK-NEXT: } // CHECK-NEXT: ] // CHECK-NEXT: }, @@ -374,41 +367,39 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: { // CHECK-NEXT: "name": "call", // CHECK-NEXT: "type": { -// CHECK-NEXT: "type": { -// CHECK-NEXT: "channels": [ -// CHECK-NEXT: { -// CHECK-NEXT: "direction": "to", -// CHECK-NEXT: "name": "arg", -// CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", +// CHECK-NEXT: "channels": [ +// CHECK-NEXT: { +// CHECK-NEXT: "direction": "to", +// CHECK-NEXT: "name": "arg", +// CHECK-NEXT: "type": { +// CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.channel", +// CHECK-NEXT: "inner": { // CHECK-NEXT: "dialect": "esi", -// CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "!esi.any", -// CHECK-NEXT: "dialect": "esi", -// CHECK-NEXT: "mnemonic": "any" -// CHECK-NEXT: }, -// CHECK-NEXT: "mnemonic": "channel" -// CHECK-NEXT: } -// CHECK-NEXT: }, -// CHECK-NEXT: { -// CHECK-NEXT: "direction": "from", -// CHECK-NEXT: "name": "result", -// CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", +// CHECK-NEXT: "id": "!esi.any", +// CHECK-NEXT: "mnemonic": "any" +// CHECK-NEXT: }, +// CHECK-NEXT: "mnemonic": "channel" +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "direction": "from", +// CHECK-NEXT: "name": "result", +// CHECK-NEXT: "type": { +// CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.channel", +// CHECK-NEXT: "inner": { // CHECK-NEXT: "dialect": "esi", -// CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "!esi.any", -// CHECK-NEXT: "dialect": "esi", -// CHECK-NEXT: "mnemonic": "any" -// CHECK-NEXT: }, -// CHECK-NEXT: "mnemonic": "channel" -// CHECK-NEXT: } +// CHECK-NEXT: "id": "!esi.any", +// CHECK-NEXT: "mnemonic": "any" +// CHECK-NEXT: }, +// CHECK-NEXT: "mnemonic": "channel" // CHECK-NEXT: } -// CHECK-NEXT: ], -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"arg\", !esi.channel from \"result\"]>", -// CHECK-NEXT: "dialect": "esi", -// CHECK-NEXT: "mnemonic": "bundle" -// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ], +// CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.bundle<[!esi.channel to \"arg\", !esi.channel from \"result\"]>", +// CHECK-NEXT: "mnemonic": "bundle" // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ] @@ -422,13 +413,13 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "direction": "to", // CHECK-NEXT: "name": "recv", // CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", // CHECK-NEXT: "dialect": "esi", // CHECK-NEXT: "hw_bitwidth": 8, +// CHECK-NEXT: "id": "!esi.channel", // CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "i8", // CHECK-NEXT: "dialect": "builtin", // CHECK-NEXT: "hw_bitwidth": 8, +// CHECK-NEXT: "id": "i8", // CHECK-NEXT: "mnemonic": "int", // CHECK-NEXT: "signedness": "signless" // CHECK-NEXT: }, @@ -436,8 +427,8 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ], -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"recv\"]>", // CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.bundle<[!esi.channel to \"recv\"]>", // CHECK-NEXT: "mnemonic": "bundle" // CHECK-NEXT: }, // CHECK-NEXT: { @@ -446,13 +437,13 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "direction": "from", // CHECK-NEXT: "name": "send", // CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", // CHECK-NEXT: "dialect": "esi", // CHECK-NEXT: "hw_bitwidth": 8, +// CHECK-NEXT: "id": "!esi.channel", // CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "i8", // CHECK-NEXT: "dialect": "builtin", // CHECK-NEXT: "hw_bitwidth": 8, +// CHECK-NEXT: "id": "i8", // CHECK-NEXT: "mnemonic": "int", // CHECK-NEXT: "signedness": "signless" // CHECK-NEXT: }, @@ -460,8 +451,8 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ], -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>", // CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.bundle<[!esi.channel from \"send\"]>", // CHECK-NEXT: "mnemonic": "bundle" // CHECK-NEXT: }, // CHECK-NEXT: { @@ -470,13 +461,13 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "direction": "from", // CHECK-NEXT: "name": "send", // CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", // CHECK-NEXT: "dialect": "esi", // CHECK-NEXT: "hw_bitwidth": 0, +// CHECK-NEXT: "id": "!esi.channel", // CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "i0", // CHECK-NEXT: "dialect": "builtin", // CHECK-NEXT: "hw_bitwidth": 0, +// CHECK-NEXT: "id": "i0", // CHECK-NEXT: "mnemonic": "int", // CHECK-NEXT: "signedness": "signless" // CHECK-NEXT: }, @@ -484,8 +475,8 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ], -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel from \"send\"]>", // CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.bundle<[!esi.channel from \"send\"]>", // CHECK-NEXT: "mnemonic": "bundle" // CHECK-NEXT: }, // CHECK-NEXT: { @@ -494,13 +485,13 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "direction": "to", // CHECK-NEXT: "name": "arg", // CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", // CHECK-NEXT: "dialect": "esi", // CHECK-NEXT: "hw_bitwidth": 16, +// CHECK-NEXT: "id": "!esi.channel", // CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "i16", // CHECK-NEXT: "dialect": "builtin", // CHECK-NEXT: "hw_bitwidth": 16, +// CHECK-NEXT: "id": "i16", // CHECK-NEXT: "mnemonic": "int", // CHECK-NEXT: "signedness": "signless" // CHECK-NEXT: }, @@ -511,13 +502,13 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: "direction": "from", // CHECK-NEXT: "name": "result", // CHECK-NEXT: "type": { -// CHECK-NEXT: "circt_name": "!esi.channel", // CHECK-NEXT: "dialect": "esi", // CHECK-NEXT: "hw_bitwidth": 16, +// CHECK-NEXT: "id": "!esi.channel", // CHECK-NEXT: "inner": { -// CHECK-NEXT: "circt_name": "i16", // CHECK-NEXT: "dialect": "builtin", // CHECK-NEXT: "hw_bitwidth": 16, +// CHECK-NEXT: "id": "i16", // CHECK-NEXT: "mnemonic": "int", // CHECK-NEXT: "signedness": "signless" // CHECK-NEXT: }, @@ -525,9 +516,23 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ], -// CHECK-NEXT: "circt_name": "!esi.bundle<[!esi.channel to \"arg\", !esi.channel from \"result\"]>", // CHECK-NEXT: "dialect": "esi", +// CHECK-NEXT: "id": "!esi.bundle<[!esi.channel to \"arg\", !esi.channel from \"result\"]>", // CHECK-NEXT: "mnemonic": "bundle" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "dialect": "builtin", +// CHECK-NEXT: "hw_bitwidth": 64, +// CHECK-NEXT: "id": "i64", +// CHECK-NEXT: "mnemonic": "int", +// CHECK-NEXT: "signedness": "signless" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "dialect": "builtin", +// CHECK-NEXT: "hw_bitwidth": 32, +// CHECK-NEXT: "id": "ui32", +// CHECK-NEXT: "mnemonic": "int", +// CHECK-NEXT: "signedness": "unsigned" // CHECK-NEXT: } // CHECK-NEXT: ] // CHECK-NEXT: } From b5ce94b7230fecda0c0aa92ab83ed1ce681c3db8 Mon Sep 17 00:00:00 2001 From: John Demme Date: Fri, 9 Aug 2024 19:27:05 +0000 Subject: [PATCH 2/2] [ESI][Runtime] Parse and expose manifest constants Access per-module constants from the ModuleInfo class. --- .../Dialect/ESI/runtime/loopback.mlir | 12 +++ .../Dialect/ESI/runtime/loopback.mlir.py | 7 ++ .../ESI/runtime/cpp/include/esi/Common.h | 7 ++ lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp | 85 ++++++++++++++----- .../runtime/python/esiaccel/esiCppAccel.cpp | 45 ++++++++++ 5 files changed, 133 insertions(+), 23 deletions(-) diff --git a/integration_test/Dialect/ESI/runtime/loopback.mlir b/integration_test/Dialect/ESI/runtime/loopback.mlir index f5eec8e5e390..d1fceb5e3a8f 100644 --- a/integration_test/Dialect/ESI/runtime/loopback.mlir +++ b/integration_test/Dialect/ESI/runtime/loopback.mlir @@ -3,6 +3,7 @@ // RUN: circt-opt %s --esi-connect-services --esi-appid-hier=top=top --esi-build-manifest=top=top --esi-clean-metadata > %t4.mlir // RUN: circt-opt %t4.mlir --lower-esi-to-physical --lower-esi-bundles --lower-esi-ports --lower-esi-to-hw=platform=cosim --lower-seq-to-sv --lower-hwarith-to-hw --canonicalize --export-split-verilog -o %t3.mlir // RUN: cd .. +// RUN: esiquery trace w:%t6/esi_system_manifest.json info | FileCheck %s --check-prefix=QUERY-INFO // RUN: esiquery trace w:%t6/esi_system_manifest.json hier | FileCheck %s --check-prefix=QUERY-HIER // RUN: %python %s.py trace w:%t6/esi_system_manifest.json // RUN: esi-cosim.py --source %t6 --top top -- %python %s.py cosim env @@ -95,6 +96,7 @@ hw.module @CallableFunc1() { } esi.manifest.sym @Loopback name "LoopbackIP" version "v0.0" summary "IP which simply echos bytes" {foo=1} +esi.manifest.consts @Loopback {depth=5:ui32} hw.module @top(in %clk: !seq.clock, in %rst: i1) { esi.service.instance #esi.appid<"cosim"> svc @HostComms impl as "cosim" (%clk, %rst) : (!seq.clock, i1) -> () @@ -107,6 +109,16 @@ hw.module @top(in %clk: !seq.clock, in %rst: i1) { hw.instance "loopback_array" @LoopbackArray() -> () } +// QUERY-INFO: API version: 0 +// QUERY-INFO: ******************************** +// QUERY-INFO: * Module information +// QUERY-INFO: ******************************** +// QUERY-INFO: - LoopbackIP v0.0 : IP which simply echos bytes +// QUERY-INFO: Constants: +// QUERY-INFO: depth: 5 +// QUERY-INFO: Extra metadata: +// QUERY-INFO: foo: 1 + // QUERY-HIER: ******************************** // QUERY-HIER: * Design hierarchy // QUERY-HIER: ******************************** diff --git a/integration_test/Dialect/ESI/runtime/loopback.mlir.py b/integration_test/Dialect/ESI/runtime/loopback.mlir.py index a942a45b232b..091b02511599 100644 --- a/integration_test/Dialect/ESI/runtime/loopback.mlir.py +++ b/integration_test/Dialect/ESI/runtime/loopback.mlir.py @@ -21,6 +21,13 @@ for esiType in m.type_table: print(f"{esiType}") +for info in m.module_infos: + print(f"{info.name}") + for const_name, const in info.constants.items(): + print(f" {const_name}: {const.value} {const.type}") + if info.name == "LoopbackIP" and const_name == "depth": + assert const.value == 5 + d = acc.build_accelerator() loopback = d.children[esiaccel.AppID("loopback_inst", 0)] diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h index ee502c72f774..a47e2c0e76dd 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h @@ -25,6 +25,7 @@ #include namespace esi { +class Type; //===----------------------------------------------------------------------===// // Common accelerator description types. @@ -53,12 +54,18 @@ class AppIDPath : public std::vector { }; bool operator<(const AppIDPath &a, const AppIDPath &b); +struct Constant { + std::any value; + std::optional type; +}; + struct ModuleInfo { std::optional name; std::optional summary; std::optional version; std::optional repo; std::optional commitHash; + std::map constants; std::map extra; }; diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp index 0640538eb91f..56e7986179e4 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp @@ -107,6 +107,10 @@ class Manifest::Impl { return ctxt.getType(id); } + std::any getAny(const nlohmann::json &value) const; + void parseModuleMetadata(ModuleInfo &info, const nlohmann::json &mod) const; + void parseModuleConsts(ModuleInfo &info, const nlohmann::json &mod) const; + // The parsed json. nlohmann::json manifestJson; // Cache the module info for each symbol. @@ -138,42 +142,50 @@ static ServicePortDesc parseServicePort(const nlohmann::json &jsonPort) { /// Convert the json value to a 'std::any', which can be exposed outside of this /// file. -static std::any getAny(const nlohmann::json &value) { - auto getObject = [](const nlohmann::json &json) { +std::any Manifest::Impl::getAny(const nlohmann::json &value) const { + auto getObject = [this](const nlohmann::json &json) -> std::any { std::map ret; for (auto &e : json.items()) ret[e.key()] = getAny(e.value()); return ret; }; - auto getArray = [](const nlohmann::json &json) { + auto getArray = [this](const nlohmann::json &json) -> std::any { std::vector ret; for (auto &e : json) ret.push_back(getAny(e)); return ret; }; - if (value.is_string()) - return value.get(); - else if (value.is_number_integer()) - return value.get(); - else if (value.is_number_unsigned()) - return value.get(); - else if (value.is_number_float()) - return value.get(); - else if (value.is_boolean()) - return value.get(); - else if (value.is_null()) - return value.get(); - else if (value.is_object()) - return getObject(value); - else if (value.is_array()) - return getArray(value); - else - throw std::runtime_error("Unknown type in manifest: " + value.dump(2)); + auto getValue = [&](const nlohmann::json &innerValue) -> std::any { + if (innerValue.is_string()) + return innerValue.get(); + else if (innerValue.is_number_integer()) + return innerValue.get(); + else if (innerValue.is_number_unsigned()) + return innerValue.get(); + else if (innerValue.is_number_float()) + return innerValue.get(); + else if (innerValue.is_boolean()) + return innerValue.get(); + else if (innerValue.is_null()) + return innerValue.get(); + else if (innerValue.is_object()) + return getObject(innerValue); + else if (innerValue.is_array()) + return getArray(innerValue); + else + throw std::runtime_error("Unknown type in manifest: " + + innerValue.dump(2)); + }; + + if (!value.is_object() || !value.contains("type") || !value.contains("value")) + return getValue(value); + return Constant{getValue(value.at("value")), getType(value.at("type"))}; } -static void parseModuleInfo(ModuleInfo &info, const nlohmann::json &mod) { +void Manifest::Impl::parseModuleMetadata(ModuleInfo &info, + const nlohmann::json &mod) const { for (auto &extra : mod.items()) if (extra.key() != "name" && extra.key() != "summary" && extra.key() != "version" && extra.key() != "repo" && @@ -193,6 +205,19 @@ static void parseModuleInfo(ModuleInfo &info, const nlohmann::json &mod) { info.commitHash = value("commitHash"); } +void Manifest::Impl::parseModuleConsts(ModuleInfo &info, + const nlohmann::json &mod) const { + for (auto &item : mod.items()) { + std::any value = getAny(item.value()); + auto *c = std::any_cast(&value); + if (c) + info.constants[item.key()] = *c; + else + // If the value isn't a "proper" constant, present it as one with no type. + info.constants[item.key()] = Constant{value, std::nullopt}; + } +} + //===----------------------------------------------------------------------===// // Manifest::Impl class implementation. //===----------------------------------------------------------------------===// @@ -209,7 +234,9 @@ Manifest::Impl::Impl(Context &ctxt, const std::string &manifestStr) for (auto &mod : manifestJson.at("symbols")) { ModuleInfo info; if (mod.contains("sym_info")) - parseModuleInfo(info, mod); + parseModuleMetadata(info, mod.at("sym_info")); + if (mod.contains("sym_consts")) + parseModuleConsts(info, mod.at("sym_consts")); symbolInfoCache.insert(make_pair(mod.at("symbol"), info)); } } catch (const std::exception &e) { @@ -577,6 +604,9 @@ const std::vector &Manifest::getTypeTable() const { // Print a module info, including the extra metadata. std::ostream &operator<<(std::ostream &os, const ModuleInfo &m) { auto printAny = [&os](std::any a) { + if (auto *c = std::any_cast(&a)) + a = std::any_cast(a).value; + const std::type_info &t = a.type(); if (t == typeid(std::string)) os << std::any_cast(a); @@ -610,6 +640,15 @@ std::ostream &operator<<(std::ostream &os, const ModuleInfo &m) { os << ": " << *m.summary; os << "\n"; + if (!m.constants.empty()) { + os << " Constants:\n"; + for (auto &e : m.constants) { + os << " " << e.first << ": "; + printAny(e.second); + os << "\n"; + } + } + if (!m.extra.empty()) { os << " Extra metadata:\n"; for (auto &e : m.extra) { diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp index ab8b94e81d67..051106ecc5e2 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp +++ b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp @@ -42,8 +42,45 @@ struct polymorphic_type_hook { return port; } }; + +namespace detail { +/// Pybind11 doesn't have a built-in type caster for std::any +/// (https://github.com/pybind/pybind11/issues/1590). We must provide one which +/// knows about all of the potential types which the any might be. +template <> +struct type_caster { +public: + PYBIND11_TYPE_CASTER(std::any, const_name("object")); + + static handle cast(std::any src, return_value_policy /* policy */, + handle /* parent */) { + const std::type_info &t = src.type(); + if (t == typeid(std::string)) + return py::str(std::any_cast(src)); + else if (t == typeid(int64_t)) + return py::int_(std::any_cast(src)); + else if (t == typeid(uint64_t)) + return py::int_(std::any_cast(src)); + else if (t == typeid(double)) + return py::float_(std::any_cast(src)); + else if (t == typeid(bool)) + return py::bool_(std::any_cast(src)); + else if (t == typeid(std::nullptr_t)) + return py::none(); + return py::none(); + } +}; +} // namespace detail } // namespace pybind11 +/// Resolve a Type to the Python wrapper object. +py::object getPyType(std::optional t) { + py::object typesModule = py::module_::import("esiaccel.types"); + if (!t) + return py::none(); + return typesModule.attr("_get_esi_type")(*t); +} + // NOLINTNEXTLINE(readability-identifier-naming) PYBIND11_MODULE(esiCppAccel, m) { py::class_(m, "Type") @@ -75,6 +112,12 @@ PYBIND11_MODULE(esiCppAccel, m) { py::return_value_policy::reference) .def_property_readonly("size", &ArrayType::getSize); + py::class_(m, "Constant") + .def_property_readonly("value", [](Constant &c) { return c.value; }) + .def_property_readonly( + "type", [](Constant &c) { return getPyType(*c.type); }, + py::return_value_policy::reference); + py::class_(m, "ModuleInfo") .def_property_readonly("name", [](ModuleInfo &info) { return info.name; }) .def_property_readonly("summary", @@ -84,6 +127,8 @@ PYBIND11_MODULE(esiCppAccel, m) { .def_property_readonly("repo", [](ModuleInfo &info) { return info.repo; }) .def_property_readonly("commit_hash", [](ModuleInfo &info) { return info.commitHash; }) + .def_property_readonly("constants", + [](ModuleInfo &info) { return info.constants; }) // TODO: "extra" field. .def("__repr__", [](ModuleInfo &info) { std::string ret;