Skip to content

Commit

Permalink
Add support for exploding one-bit MUXes written as one-hot and priori…
Browse files Browse the repository at this point in the history
…ty selects

PiperOrigin-RevId: 649229590
  • Loading branch information
ericastor authored and copybara-github committed Jul 3, 2024
1 parent 3add73f commit cfa9a53
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 22 deletions.
74 changes: 53 additions & 21 deletions xls/passes/select_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@ absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,
// OneHotSelect with identical cases can be replaced with a select between one
// of the identical case and the default value where the selector is: original
// selector == 0
if (node->Is<OneHotSelect>() && node->GetType()->IsBits()) {
if (node->Is<OneHotSelect>() && node->GetType()->IsBits() &&
node->BitCountOrDie() > 1) {
Node* selector = node->As<OneHotSelect>()->selector();
absl::Span<Node* const> cases = node->As<OneHotSelect>()->cases();
if (absl::c_all_of(cases, [&](Node* c) { return c == cases[0]; })) {
Expand Down Expand Up @@ -922,37 +923,68 @@ absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,
// * At least one of the selected values is a constant.
// * One of the selected values is also the selector.
//
// TODO(meheff): Handle one-hot select and priority-select here as well.
// If the one-bit MUX is a one-hot select, one of the selected values is
// always a constant, since the default value is always zero.
auto is_one_bit_mux = [&] {
return node->Is<Select>() && node->GetType()->IsBits() &&
node->BitCountOrDie() == 1 && node->operand(0)->BitCountOrDie() == 1;
if (!node->GetType()->IsBits() || node->BitCountOrDie() != 1) {
return false;
}
if (node->Is<Select>()) {
return node->As<Select>()->selector()->BitCountOrDie() == 1;
}
if (node->Is<PrioritySelect>()) {
return node->As<PrioritySelect>()->selector()->BitCountOrDie() == 1;
}
if (node->Is<OneHotSelect>()) {
return node->As<OneHotSelect>()->selector()->BitCountOrDie() == 1;
}
return false;
};
if (NarrowingEnabled(opt_level) && is_one_bit_mux() &&
(query_engine.IsFullyKnown(node->operand(1)) ||
(node->Is<OneHotSelect>() ||
query_engine.IsFullyKnown(node->operand(1)) ||
query_engine.IsFullyKnown(node->operand(2)) ||
(node->operand(0) == node->operand(1) ||
node->operand(0) == node->operand(2)))) {
node->operand(0) == node->operand(1) ||
node->operand(0) == node->operand(2))) {
FunctionBase* f = node->function_base();
Select* select = node->As<Select>();
XLS_RET_CHECK(!select->default_value().has_value()) << select->ToString();
Node* s = select->operand(0);
Node* on_false = select->get_case(0);
Node* on_true = select->get_case(1);
Node* s;
Node* on_true;
std::optional<Node*> on_false;
if (node->Is<Select>()) {
Select* select = node->As<Select>();
XLS_RET_CHECK(!select->default_value().has_value()) << select->ToString();
s = select->selector();
on_true = select->get_case(1);
on_false = select->get_case(0);
} else if (node->Is<PrioritySelect>()) {
s = node->As<PrioritySelect>()->selector();
on_true = node->As<PrioritySelect>()->get_case(0);
on_false = node->As<PrioritySelect>()->default_value();
} else {
XLS_RET_CHECK(node->Is<OneHotSelect>());
s = node->As<OneHotSelect>()->selector();
on_true = node->As<OneHotSelect>()->get_case(0);
on_false = std::nullopt;
}
VLOG(3) << absl::StrFormat("Decomposing single-bit select: %s",
node->ToString());
if (!on_false.has_value()) {
XLS_RETURN_IF_ERROR(node->ReplaceUsesWithNew<NaryOp>(
std::vector<Node*>{s, on_true}, Op::kAnd)
.status());
return true;
}
XLS_ASSIGN_OR_RETURN(
Node * lhs,
f->MakeNode<NaryOp>(select->loc(), std::vector<Node*>{s, on_true},
Op::kAnd));
Node * lhs, f->MakeNode<NaryOp>(
node->loc(), std::vector<Node*>{s, on_true}, Op::kAnd));
XLS_ASSIGN_OR_RETURN(Node * s_not,
f->MakeNode<UnOp>(select->loc(), s, Op::kNot));
f->MakeNode<UnOp>(node->loc(), s, Op::kNot));
XLS_ASSIGN_OR_RETURN(
Node * rhs,
f->MakeNode<NaryOp>(select->loc(), std::vector<Node*>{s_not, on_false},
f->MakeNode<NaryOp>(node->loc(), std::vector<Node*>{s_not, *on_false},
Op::kAnd));
VLOG(2) << absl::StrFormat("Decomposing single-bit select: %s",
node->ToString());
XLS_RETURN_IF_ERROR(
select
->ReplaceUsesWithNew<NaryOp>(std::vector<Node*>{lhs, rhs}, Op::kOr)
node->ReplaceUsesWithNew<NaryOp>(std::vector<Node*>{lhs, rhs}, Op::kOr)
.status());
return true;
}
Expand Down
33 changes: 32 additions & 1 deletion xls/passes/select_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ TEST_F(SelectSimplificationPassTest, MeaningfulArrayTyped3ArySelectViaDefault) {
m::Literal(default_value)));
}

TEST_F(SelectSimplificationPassTest, OneBitMux) {
TEST_F(SelectSimplificationPassTest, OneBitMuxSel) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
fn func(s: bits[1], a: bits[1]) -> bits[1] {
Expand All @@ -555,6 +555,37 @@ TEST_F(SelectSimplificationPassTest, OneBitMux) {
m::And(m::Not(m::Param("s")), m::Param("s"))));
}

TEST_F(SelectSimplificationPassTest, OneBitMuxPrioritySel) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
fn func(s: bits[1], a: bits[1]) -> bits[1] {
ret priority_sel.3: bits[1] = priority_sel(s, cases=[s], default=a)
}
)",
p.get()));
EXPECT_EQ(f->return_value()->op(), Op::kPrioritySel);

EXPECT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::Or(m::And(m::Param("s"), m::Param("s")),
m::And(m::Not(m::Param("s")), m::Param("a"))));
}

TEST_F(SelectSimplificationPassTest, OneBitMuxOneHotSel) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
fn func(s: bits[1], a: bits[1]) -> bits[1] {
ret one_hot_sel.3: bits[1] = one_hot_sel(s, cases=[a])
}
)",
p.get()));
EXPECT_EQ(f->return_value()->op(), Op::kOneHotSel);

EXPECT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::And(m::Param("s"), m::Param("a")));
}

TEST_F(SelectSimplificationPassTest, SelSqueezing) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down

0 comments on commit cfa9a53

Please sign in to comment.