From 6431091409a79f3cddf1d62c1c040e9569169694 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Wed, 7 Aug 2024 12:23:18 -0700 Subject: [PATCH] [OM] Support list concatenation in the Evaluator. List concatenation is implemented by first ensuring all the sub-lists are evaluated, and then appending the elements of each sub-list into the final ListValue that represents the concatenation of the lists. --- .../circt/Dialect/OM/Evaluator/Evaluator.h | 3 + lib/Dialect/OM/Evaluator/Evaluator.cpp | 50 +++++++- .../Dialect/OM/Evaluator/EvaluatorTests.cpp | 111 ++++++++++++++++++ 3 files changed, 160 insertions(+), 4 deletions(-) diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index 3834e626c108..2592fbdc7e2b 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -481,6 +481,9 @@ struct Evaluator { FailureOr evaluateListCreate(ListCreateOp op, ActualParameters actualParams, Location loc); + FailureOr evaluateListConcat(ListConcatOp op, + ActualParameters actualParams, + Location loc); FailureOr evaluateTupleCreate(TupleCreateOp op, ActualParameters actualParams, Location loc); diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index d029cfc9df55..e100a0e8d086 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -167,10 +167,10 @@ FailureOr circt::om::Evaluator::getOrCreateValue( evaluator::PathValue::getEmptyPath(loc)); return success(result); }) - .Case( - [&](auto op) { - return getPartiallyEvaluatedValue(op.getType(), loc); - }) + .Case([&](auto op) { + return getPartiallyEvaluatedValue(op.getType(), loc); + }) .Case([&](auto op) { return getPartiallyEvaluatedValue(op.getType(), op.getLoc()); }) @@ -360,6 +360,9 @@ circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams, .Case([&](ListCreateOp op) { return evaluateListCreate(op, actualParams, loc); }) + .Case([&](ListConcatOp op) { + return evaluateListConcat(op, actualParams, loc); + }) .Case([&](TupleCreateOp op) { return evaluateTupleCreate(op, actualParams, loc); }) @@ -583,6 +586,45 @@ circt::om::Evaluator::evaluateListCreate(ListCreateOp op, return list; } +/// Evaluator dispatch function for List concatenation. +FailureOr +circt::om::Evaluator::evaluateListConcat(ListConcatOp op, + ActualParameters actualParams, + Location loc) { + // Evaluate the List concat op itself, in case it hasn't been evaluated yet. + SmallVector values; + auto list = getOrCreateValue(op, actualParams, loc); + + // Extract the ListValue, either directly or through an object reference. + auto extractList = [](evaluator::EvaluatorValue *value) { + return std::move( + llvm::TypeSwitch( + value) + .Case([](evaluator::ListValue *val) { return val; }) + .Case([](evaluator::ReferenceValue *val) { + return cast(val->getStrippedValue()->get()); + })); + }; + + for (auto operand : op.getOperands()) { + auto result = evaluateValue(operand, actualParams, loc); + if (failed(result)) + return result; + if (!result.value()->isFullyEvaluated()) + return list; + + // Append each EvaluatorValue from the sublist. + evaluator::ListValue *subList = extractList(result.value().get()); + for (auto subValue : subList->getElements()) + values.push_back(subValue); + } + + // Return the concatenated list. + llvm::cast(list.value().get()) + ->setElements(std::move(values)); + return list; +} + /// Evaluator dispatch function for Tuple creation. FailureOr circt::om::Evaluator::evaluateTupleCreate(TupleCreateOp op, diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index 8b9e24b0a86b..abd9547c4632 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -919,4 +919,115 @@ TEST(EvaluatorTests, IntegerBinaryArithmeticWidthMismatch) { .getValue()); } +TEST(EvaluatorTests, ListConcat) { + StringRef mod = "om.class @ListConcat() {" + " %0 = om.constant #om.integer<0 : i8> : !om.integer" + " %1 = om.constant #om.integer<1 : i8> : !om.integer" + " %2 = om.constant #om.integer<2 : i8> : !om.integer" + " %l0 = om.list_create %0, %1 : !om.integer" + " %l1 = om.list_create %2 : !om.integer" + " %concat = om.list_concat %l0, %l1 : !om.list" + " om.class.field @result, %concat : !om.list" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = + evaluator.instantiate(StringAttr::get(&context, "ListConcat"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + auto finalList = + llvm::cast(fieldValue.get())->getElements(); + + ASSERT_EQ(3, finalList.size()); + + ASSERT_EQ(0, llvm::cast(finalList[0].get()) + ->getAs() + .getValue() + .getValue()); + + ASSERT_EQ(1, llvm::cast(finalList[1].get()) + ->getAs() + .getValue() + .getValue()); + + ASSERT_EQ(2, llvm::cast(finalList[2].get()) + ->getAs() + .getValue() + .getValue()); +} + +TEST(EvaluatorTests, ListConcatField) { + StringRef mod = + "om.class @ListField() {" + " %0 = om.constant #om.integer<2 : i8> : !om.integer" + " %1 = om.list_create %0 : !om.integer" + " om.class.field @value, %1 : !om.list" + "}" + "om.class @ListConcatField() {" + " %listField = om.object @ListField() : () -> !om.class.type<@ListField>" + " %0 = om.constant #om.integer<0 : i8> : !om.integer" + " %1 = om.constant #om.integer<1 : i8> : !om.integer" + " %l0 = om.list_create %0, %1 : !om.integer" + " %l1 = om.object.field %listField, [@value] : " + "(!om.class.type<@ListField>) -> !om.list" + " %concat = om.list_concat %l0, %l1 : !om.list" + " om.class.field @result, %concat : !om.list" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = + evaluator.instantiate(StringAttr::get(&context, "ListConcatField"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + auto finalList = + llvm::cast(fieldValue.get())->getElements(); + + ASSERT_EQ(3, finalList.size()); + + ASSERT_EQ(0, llvm::cast(finalList[0].get()) + ->getAs() + .getValue() + .getValue()); + + ASSERT_EQ(1, llvm::cast(finalList[1].get()) + ->getAs() + .getValue() + .getValue()); + + ASSERT_EQ(2, llvm::cast(finalList[2].get()) + ->getAs() + .getValue() + .getValue()); +} + } // namespace