Skip to content

Commit

Permalink
Add 'InstantiationType' to hold information about the types of input/…
Browse files Browse the repository at this point in the history
…output ports on an instantiation

PiperOrigin-RevId: 649131832
  • Loading branch information
allight authored and copybara-github committed Jul 3, 2024
1 parent d67ce25 commit 4a05b19
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 0 deletions.
5 changes: 5 additions & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ cc_library(
deps = [
":xls_type_cc_proto",
"//xls/common:casts",
"//xls/common/status:ret_check",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -563,6 +565,7 @@ cc_library(
"//xls/common:visitor",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/data_structures:leaf_type_tree",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -1059,8 +1062,10 @@ cc_test(
srcs = ["type_test.cc"],
deps = [
":type",
":type_manager",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
],
Expand Down
63 changes: 63 additions & 0 deletions xls/ir/instantiation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,27 @@
#include "xls/ir/instantiation.h"

#include <cstdint>
#include <initializer_list>
#include <optional>
#include <ostream>
#include <string>
#include <string_view>
#include <utility>
#include <variant>

#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/leaf_type_tree.h"
#include "xls/ir/block.h"
#include "xls/ir/channel.h"
#include "xls/ir/function.h"
Expand Down Expand Up @@ -119,6 +126,19 @@ absl::StatusOr<InstantiationPort> BlockInstantiation::GetOutputPort(
return absl::NotFoundError(absl::StrFormat("No such output port `%s`", name));
}

absl::StatusOr<InstantiationType> BlockInstantiation::type() const {
absl::flat_hash_map<std::string, Type*> input_ports;
absl::flat_hash_map<std::string, Type*> output_ports;
for (InputPort* p : instantiated_block()->GetInputPorts()) {
input_ports[p->name()] = p->GetType();
}
for (OutputPort* p : instantiated_block()->GetOutputPorts()) {
output_ports[p->name()] =
p->operand(OutputPort::kOperandOperand)->GetType();
}
return InstantiationType(std::move(input_ports), std::move(output_ports));
}

// Note: these are tested in ffi_instantiation_pass_test
static absl::StatusOr<InstantiationPort> ExtractNested(
std::string_view fn_name, std::string_view full_parameter_name,
Expand Down Expand Up @@ -202,6 +222,34 @@ std::string ExternInstantiation::ToString() const {
name(), function_->name());
}

absl::StatusOr<InstantiationType> ExternInstantiation::type() const {
absl::flat_hash_map<std::string, Type*> input_ports;
absl::flat_hash_map<std::string, Type*> output_ports;
for (Param* p : function_->params()) {
LeafTypeTree<std::monostate> ltt(p->GetType(), std::monostate{});
XLS_RETURN_IF_ERROR(leaf_type_tree::ForEachIndex(
ltt.AsView(),
[&](Type* type, std::monostate v,
absl::Span<const int64_t> idx) -> absl::Status {
std::string name =
absl::StrFormat("%s.%s", p->GetName(), absl::StrJoin(idx, "."));
input_ports[name] = type;
return absl::OkStatus();
}));
}
LeafTypeTree<std::monostate> result(function_->return_value()->GetType(),
std::monostate{});
XLS_RETURN_IF_ERROR(leaf_type_tree::ForEachIndex(
result.AsView(),
[&](Type* type, std::monostate v,
absl::Span<const int64_t> idx) -> absl::Status {
std::string name = absl::StrCat("return.", absl::StrJoin(idx, "."));
output_ports[name] = type;
return absl::OkStatus();
}));
return InstantiationType(std::move(input_ports), std::move(output_ports));
}

FifoInstantiation::FifoInstantiation(
std::string_view inst_name, FifoConfig fifo_config, Type* data_type,
std::optional<std::string_view> channel_name, Package* package)
Expand Down Expand Up @@ -233,6 +281,21 @@ absl::StatusOr<InstantiationPort> FifoInstantiation::GetInputPort(
name));
}

absl::StatusOr<InstantiationType> FifoInstantiation::type() const {
Type* u1 = package_->GetBitsType(1);
return InstantiationType(/*input_types=*/
{
{std::string(kPushValidPortName), u1},
{std::string(kPopReadyPortName), u1},
{std::string(kPushDataPortName), data_type()},
},
/*output_types=*/{
{std::string(kPopValidPortName), u1},
{std::string(kPushReadyPortName), u1},
{std::string(kPopDataPortName), data_type()},
});
}

absl::StatusOr<InstantiationPort> FifoInstantiation::GetOutputPort(
std::string_view name) {
if (name == kPopDataPortName) {
Expand Down
6 changes: 6 additions & 0 deletions xls/ir/instantiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>
#include <string_view>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "xls/ir/channel.h"
#include "xls/ir/package.h"
Expand Down Expand Up @@ -76,6 +77,8 @@ class Instantiation {

virtual std::string ToString() const = 0;

virtual absl::StatusOr<InstantiationType> type() const = 0;

virtual absl::StatusOr<InstantiationPort> GetInputPort(
std::string_view name) = 0;
virtual absl::StatusOr<InstantiationPort> GetOutputPort(
Expand All @@ -100,6 +103,7 @@ class BlockInstantiation : public Instantiation {
Block* instantiated_block() const { return instantiated_block_; }

std::string ToString() const override;
absl::StatusOr<InstantiationType> type() const override;

absl::StatusOr<InstantiationPort> GetInputPort(
std::string_view name) override;
Expand All @@ -125,6 +129,7 @@ class ExternInstantiation : public Instantiation {
absl::StatusOr<InstantiationPort> GetOutputPort(std::string_view name) final;

std::string ToString() const final;
absl::StatusOr<InstantiationType> type() const override;

absl::StatusOr<ExternInstantiation*> AsExternInstantiation() override {
return this;
Expand All @@ -150,6 +155,7 @@ class FifoInstantiation : public Instantiation {

absl::StatusOr<InstantiationPort> GetInputPort(std::string_view name) final;
absl::StatusOr<InstantiationPort> GetOutputPort(std::string_view name) final;
absl::StatusOr<InstantiationType> type() const override;

const FifoConfig& fifo_config() const { return fifo_config_; }
Type* data_type() const { return data_type_; }
Expand Down
13 changes: 13 additions & 0 deletions xls/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstdint>
#include <ostream>
#include <string>
#include <string_view>
#include <vector>

#include "absl/log/check.h"
Expand All @@ -25,6 +26,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "xls/common/status/ret_check.h"
#include "xls/ir/xls_type.pb.h"

namespace xls {
Expand Down Expand Up @@ -261,4 +263,15 @@ bool TypeHasToken(Type* type) {
return false;
}

absl::StatusOr<Type*> InstantiationType::GetOutputPortType(
std::string_view name) const {
XLS_RET_CHECK(output_types_.contains(name));
return output_types_.at(name);
}
absl::StatusOr<Type*> InstantiationType::GetInputPortType(
std::string_view name) const {
XLS_RET_CHECK(input_types_.contains(name));
return input_types_.at(name);
}

} // namespace xls
37 changes: 37 additions & 0 deletions xls/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#include <cstdint>
#include <ostream>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
Expand Down Expand Up @@ -230,6 +233,40 @@ class FunctionType {
Type* return_type_;
};

// Represents a type that is an instantiation with input and output ports.
class InstantiationType {
public:
explicit InstantiationType(
absl::flat_hash_map<std::string, Type*> input_types,
absl::flat_hash_map<std::string, Type*> output_types)
: input_types_(std::move(input_types)),
output_types_(std::move(output_types)) {}

InstantiationType(const InstantiationType&) = default;
InstantiationType(InstantiationType&&) = default;
InstantiationType& operator=(const InstantiationType&) = default;
InstantiationType& operator=(InstantiationType&&) = default;

absl::StatusOr<Type*> GetInputPortType(std::string_view name) const;
absl::StatusOr<Type*> GetOutputPortType(std::string_view name) const;

const absl::flat_hash_map<std::string, Type*>& input_types() const {
return input_types_;
}
const absl::flat_hash_map<std::string, Type*>& output_types() const {
return output_types_;
}

bool operator==(const InstantiationType& o) const {
return input_types_ == o.input_types_ && output_types_ == o.output_types_;
}
bool operator!=(const InstantiationType& it) const { return !(*this == it); }

private:
absl::flat_hash_map<std::string, Type*> input_types_;
absl::flat_hash_map<std::string, Type*> output_types_;
};

// -- Inlines

inline const BitsType* Type::AsBitsOrDie() const {
Expand Down
34 changes: 34 additions & 0 deletions xls/ir/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "xls/common/status/matchers.h"
#include "xls/ir/type_manager.h"

namespace xls {
namespace {
Expand Down Expand Up @@ -205,5 +207,37 @@ TEST(TypeTest, AsXTypeCallsWork) {
HasSubstr("Type is not a tuple: bits[32][7]")));
}

TEST(TypeTest, InstantiationType) {
TypeManager man;
InstantiationType it1(/*input_types=*/{{"foo", man.GetBitsType(32)}},
/*output_types=*/{{"bar", man.GetBitsType(32)}});
InstantiationType it2(/*input_types=*/{{"foo", man.GetBitsType(32)}},
/*output_types=*/{{"bar", man.GetBitsType(32)}});
EXPECT_EQ(it1, it2);
InstantiationType it3(/*input_types=*/{{"bar", man.GetBitsType(32)}},
/*output_types=*/{{"foo", man.GetBitsType(32)}});
EXPECT_NE(it1, it3);
InstantiationType it4(/*input_types=*/{{"fooooooo", man.GetBitsType(32)}},
/*output_types=*/{{"bar", man.GetBitsType(32)}});
EXPECT_NE(it1, it4);
InstantiationType it5(/*input_types=*/{{"foo", man.GetBitsType(32)}},
/*output_types=*/{{"baaaaar", man.GetBitsType(32)}});
EXPECT_NE(it1, it5);
InstantiationType it6(/*input_types=*/{{"foo", man.GetBitsType(32)},
{"more", man.GetBitsType(32)}},
/*output_types=*/{{"bar", man.GetBitsType(32)}});
EXPECT_NE(it1, it6);
InstantiationType it7(/*input_types=*/{{"foo", man.GetBitsType(32)}},
/*output_types=*/{{"bar", man.GetBitsType(32)},
{"more", man.GetBitsType(32)}});
EXPECT_NE(it1, it7);
InstantiationType it8(/*input_types=*/{{"foo", man.GetBitsType(32)}},
/*output_types=*/{{"bar", man.GetBitsType(3)}});
EXPECT_NE(it1, it8);
InstantiationType it9(/*input_types=*/{{"foo", man.GetBitsType(3)}},
/*output_types=*/{{"bar", man.GetBitsType(32)}});
EXPECT_NE(it1, it9);
}

} // namespace
} // namespace xls

0 comments on commit 4a05b19

Please sign in to comment.