Skip to content

Commit

Permalink
[OM] Support list concatenation in the Evaluator. (#7511)
Browse files Browse the repository at this point in the history
List concatenation is implemented by first ensuring all the sub-lists
are evaluated, and then appending the elements of each sub-list into
the final ListValue that represents the concatenation of the lists.
  • Loading branch information
mikeurbach authored Aug 13, 2024
1 parent ebb2429 commit 4cfcbda
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 4 deletions.
3 changes: 3 additions & 0 deletions include/circt/Dialect/OM/Evaluator/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,9 @@ struct Evaluator {
FailureOr<EvaluatorValuePtr> evaluateListCreate(ListCreateOp op,
ActualParameters actualParams,
Location loc);
FailureOr<EvaluatorValuePtr> evaluateListConcat(ListConcatOp op,
ActualParameters actualParams,
Location loc);
FailureOr<EvaluatorValuePtr>
evaluateTupleCreate(TupleCreateOp op, ActualParameters actualParams,
Location loc);
Expand Down
50 changes: 46 additions & 4 deletions lib/Dialect/OM/Evaluator/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ FailureOr<evaluator::EvaluatorValuePtr> circt::om::Evaluator::getOrCreateValue(
evaluator::PathValue::getEmptyPath(loc));
return success(result);
})
.Case<ListCreateOp, TupleCreateOp, MapCreateOp, ObjectFieldOp>(
[&](auto op) {
return getPartiallyEvaluatedValue(op.getType(), loc);
})
.Case<ListCreateOp, ListConcatOp, TupleCreateOp, MapCreateOp,
ObjectFieldOp>([&](auto op) {
return getPartiallyEvaluatedValue(op.getType(), loc);
})
.Case<ObjectOp>([&](auto op) {
return getPartiallyEvaluatedValue(op.getType(), op.getLoc());
})
Expand Down Expand Up @@ -360,6 +360,9 @@ circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams,
.Case([&](ListCreateOp op) {
return evaluateListCreate(op, actualParams, loc);
})
.Case([&](ListConcatOp op) {
return evaluateListConcat(op, actualParams, loc);
})
.Case([&](TupleCreateOp op) {
return evaluateTupleCreate(op, actualParams, loc);
})
Expand Down Expand Up @@ -583,6 +586,45 @@ circt::om::Evaluator::evaluateListCreate(ListCreateOp op,
return list;
}

/// Evaluator dispatch function for List concatenation.
FailureOr<evaluator::EvaluatorValuePtr>
circt::om::Evaluator::evaluateListConcat(ListConcatOp op,
ActualParameters actualParams,
Location loc) {
// Evaluate the List concat op itself, in case it hasn't been evaluated yet.
SmallVector<evaluator::EvaluatorValuePtr> values;
auto list = getOrCreateValue(op, actualParams, loc);

// Extract the ListValue, either directly or through an object reference.
auto extractList = [](evaluator::EvaluatorValue *value) {
return std::move(
llvm::TypeSwitch<evaluator::EvaluatorValue *, evaluator::ListValue *>(
value)
.Case([](evaluator::ListValue *val) { return val; })
.Case([](evaluator::ReferenceValue *val) {
return cast<evaluator::ListValue>(val->getStrippedValue()->get());
}));
};

for (auto operand : op.getOperands()) {
auto result = evaluateValue(operand, actualParams, loc);
if (failed(result))
return result;
if (!result.value()->isFullyEvaluated())
return list;

// Append each EvaluatorValue from the sublist.
evaluator::ListValue *subList = extractList(result.value().get());
for (auto subValue : subList->getElements())
values.push_back(subValue);
}

// Return the concatenated list.
llvm::cast<evaluator::ListValue>(list.value().get())
->setElements(std::move(values));
return list;
}

/// Evaluator dispatch function for Tuple creation.
FailureOr<evaluator::EvaluatorValuePtr>
circt::om::Evaluator::evaluateTupleCreate(TupleCreateOp op,
Expand Down
111 changes: 111 additions & 0 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,4 +919,115 @@ TEST(EvaluatorTests, IntegerBinaryArithmeticWidthMismatch) {
.getValue());
}

TEST(EvaluatorTests, ListConcat) {
StringRef mod = "om.class @ListConcat() {"
" %0 = om.constant #om.integer<0 : i8> : !om.integer"
" %1 = om.constant #om.integer<1 : i8> : !om.integer"
" %2 = om.constant #om.integer<2 : i8> : !om.integer"
" %l0 = om.list_create %0, %1 : !om.integer"
" %l1 = om.list_create %2 : !om.integer"
" %concat = om.list_concat %l0, %l1 : !om.list<!om.integer>"
" om.class.field @result, %concat : !om.list<!om.integer>"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result =
evaluator.instantiate(StringAttr::get(&context, "ListConcat"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

auto finalList =
llvm::cast<evaluator::ListValue>(fieldValue.get())->getElements();

ASSERT_EQ(3, finalList.size());

ASSERT_EQ(0, llvm::cast<evaluator::AttributeValue>(finalList[0].get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());

ASSERT_EQ(1, llvm::cast<evaluator::AttributeValue>(finalList[1].get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());

ASSERT_EQ(2, llvm::cast<evaluator::AttributeValue>(finalList[2].get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

TEST(EvaluatorTests, ListConcatField) {
StringRef mod =
"om.class @ListField() {"
" %0 = om.constant #om.integer<2 : i8> : !om.integer"
" %1 = om.list_create %0 : !om.integer"
" om.class.field @value, %1 : !om.list<!om.integer>"
"}"
"om.class @ListConcatField() {"
" %listField = om.object @ListField() : () -> !om.class.type<@ListField>"
" %0 = om.constant #om.integer<0 : i8> : !om.integer"
" %1 = om.constant #om.integer<1 : i8> : !om.integer"
" %l0 = om.list_create %0, %1 : !om.integer"
" %l1 = om.object.field %listField, [@value] : "
"(!om.class.type<@ListField>) -> !om.list<!om.integer>"
" %concat = om.list_concat %l0, %l1 : !om.list<!om.integer>"
" om.class.field @result, %concat : !om.list<!om.integer>"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result =
evaluator.instantiate(StringAttr::get(&context, "ListConcatField"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

auto finalList =
llvm::cast<evaluator::ListValue>(fieldValue.get())->getElements();

ASSERT_EQ(3, finalList.size());

ASSERT_EQ(0, llvm::cast<evaluator::AttributeValue>(finalList[0].get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());

ASSERT_EQ(1, llvm::cast<evaluator::AttributeValue>(finalList[1].get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());

ASSERT_EQ(2, llvm::cast<evaluator::AttributeValue>(finalList[2].get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

} // namespace

0 comments on commit 4cfcbda

Please sign in to comment.