Skip to content

Commit

Permalink
Add Block::ReplaceInstantiationWith
Browse files Browse the repository at this point in the history
This replaces one instantiation with another and removes the old one.

PiperOrigin-RevId: 649132476
  • Loading branch information
allight authored and copybara-github committed Jul 3, 2024
1 parent 4a05b19 commit b18ba45
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
25 changes: 25 additions & 0 deletions xls/ir/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,31 @@ absl::StatusOr<Instantiation*> Block::AddInstantiation(
return instantiation_ptr;
}

absl::Status Block::ReplaceInstantiationWith(Instantiation* old_inst,
Instantiation* new_inst) {
XLS_RET_CHECK(IsOwned(old_inst));
XLS_RET_CHECK(IsOwned(new_inst)) << "must add instantiation to this block "
"before replacing uses of another.";
std::vector<InstantiationInput*> inps(instantiation_inputs_.at(old_inst));
std::vector<InstantiationOutput*> outs(instantiation_outputs_.at(old_inst));
XLS_ASSIGN_OR_RETURN(InstantiationType old_type, old_inst->type());
XLS_ASSIGN_OR_RETURN(InstantiationType new_type, new_inst->type());
XLS_RET_CHECK(old_type == new_type) << "Type mismatch of instantiations";
for (InstantiationInput* inp : inps) {
XLS_RETURN_IF_ERROR(inp->ReplaceUsesWithNew<InstantiationInput>(
inp->data(), new_inst, inp->port_name())
.status());
XLS_RETURN_IF_ERROR(RemoveNode(inp));
}
for (InstantiationOutput* out : outs) {
XLS_RETURN_IF_ERROR(
out->ReplaceUsesWithNew<InstantiationOutput>(new_inst, out->port_name())
.status());
XLS_RETURN_IF_ERROR(RemoveNode(out));
}
return RemoveInstantiation(old_inst);
}

absl::Status Block::RemoveInstantiation(Instantiation* instantiation) {
if (!IsOwned(instantiation)) {
return absl::InvalidArgumentError("Instantiation is not owned by block.");
Expand Down
5 changes: 5 additions & 0 deletions xls/ir/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ class Block : public FunctionBase {
// to calling this method
absl::Status RemoveInstantiation(Instantiation* instantiation);

// Replaces all uses of old_isnt with new_inst and removes old_inst. Both must
// be currently owned by this block.
absl::Status ReplaceInstantiationWith(Instantiation* old_inst,
Instantiation* new_inst);

// Returns all instantiations owned by this block.
absl::Span<Instantiation* const> GetInstantiations() const {
return instantiation_vec_;
Expand Down
97 changes: 97 additions & 0 deletions xls/ir/block_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -905,5 +905,102 @@ TEST_F(BlockTest, BlockInstantiation) {
)");
}

TEST_F(BlockTest, ReplaceInstantiation) {
auto p = CreatePackage();
Type* u32 = p->GetBitsType(32);

BlockBuilder sub_bb("sub_block", p.get());
{
BValue a = sub_bb.InputPort("a", u32);
BValue b = sub_bb.InputPort("b", u32);
sub_bb.OutputPort("x", a);
sub_bb.OutputPort("y", b);
}
XLS_ASSERT_OK_AND_ASSIGN(Block * sub_block, sub_bb.Build());
BlockBuilder add_bb("add_block", p.get());
{
BValue a = add_bb.InputPort("a", u32);
BValue b = add_bb.InputPort("b", u32);
add_bb.OutputPort("x", add_bb.Add(a, b));
add_bb.OutputPort("y", b);
}
XLS_ASSERT_OK_AND_ASSIGN(Block * add_block, add_bb.Build());

BlockBuilder add2_bb("add2_block", p.get());
{
BValue a = add2_bb.InputPort("a2", u32);
BValue b = add2_bb.InputPort("b2", u32);
add2_bb.OutputPort("x", add2_bb.Add(a, b));
add2_bb.OutputPort("y", b);
}
XLS_ASSERT_OK_AND_ASSIGN(Block * add2_block, add2_bb.Build());

BlockBuilder bb("my_block", p.get());
XLS_ASSERT_OK_AND_ASSIGN(
Instantiation * instantiation,
bb.block()->AddBlockInstantiation("inst", sub_block));
{
BValue in0 = bb.InputPort("in0", u32);
BValue in1 = bb.InputPort("in1", u32);
bb.InstantiationInput(instantiation, "a", in0);
BValue out0 = bb.InstantiationOutput(instantiation, "x");
bb.InstantiationInput(instantiation, "b", in1);
BValue out1 = bb.InstantiationOutput(instantiation, "y");
bb.OutputPort("out0", out0);
bb.OutputPort("out1", out1);
}
XLS_ASSERT_OK_AND_ASSIGN(Block * block, bb.Build());

XLS_ASSERT_OK_AND_ASSIGN(auto inst_add,
block->AddBlockInstantiation("inst_add", add_block));
XLS_ASSERT_OK_AND_ASSIGN(
auto inst2_add, block->AddBlockInstantiation("inst2_add", add2_block));

EXPECT_THAT(
block->ReplaceInstantiationWith(instantiation, inst2_add),
status_testing::StatusIs(absl::StatusCode::kInternal,
testing::ContainsRegex("Type mismatch")));
XLS_ASSERT_OK(block->RemoveInstantiation(inst2_add));
EXPECT_THAT(block->ReplaceInstantiationWith(instantiation, inst_add),
status_testing::IsOk());
EXPECT_EQ(p->DumpIr(), R"(package ReplaceInstantiation
block sub_block(a: bits[32], b: bits[32], x: bits[32], y: bits[32]) {
a: bits[32] = input_port(name=a, id=1)
b: bits[32] = input_port(name=b, id=2)
x: () = output_port(a, name=x, id=3)
y: () = output_port(b, name=y, id=4)
}
block add_block(a: bits[32], b: bits[32], x: bits[32], y: bits[32]) {
a: bits[32] = input_port(name=a, id=5)
b: bits[32] = input_port(name=b, id=6)
add.7: bits[32] = add(a, b, id=7)
x: () = output_port(add.7, name=x, id=8)
y: () = output_port(b, name=y, id=9)
}
block add2_block(a2: bits[32], b2: bits[32], x: bits[32], y: bits[32]) {
a2: bits[32] = input_port(name=a2, id=10)
b2: bits[32] = input_port(name=b2, id=11)
add.12: bits[32] = add(a2, b2, id=12)
x: () = output_port(add.12, name=x, id=13)
y: () = output_port(b2, name=y, id=14)
}
block my_block(in0: bits[32], in1: bits[32], out0: bits[32], out1: bits[32]) {
instantiation inst_add(block=add_block, kind=block)
in0: bits[32] = input_port(name=in0, id=15)
in1: bits[32] = input_port(name=in1, id=16)
instantiation_input.23: () = instantiation_input(in0, instantiation=inst_add, port_name=a, id=23)
instantiation_input.24: () = instantiation_input(in1, instantiation=inst_add, port_name=b, id=24)
instantiation_output.25: bits[32] = instantiation_output(instantiation=inst_add, port_name=x, id=25)
instantiation_output.26: bits[32] = instantiation_output(instantiation=inst_add, port_name=y, id=26)
out0: () = output_port(instantiation_output.25, name=out0, id=21)
out1: () = output_port(instantiation_output.26, name=out1, id=22)
}
)");
}

} // namespace
} // namespace xls

0 comments on commit b18ba45

Please sign in to comment.