Skip to content

Commit

Permalink
AOT compile proc networks.
Browse files Browse the repository at this point in the history
This adds support to aot_compiler_main to generate proc network AOTs and allows ProcRuntimes to be built using these compiled entrypoints.

Bug: #1403
PiperOrigin-RevId: 638069672
  • Loading branch information
allight authored and copybara-github committed May 28, 2024
1 parent 96653bd commit 5a561aa
Show file tree
Hide file tree
Showing 20 changed files with 760 additions and 180 deletions.
1 change: 1 addition & 0 deletions xls/examples/dslx_module/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ xls_dslx_opt_ir(
dslx_top = "manual_chan_caps_specialized",
ir_file = "manual_chan_caps_streaming_configured.ir",
library = ":some_caps_streaming_configured",
visibility = ["//xls:xls_internal"],
)

cc_xls_ir_jit_wrapper(
Expand Down
80 changes: 74 additions & 6 deletions xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ cc_binary(
deps = [
":aot_compiler",
":aot_entrypoint_cc_proto",
":function_base_jit",
":function_jit",
":jit_proc_runtime",
":llvm_type_converter",
":type_layout_cc_proto",
"//xls/common:init_xls",
Expand All @@ -95,6 +97,7 @@ cc_binary(
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:ir_parser",
"//xls/ir:type",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -170,15 +173,13 @@ cc_library(
hdrs = ["ir_builder_visitor.h"],
deps = [
":jit_callbacks",
":jit_channel_queue",
":llvm_compiler",
":llvm_type_converter",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:elaboration",
"//xls/ir:format_preference",
"//xls/ir:format_strings",
"//xls/ir:op",
Expand All @@ -189,7 +190,6 @@ cc_library(
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -535,6 +535,43 @@ cc_test(
],
)

cc_test(
name = "proc_jit_aot_test",
srcs = [
"proc_jit_aot_test.cc",
],
data = [
":some_caps_no_idents.ir",
":specialized_caps_aot",
],
deps = [
":aot_entrypoint_cc_proto",
":function_base_jit",
":jit_callbacks",
":jit_channel_queue",
":jit_proc_runtime",
":jit_runtime",
":specialized_caps_aot", # build_cleaner: keep
"//xls/common:xls_gunit_main",
"//xls/common/file:filesystem",
"//xls/common/file:get_runfile_path",
"//xls/common/status:matchers",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/interpreter:channel_queue",
"//xls/ir",
"//xls/ir:events",
"//xls/ir:type_manager",
"//xls/ir:value",
"//xls/ir:value_builder",
"//xls/ir:value_utils",
"//xls/public:ir_parser",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest",
],
)

cc_library(
name = "jit_channel_queue",
srcs = ["jit_channel_queue.cc"],
Expand Down Expand Up @@ -749,8 +786,8 @@ cc_library(
srcs = ["proc_jit.cc"],
hdrs = ["proc_jit.h"],
deps = [
":aot_entrypoint_cc_proto",
":function_base_jit",
":ir_builder_visitor",
":jit_buffer",
":jit_callbacks",
":jit_channel_queue",
Expand All @@ -762,7 +799,6 @@ cc_library(
"//xls/interpreter:proc_evaluator",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:elaboration",
"//xls/ir:events",
"//xls/ir:proc_elaboration",
"//xls/ir:value",
Expand Down Expand Up @@ -836,17 +872,29 @@ cc_library(
srcs = ["jit_proc_runtime.cc"],
hdrs = ["jit_proc_runtime.h"],
deps = [
":aot_compiler",
":aot_entrypoint_cc_proto",
":function_base_jit",
":jit_channel_queue",
":llvm_compiler",
":proc_jit",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/interpreter:channel_queue",
"//xls/interpreter:proc_evaluator",
"//xls/interpreter:serial_proc_runtime",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:proc_elaboration",
"//xls/ir:value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:ir_headers",
],
)

Expand Down Expand Up @@ -1074,3 +1122,23 @@ xls_aot_generate(
top = "multi_function_one",
with_msan = XLS_IS_MSAN_BUILD,
)

# Procs get the file name embedded in their identifier when they have generics.
# This means that writing a test which also works in OSS would be difficult. To
# avoid this problem we just remove all identifiers. Normal users should never
# need to do this since the generated wrappers handle all of this for you.
genrule(
name = "some_caps_no_idents",
srcs = ["//xls/examples/dslx_module:manaul_chan_caps_streaming_configured_opt_ir.opt.ir"],
outs = ["some_caps_no_idents.ir"],
cmd = """
$(location //xls/tools:remove_identifiers_main) $< > $@
""",
tools = ["//xls/tools:remove_identifiers_main"],
)

xls_aot_generate(
name = "specialized_caps_aot",
src = ":some_caps_no_idents",
with_msan = XLS_IS_MSAN_BUILD,
)
13 changes: 9 additions & 4 deletions xls/jit/aot_basic_function_entrypoint_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,19 @@ def main(argv: Sequence[str]) -> None:
if len(argv) != 2:
raise app.UsageError(f"Usage: {argv[0]} [flags] AotEntrypointProto")
if _READ_TEXTPROTO.value:
entrypoint = aot_entrypoint_pb2.AotEntrypointProto()
all_entrypoints = aot_entrypoint_pb2.AotPackageEntrypointsProto()
with open(argv[1], "rt") as proto:
text_format.Parse(proto.read(), entrypoint)
text_format.Parse(proto.read(), all_entrypoints)
else:
with open(argv[1], "rb") as proto:
entrypoint = aot_entrypoint_pb2.AotEntrypointProto.FromString(
proto.read()
all_entrypoints = (
aot_entrypoint_pb2.AotPackageEntrypointsProto.FromString(proto.read())
)
if len(all_entrypoints.entrypoint) != 1:
raise app.UsageError("Multiple entrypoints are not supported.")
entrypoint = all_entrypoints.entrypoint[0]
if entrypoint.type != aot_entrypoint_pb2.AotEntrypointProto.FUNCTION:
raise app.UsageError("Only functions are supported!")
params = []
for name, size, align in zip(
entrypoint.inputs_names,
Expand Down
Loading

0 comments on commit 5a561aa

Please sign in to comment.