From 5a561aaf6388feaf9fa3068327cfbc9a093545e6 Mon Sep 17 00:00:00 2001 From: Alex Light Date: Tue, 28 May 2024 16:46:48 -0700 Subject: [PATCH] AOT compile proc networks. This adds support to aot_compiler_main to generate proc network AOTs and allows ProcRuntimes to be built using these compiled entrypoints. Bug: https://github.com/google/xls/issues/1403 PiperOrigin-RevId: 638069672 --- xls/examples/dslx_module/BUILD | 1 + xls/jit/BUILD | 80 ++++++- xls/jit/aot_basic_function_entrypoint_main.py | 13 +- xls/jit/aot_compiler_main.cc | 165 +++++++------ xls/jit/aot_entrypoint.proto | 14 ++ xls/jit/function_base_jit.cc | 30 ++- xls/jit/function_base_jit.h | 13 +- xls/jit/function_jit.cc | 25 +- xls/jit/function_jit.h | 8 - xls/jit/function_jit_aot_test.cc | 5 +- xls/jit/ir_builder_visitor.h | 9 +- xls/jit/jit_function_wrapper_cc.tmpl | 16 +- xls/jit/jit_proc_runtime.cc | 217 +++++++++++++++++- xls/jit/jit_proc_runtime.h | 38 +++ xls/jit/jit_wrapper_generator_main.py | 23 +- xls/jit/llvm_compiler.cc | 3 +- xls/jit/llvm_compiler.h | 14 +- xls/jit/proc_jit.cc | 57 +++-- xls/jit/proc_jit.h | 13 +- xls/jit/proc_jit_aot_test.cc | 196 ++++++++++++++++ 20 files changed, 760 insertions(+), 180 deletions(-) create mode 100644 xls/jit/proc_jit_aot_test.cc diff --git a/xls/examples/dslx_module/BUILD b/xls/examples/dslx_module/BUILD index 915b3be116..ab3fa07c3c 100644 --- a/xls/examples/dslx_module/BUILD +++ b/xls/examples/dslx_module/BUILD @@ -84,6 +84,7 @@ xls_dslx_opt_ir( dslx_top = "manual_chan_caps_specialized", ir_file = "manual_chan_caps_streaming_configured.ir", library = ":some_caps_streaming_configured", + visibility = ["//xls:xls_internal"], ) cc_xls_ir_jit_wrapper( diff --git a/xls/jit/BUILD b/xls/jit/BUILD index 246b0cbd96..1350ebbc42 100644 --- a/xls/jit/BUILD +++ b/xls/jit/BUILD @@ -86,7 +86,9 @@ cc_binary( deps = [ ":aot_compiler", ":aot_entrypoint_cc_proto", + ":function_base_jit", ":function_jit", + ":jit_proc_runtime", ":llvm_type_converter", ":type_layout_cc_proto", "//xls/common:init_xls", @@ -95,6 +97,7 @@ cc_binary( "//xls/common/status:status_macros", "//xls/ir", "//xls/ir:ir_parser", + "//xls/ir:type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:check", @@ -170,7 +173,6 @@ cc_library( hdrs = ["ir_builder_visitor.h"], deps = [ ":jit_callbacks", - ":jit_channel_queue", ":llvm_compiler", ":llvm_type_converter", "//xls/common/status:ret_check", @@ -178,7 +180,6 @@ cc_library( "//xls/ir", "//xls/ir:bits", "//xls/ir:bits_ops", - "//xls/ir:elaboration", "//xls/ir:format_preference", "//xls/ir:format_strings", "//xls/ir:op", @@ -189,7 +190,6 @@ cc_library( "@com_google_absl//absl/base:config", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -535,6 +535,43 @@ cc_test( ], ) +cc_test( + name = "proc_jit_aot_test", + srcs = [ + "proc_jit_aot_test.cc", + ], + data = [ + ":some_caps_no_idents.ir", + ":specialized_caps_aot", + ], + deps = [ + ":aot_entrypoint_cc_proto", + ":function_base_jit", + ":jit_callbacks", + ":jit_channel_queue", + ":jit_proc_runtime", + ":jit_runtime", + ":specialized_caps_aot", # build_cleaner: keep + "//xls/common:xls_gunit_main", + "//xls/common/file:filesystem", + "//xls/common/file:get_runfile_path", + "//xls/common/status:matchers", + "//xls/common/status:ret_check", + "//xls/common/status:status_macros", + "//xls/interpreter:channel_queue", + "//xls/ir", + "//xls/ir:events", + "//xls/ir:type_manager", + "//xls/ir:value", + "//xls/ir:value_builder", + "//xls/ir:value_utils", + "//xls/public:ir_parser", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "jit_channel_queue", srcs = ["jit_channel_queue.cc"], @@ -749,8 +786,8 @@ cc_library( srcs = ["proc_jit.cc"], hdrs = ["proc_jit.h"], deps = [ + ":aot_entrypoint_cc_proto", ":function_base_jit", - ":ir_builder_visitor", ":jit_buffer", ":jit_callbacks", ":jit_channel_queue", @@ -762,7 +799,6 @@ cc_library( "//xls/interpreter:proc_evaluator", "//xls/ir", "//xls/ir:channel", - "//xls/ir:elaboration", "//xls/ir:events", "//xls/ir:proc_elaboration", "//xls/ir:value", @@ -836,17 +872,29 @@ cc_library( srcs = ["jit_proc_runtime.cc"], hdrs = ["jit_proc_runtime.h"], deps = [ + ":aot_compiler", + ":aot_entrypoint_cc_proto", + ":function_base_jit", ":jit_channel_queue", + ":llvm_compiler", ":proc_jit", + "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/interpreter:channel_queue", "//xls/interpreter:proc_evaluator", "//xls/interpreter:serial_proc_runtime", "//xls/ir", - "//xls/ir:channel", "//xls/ir:proc_elaboration", "//xls/ir:value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:ir_headers", ], ) @@ -1074,3 +1122,23 @@ xls_aot_generate( top = "multi_function_one", with_msan = XLS_IS_MSAN_BUILD, ) + +# Procs get the file name embedded in their identifier when they have generics. +# This means that writing a test which also works in OSS would be difficult. To +# avoid this problem we just remove all identifiers. Normal users should never +# need to do this since the generated wrappers handle all of this for you. +genrule( + name = "some_caps_no_idents", + srcs = ["//xls/examples/dslx_module:manaul_chan_caps_streaming_configured_opt_ir.opt.ir"], + outs = ["some_caps_no_idents.ir"], + cmd = """ + $(location //xls/tools:remove_identifiers_main) $< > $@ + """, + tools = ["//xls/tools:remove_identifiers_main"], +) + +xls_aot_generate( + name = "specialized_caps_aot", + src = ":some_caps_no_idents", + with_msan = XLS_IS_MSAN_BUILD, +) diff --git a/xls/jit/aot_basic_function_entrypoint_main.py b/xls/jit/aot_basic_function_entrypoint_main.py index 2da0898fc7..ee83a9de37 100644 --- a/xls/jit/aot_basic_function_entrypoint_main.py +++ b/xls/jit/aot_basic_function_entrypoint_main.py @@ -127,14 +127,19 @@ def main(argv: Sequence[str]) -> None: if len(argv) != 2: raise app.UsageError(f"Usage: {argv[0]} [flags] AotEntrypointProto") if _READ_TEXTPROTO.value: - entrypoint = aot_entrypoint_pb2.AotEntrypointProto() + all_entrypoints = aot_entrypoint_pb2.AotPackageEntrypointsProto() with open(argv[1], "rt") as proto: - text_format.Parse(proto.read(), entrypoint) + text_format.Parse(proto.read(), all_entrypoints) else: with open(argv[1], "rb") as proto: - entrypoint = aot_entrypoint_pb2.AotEntrypointProto.FromString( - proto.read() + all_entrypoints = ( + aot_entrypoint_pb2.AotPackageEntrypointsProto.FromString(proto.read()) ) + if len(all_entrypoints.entrypoint) != 1: + raise app.UsageError("Multiple entrypoints are not supported.") + entrypoint = all_entrypoints.entrypoint[0] + if entrypoint.type != aot_entrypoint_pb2.AotEntrypointProto.FUNCTION: + raise app.UsageError("Only functions are supported!") params = [] for name, size, align in zip( entrypoint.inputs_names, diff --git a/xls/jit/aot_compiler_main.cc b/xls/jit/aot_compiler_main.cc index b792a6e56d..205799e859 100644 --- a/xls/jit/aot_compiler_main.cc +++ b/xls/jit/aot_compiler_main.cc @@ -18,6 +18,7 @@ // disk. #include +#include // NOLINT #include #include #include @@ -40,9 +41,12 @@ #include "xls/ir/ir_parser.h" #include "xls/ir/nodes.h" #include "xls/ir/package.h" +#include "xls/ir/type.h" #include "xls/jit/aot_compiler.h" #include "xls/jit/aot_entrypoint.pb.h" +#include "xls/jit/function_base_jit.h" #include "xls/jit/function_jit.h" +#include "xls/jit/jit_proc_runtime.h" #include "xls/jit/llvm_type_converter.h" #include "xls/jit/type_layout.pb.h" @@ -54,10 +58,10 @@ ABSL_FLAG(std::string, top, "", ABSL_FLAG(std::string, output_object, "", "Path at which to write the output object file."); ABSL_FLAG(std::string, output_proto, "", - "Path at which to write the AotEntrypointProto describing the ABI of " - "the generated object file."); + "Path at which to write the AotPackageEntrypointsProto describing " + "the ABI of the generated object files."); ABSL_FLAG(bool, generate_textproto, false, - "Generate the AotEntrypointProto as a textproto"); + "Generate the AotPackageEntrypointsProto as a textproto"); #ifdef ABSL_HAVE_MEMORY_SANITIZER static constexpr bool kHasMsan = true; #else @@ -70,26 +74,8 @@ ABSL_FLAG(bool, include_msan, kHasMsan, namespace xls { namespace { -// Returns the TypeLayouts for the arguments of `f`. -TypeLayoutsProto ArgLayouts(Function* f, LlvmTypeConverter& type_converter) { - TypeLayoutsProto layouts_proto; - for (Param* param : f->params()) { - *layouts_proto.add_layouts() = - type_converter.CreateTypeLayout(param->GetType()).ToProto(); - } - return layouts_proto; -} - -// Returns the TypeLayout for the return value of `f`. -TypeLayoutsProto ResultLayouts(Function* f, LlvmTypeConverter& type_converter) { - TypeLayoutsProto layout_proto; - *layout_proto.add_layouts() = - type_converter.CreateTypeLayout(f->return_value()->GetType()).ToProto(); - return layout_proto; -} - absl::StatusOr GenerateEntrypointProto( - Package* package, Function* func, const JitObjectCode& object_code, + Package* package, FunctionBase* func, const JittedFunctionBase& object_code, bool include_msan) { AotEntrypointProto proto; XLS_ASSIGN_OR_RETURN( @@ -98,53 +84,66 @@ absl::StatusOr GenerateEntrypointProto( XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout, aot_compiler->CreateDataLayout()); LlvmTypeConverter type_converter(aot_compiler->GetContext(), data_layout); - *proto.mutable_inputs_layout() = ArgLayouts(func, type_converter); - *proto.mutable_outputs_layout() = ResultLayouts(func, type_converter); - proto.add_outputs_names("result"); proto.set_has_msan(include_msan); - for (const Param* p : func->params()) { - proto.add_inputs_names(p->name()); + if (func->IsFunction()) { + proto.set_type(AotEntrypointProto::FUNCTION); + proto.add_outputs_names("result"); + for (const Param* p : func->params()) { + proto.add_inputs_names(p->name()); + *proto.mutable_inputs_layout()->add_layouts() = + type_converter.CreateTypeLayout(p->GetType()).ToProto(); + } + *proto.mutable_outputs_layout()->add_layouts() = + type_converter + .CreateTypeLayout(func->AsFunctionOrDie()->GetType()->return_type()) + .ToProto(); + } else if (func->IsProc()) { + proto.set_type(AotEntrypointProto::PROC); + for (const Param* p : func->params()) { + proto.add_inputs_names(p->name()); + proto.add_outputs_names(p->name()); + auto layout_proto = + type_converter.CreateTypeLayout(p->GetType()).ToProto(); + *proto.mutable_inputs_layout()->add_layouts() = layout_proto; + *proto.mutable_outputs_layout()->add_layouts() = layout_proto; + } + } else { + return absl::UnimplementedError("block aot dumping unsupported!"); } proto.set_xls_package_name(package->name()); proto.set_xls_function_identifier(func->name()); - proto.set_function_symbol(object_code.function_base.function_name()); - absl::c_for_each(object_code.function_base.input_buffer_sizes(), + proto.set_function_symbol(object_code.function_name()); + absl::c_for_each(object_code.input_buffer_sizes(), [&](int64_t i) { proto.add_input_buffer_sizes(i); }); - absl::c_for_each( - object_code.function_base.input_buffer_preferred_alignments(), - [&](int64_t i) { proto.add_input_buffer_alignments(i); }); - absl::c_for_each( - object_code.function_base.input_buffer_abi_alignments(), - [&](int64_t i) { proto.add_input_buffer_abi_alignments(i); }); - absl::c_for_each(object_code.function_base.output_buffer_sizes(), + absl::c_for_each(object_code.input_buffer_preferred_alignments(), + [&](int64_t i) { proto.add_input_buffer_alignments(i); }); + absl::c_for_each(object_code.input_buffer_abi_alignments(), [&](int64_t i) { + proto.add_input_buffer_abi_alignments(i); + }); + absl::c_for_each(object_code.output_buffer_sizes(), [&](int64_t i) { proto.add_output_buffer_sizes(i); }); - absl::c_for_each( - object_code.function_base.output_buffer_preferred_alignments(), - [&](int64_t i) { proto.add_output_buffer_alignments(i); }); - absl::c_for_each( - object_code.function_base.output_buffer_abi_alignments(), - [&](int64_t i) { proto.add_output_buffer_abi_alignments(i); }); - if (object_code.function_base.HasPackedFunction()) { - proto.set_packed_function_symbol( - *object_code.function_base.packed_function_name()); - absl::c_for_each( - object_code.function_base.packed_input_buffer_sizes(), - [&](int64_t i) { proto.add_packed_input_buffer_sizes(i); }); - absl::c_for_each( - object_code.function_base.packed_output_buffer_sizes(), - [&](int64_t i) { proto.add_packed_output_buffer_sizes(i); }); + absl::c_for_each(object_code.output_buffer_preferred_alignments(), + [&](int64_t i) { proto.add_output_buffer_alignments(i); }); + absl::c_for_each(object_code.output_buffer_abi_alignments(), [&](int64_t i) { + proto.add_output_buffer_abi_alignments(i); + }); + if (object_code.HasPackedFunction()) { + proto.set_packed_function_symbol(*object_code.packed_function_name()); + absl::c_for_each(object_code.packed_input_buffer_sizes(), [&](int64_t i) { + proto.add_packed_input_buffer_sizes(i); + }); + absl::c_for_each(object_code.packed_output_buffer_sizes(), [&](int64_t i) { + proto.add_packed_output_buffer_sizes(i); + }); } - proto.set_temp_buffer_size(object_code.function_base.temp_buffer_size()); - proto.set_temp_buffer_alignment( - object_code.function_base.temp_buffer_alignment()); - for (const auto& [cont, node] : - object_code.function_base.continuation_points()) { - proto.mutable_continuation_point_node_ids()->at(cont) = node->id(); + proto.set_temp_buffer_size(object_code.temp_buffer_size()); + proto.set_temp_buffer_alignment(object_code.temp_buffer_alignment()); + for (const auto& [cont, node] : object_code.continuation_points()) { + proto.mutable_continuation_point_node_ids()->insert({cont, node->id()}); } - for (const auto& [chan_name, idx] : - object_code.function_base.queue_indices()) { - proto.mutable_channel_queue_indices()->at(chan_name) = idx; + for (const auto& [chan_name, idx] : object_code.queue_indices()) { + proto.mutable_channel_queue_indices()->insert({chan_name, idx}); } return proto; } @@ -157,37 +156,55 @@ absl::Status RealMain(const std::string& input_ir_path, const std::string& top, XLS_ASSIGN_OR_RETURN(std::unique_ptr package, Parser::ParsePackage(input_ir, input_ir_path)); - Function* f; + FunctionBase* f; std::string package_prefix = absl::StrCat("__", package->name(), "__"); if (top.empty()) { - XLS_ASSIGN_OR_RETURN(f, package->GetTopAsFunction()); + XLS_RET_CHECK(package->HasTop()) << "No top given."; + f = *package->GetTop(); } else { - absl::StatusOr maybe_f = package->GetFunction(top); + absl::StatusOr maybe_f = package->GetFunctionBaseByName(top); if (maybe_f.ok()) { f = *maybe_f; } else { XLS_ASSIGN_OR_RETURN( - f, package->GetFunction(absl::StrCat(package_prefix, top))); + f, package->GetFunctionBaseByName(absl::StrCat(package_prefix, top))); } } - XLS_ASSIGN_OR_RETURN( - JitObjectCode object_code, - FunctionJit::CreateObjectCode(f, /*opt_level = */ 3, include_msan)); + JitObjectCode object_code; + if (f->IsFunction()) { + XLS_ASSIGN_OR_RETURN(object_code, FunctionJit::CreateObjectCode( + f->AsFunctionOrDie(), + /*opt_level = */ 3, include_msan)); + } else if (f->IsProc()) { + if (f->AsProcOrDie()->is_new_style_proc()) { + XLS_ASSIGN_OR_RETURN( + object_code, CreateProcAotObjectCode(f->AsProcOrDie(), include_msan)); + } else { + // all procs + XLS_ASSIGN_OR_RETURN( + object_code, CreateProcAotObjectCode(package.get(), include_msan)); + } + } else { + return absl::UnimplementedError( + "Dumping block jit code is not yet supported"); + } + AotPackageEntrypointsProto all_entrypoints; XLS_RETURN_IF_ERROR(SetFileContents( output_object_path, std::string(object_code.object_code.begin(), object_code.object_code.end()))); - - XLS_ASSIGN_OR_RETURN( - AotEntrypointProto entrypoint, - GenerateEntrypointProto(package.get(), f, object_code, include_msan)); + for (const FunctionEntrypoint& oc : object_code.entrypoints) { + XLS_ASSIGN_OR_RETURN(*all_entrypoints.add_entrypoint(), + GenerateEntrypointProto(package.get(), oc.function, + oc.jit_info, include_msan)); + } if (generate_textproto) { std::string text; - XLS_RET_CHECK(google::protobuf::TextFormat::PrintToString(entrypoint, &text)); + XLS_RET_CHECK(google::protobuf::TextFormat::PrintToString(all_entrypoints, &text)); XLS_RETURN_IF_ERROR(SetFileContents(output_proto_path, text)); } else { - XLS_RETURN_IF_ERROR( - SetFileContents(output_proto_path, entrypoint.SerializeAsString())); + XLS_RETURN_IF_ERROR(SetFileContents(output_proto_path, + all_entrypoints.SerializeAsString())); } return absl::OkStatus(); diff --git a/xls/jit/aot_entrypoint.proto b/xls/jit/aot_entrypoint.proto index 09f19cee55..8330260aed 100644 --- a/xls/jit/aot_entrypoint.proto +++ b/xls/jit/aot_entrypoint.proto @@ -21,9 +21,17 @@ import "xls/jit/type_layout.proto"; // Proto version of the jittedFunctionBase information needed to call a AOTd // function. message AotEntrypointProto { + enum XlsFunctionType { + INVALID = 0; + FUNCTION = 1; + PROC = 2; + BLOCK = 3; + } + // The identifier (package and function/proc/block name) in the XLS ir. optional string xls_package_name = 1; optional string xls_function_identifier = 2; + optional XlsFunctionType type = 22; // Information for unpacked function. @@ -67,3 +75,9 @@ message AotEntrypointProto { repeated string inputs_names = 20; repeated string outputs_names = 21; } + +// A single object file can have entrypoints for many different targets. This is +// a list of all of the targets contained. +message AotPackageEntrypointsProto { + repeated AotEntrypointProto entrypoint = 1; +} diff --git a/xls/jit/function_base_jit.cc b/xls/jit/function_base_jit.cc index ce11126515..241ee3a5ca 100644 --- a/xls/jit/function_base_jit.cc +++ b/xls/jit/function_base_jit.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -72,6 +73,20 @@ namespace xls { namespace { +// A fake function entrypoint we can use if we didn't actually load the compiled +// code. +int64_t InvalidJitFunctionUse(const uint8_t* const* inputs, + uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, + InstanceContext* instance_context, + JitRuntime* jit_runtime, + int64_t continuation_point) { + static_assert( + std::is_same_v); + LOG(FATAL) + << "Attempt to call invalid function pointer in JitObjectCode structure!"; +} + // Loads a pointer from the `index`-th slot in the array pointed to by // `pointer_array`. llvm::Value* LoadPointerFromPointerArray(int64_t index, @@ -1281,7 +1296,9 @@ absl::StatusOr JittedFunctionBase::BuildInternal( XLS_ASSIGN_OR_RETURN(auto fn_address, orc_jit->LoadSymbol(function_name)); jitted_function.function_ = absl::bit_cast(fn_address); } else { - jitted_function.function_ = nullptr; + // Give it a function that will give a sort of useful error message if you + // actually try to invoke it. + jitted_function.function_ = InvalidJitFunctionUse; } if (build_packed_wrapper) { @@ -1294,7 +1311,9 @@ absl::StatusOr JittedFunctionBase::BuildInternal( jitted_function.packed_function_ = absl::bit_cast(packed_fn_address); } else { - jitted_function.packed_function_ = nullptr; + // Give it a function that will give a sort of useful error message if you + // actually try to invoke it. + jitted_function.packed_function_ = InvalidJitFunctionUse; } } @@ -1394,9 +1413,10 @@ absl::StatusOr JittedFunctionBase::BuildFromAot( function->GetNodeById(node_id)); } for (auto* chan : function->package()->channels()) { - XLS_RET_CHECK(abi.channel_queue_indices().contains(chan->name())); - queue_indices[chan->name()] = - abi.channel_queue_indices().at(chan->name()); + if (abi.channel_queue_indices().contains(chan->name())) { + queue_indices[chan->name()] = + abi.channel_queue_indices().at(chan->name()); + } } } else { XLS_RET_CHECK_EQ(abi.continuation_point_node_ids_size(), 0); diff --git a/xls/jit/function_base_jit.h b/xls/jit/function_base_jit.h index e20ecd7dfe..a795b4a55c 100644 --- a/xls/jit/function_base_jit.h +++ b/xls/jit/function_base_jit.h @@ -36,7 +36,6 @@ #include "xls/jit/jit_buffer.h" #include "xls/jit/jit_callbacks.h" #include "xls/jit/jit_runtime.h" -#include "xls/jit/orc_jit.h" namespace xls { @@ -292,6 +291,18 @@ class JittedFunctionBase { absl::btree_map queue_indices_; }; +struct FunctionEntrypoint { + FunctionBase* function; + JittedFunctionBase jit_info; +}; + +// Data structure containing jitted object code and metadata about how to call +// it. +struct JitObjectCode { + std::vector object_code; + std::vector entrypoints; +}; + } // namespace xls #endif // XLS_JIT_FUNCTION_BASE_JIT_H_ diff --git a/xls/jit/function_jit.cc b/xls/jit/function_jit.cc index 4568caa28e..9a76efd154 100644 --- a/xls/jit/function_jit.cc +++ b/xls/jit/function_jit.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include @@ -51,19 +50,6 @@ absl::StatusOr> FunctionJit::Create( return CreateInternal(xls_function, opt_level, observer); } -namespace { -int64_t JitObjectCodeFunctionUse(const uint8_t* const* inputs, - uint8_t* const* outputs, void* temp_buffer, - InterpreterEvents* events, - InstanceContext* instance_context, - JitRuntime* jit_runtime, - int64_t continuation_point) { - static_assert( - std::is_same_v); - LOG(FATAL) << "Attempt to call function point in JitObjectCode structure!"; -} -} // namespace - // Returns an object containing an AOT-compiled version of the specified XLS // function. /* static */ absl::StatusOr> @@ -93,10 +79,13 @@ absl::StatusOr FunctionJit::CreateObjectCode( XLS_ASSIGN_OR_RETURN(JittedFunctionBase jfb, JittedFunctionBase::Build(xls_function, *comp)); XLS_ASSIGN_OR_RETURN(auto obj_code, std::move(comp)->GetObjectCode()); - return JitObjectCode{ - .object_code = std::move(obj_code), - .function_base = std::move(jfb), - }; + return JitObjectCode{.object_code = std::move(obj_code), + .entrypoints = { + FunctionEntrypoint{ + .function = xls_function, + .jit_info = std::move(jfb), + }, + }}; } absl::StatusOr> FunctionJit::CreateInternal( diff --git a/xls/jit/function_jit.h b/xls/jit/function_jit.h index ab85129978..076d5f06c8 100644 --- a/xls/jit/function_jit.h +++ b/xls/jit/function_jit.h @@ -42,14 +42,6 @@ namespace xls { -// Data structure containing jitted object code and metadata about how to call -// it. -struct JitObjectCode { - std::vector object_code; - - JittedFunctionBase function_base; -}; - // This class provides a facility to execute XLS functions (on the host) by // converting it to LLVM IR, compiling it, and finally executing it. Not // thread-safe due to sharing of result and temporary buffers between diff --git a/xls/jit/function_jit_aot_test.cc b/xls/jit/function_jit_aot_test.cc index ab0612e619..dfe81a1337 100644 --- a/xls/jit/function_jit_aot_test.cc +++ b/xls/jit/function_jit_aot_test.cc @@ -80,12 +80,13 @@ static constexpr std::string_view kTestAotEntrypointsProto = "xls/jit/multi_function_aot.pb"; absl::StatusOr GetEntrypointsProto() { - AotEntrypointProto proto; + AotPackageEntrypointsProto proto; XLS_ASSIGN_OR_RETURN(std::filesystem::path path, GetXlsRunfilePath(kTestAotEntrypointsProto)); XLS_ASSIGN_OR_RETURN(std::string bin, GetFileContents(path)); XLS_RET_CHECK(proto.ParseFromString(bin)); - return proto; + XLS_RET_CHECK_EQ(proto.entrypoint_size(), 1); + return proto.entrypoint()[0]; } bool AreSymbolsAsExpected() { auto v = GetEntrypointsProto(); diff --git a/xls/jit/ir_builder_visitor.h b/xls/jit/ir_builder_visitor.h index 9e7ae04885..8688747ab9 100644 --- a/xls/jit/ir_builder_visitor.h +++ b/xls/jit/ir_builder_visitor.h @@ -16,7 +16,6 @@ #include #include -#include #include #include #include @@ -24,17 +23,13 @@ #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "llvm/include/llvm/IR/Function.h" #include "llvm/include/llvm/IR/IRBuilder.h" #include "llvm/include/llvm/IR/Value.h" -#include "xls/ir/elaboration.h" #include "xls/ir/node.h" -#include "xls/ir/nodes.h" -#include "xls/jit/jit_channel_queue.h" #include "xls/jit/llvm_compiler.h" -#include "xls/jit/jit_callbacks.h" #include "xls/jit/llvm_type_converter.h" namespace xls { @@ -54,7 +49,7 @@ class JitBuilderContext { type_converter_( llvm_compiler_.GetContext(), llvm_compiler_.CreateDataLayout().value()) { - module_->setTargetTriple(llvm_compiler_.target_triple()); + CHECK_EQ(module_->getTargetTriple(), llvm_compiler_.target_triple()); } llvm::Module* module() const { return module_.get(); } diff --git a/xls/jit/jit_function_wrapper_cc.tmpl b/xls/jit/jit_function_wrapper_cc.tmpl index 33fabd088a..6c6c709f3c 100644 --- a/xls/jit/jit_function_wrapper_cc.tmpl +++ b/xls/jit/jit_function_wrapper_cc.tmpl @@ -11,12 +11,12 @@ extern "C" { // The actual symbols the AOT generates. // Unpacked entrypoint -int64_t {{wrapped.aot_entrypoint.function_symbol}}( // NOLINT +int64_t {{wrapped.aot_entrypoint.entrypoint[0].function_symbol}}( // NOLINT const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, xls::InterpreterEvents* events, xls::InstanceContext* instance_context, xls::JitRuntime* jit_runtime, int64_t continuation_point); // Packed entrypoint -int64_t {{wrapped.aot_entrypoint.packed_function_symbol}}( // NOLINT +int64_t {{wrapped.aot_entrypoint.entrypoint[0].packed_function_symbol}}( // NOLINT const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, xls::InterpreterEvents* events, xls::InstanceContext* instance_context, xls::JitRuntime* jit_runtime, int64_t continuation_point); @@ -31,7 +31,7 @@ static constexpr bool kTargetHasSanitizer = true; #else static constexpr bool kTargetHasSanitizer = false; #endif -static constexpr bool kExternHasSanitizer = {{ "true" if wrapped.aot_entrypoint.has_msan else "false" }}; +static constexpr bool kExternHasSanitizer = {{ "true" if wrapped.aot_entrypoint.entrypoint[0].has_msan else "false" }}; static_assert(kTargetHasSanitizer == kExternHasSanitizer, "sanitizer states do not match!"); @@ -44,9 +44,9 @@ static constexpr char kIrText[] = R"original_ir({{wrapped.ir_text}})original_ir"; // Bytes of the AOT entrypoint message: -{{ str(wrapped.aot_entrypoint).split("\n") | prefix_each("// ") | join("\n") }} -static constexpr std::array kAotEntrypointProtoBin = { - {{wrapped.aot_entrypoint.SerializeToString() | list | join(", ")}} +{{ str(wrapped.aot_entrypoint.entrypoint[0]).split("\n") | prefix_each("// ") | join("\n") }} +static constexpr std::array kAotEntrypointProtoBin = { + {{wrapped.aot_entrypoint.entrypoint[0].SerializeToString() | list | join(", ")}} }; } // namespace @@ -56,8 +56,8 @@ absl::StatusOr> kIrText, kFunctionName, kAotEntrypointProtoBin, - {{wrapped.aot_entrypoint.function_symbol}}, - {{wrapped.aot_entrypoint.packed_function_symbol}}); + {{wrapped.aot_entrypoint.entrypoint[0].function_symbol}}, + {{wrapped.aot_entrypoint.entrypoint[0].packed_function_symbol}}); } absl::StatusOr {{ wrapped.class_name }}::Run( diff --git a/xls/jit/jit_proc_runtime.cc b/xls/jit/jit_proc_runtime.cc index d5ef3c3a01..14fbee4962 100644 --- a/xls/jit/jit_proc_runtime.cc +++ b/xls/jit/jit_proc_runtime.cc @@ -15,25 +15,192 @@ #include "xls/jit/jit_proc_runtime.h" #include +#include +#include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "llvm/include/llvm/IR/DataLayout.h" +#include "llvm/include/llvm/IR/LLVMContext.h" +#include "llvm/include/llvm/IR/Module.h" +#include "llvm/include/llvm/Target/TargetMachine.h" +#include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/interpreter/channel_queue.h" #include "xls/interpreter/proc_evaluator.h" #include "xls/interpreter/serial_proc_runtime.h" -#include "xls/ir/channel.h" #include "xls/ir/package.h" #include "xls/ir/proc.h" #include "xls/ir/proc_elaboration.h" #include "xls/ir/value.h" +#include "xls/jit/aot_compiler.h" +#include "xls/jit/aot_entrypoint.pb.h" +#include "xls/jit/function_base_jit.h" #include "xls/jit/jit_channel_queue.h" +#include "xls/jit/llvm_compiler.h" #include "xls/jit/proc_jit.h" namespace xls { namespace { +absl::Status InsertInitialChannelValues(const ProcElaboration& elaboration, + ChannelQueueManager& queue_mgr) { + // Inject initial values into channel queues. + for (ChannelInstance* channel_instance : elaboration.channel_instances()) { + Channel* channel = channel_instance->channel; + ChannelQueue& queue = queue_mgr.GetQueue(channel_instance); + for (const Value& value : channel->initial_values()) { + XLS_RETURN_IF_ERROR(queue.Write(value)); + } + } + return absl::OkStatus(); +} + +// Wrapper compiler which just shares a single llvm::Module with multiple +// targets. +class SharedCompiler final : public LlvmCompiler { + public: + explicit SharedCompiler(std::string_view name, AotCompiler* underlying, + std::unique_ptr target, + llvm::DataLayout&& data_layout) + : LlvmCompiler(std::move(target), std::move(data_layout), + underlying->opt_level(), underlying->include_msan()), + underlying_(underlying), + the_module_(underlying_->NewModule( + absl::StrFormat("__shared_module_for_%s", name))) {} + + // Share around the same module. + std::unique_ptr NewModule(std::string_view ignored) override { + CHECK(the_module_) << "no module to give out!"; + auto res = std::move(*the_module_); + the_module_.reset(); + return res; + } + + absl::Status CompileModule(std::unique_ptr&& module) override { + XLS_RET_CHECK(!the_module_) << "Already took back module."; + the_module_ = std::move(module); + return absl::OkStatus(); + } + + // Return the underlying LLVM context. + llvm::LLVMContext* GetContext() override { return underlying_->GetContext(); } + absl::StatusOr> CreateTargetMachine() + override { + return underlying_->CreateTargetMachine(); + } + + std::unique_ptr TakeModule() && { + CHECK(the_module_) << "no module to give out!"; + auto res = std::move(*the_module_); + the_module_.reset(); + return res; + } + + protected: + absl::Status InitInternal() override { + return absl::InternalError("Should not be called"); + } + + private: + AotCompiler* underlying_; + std::optional> the_module_; +}; + +absl::StatusOr GetAotObjectCode(ProcElaboration elaboration, + bool with_msan) { + XLS_ASSIGN_OR_RETURN(std::unique_ptr compiler, + AotCompiler::Create(with_msan)); + XLS_ASSIGN_OR_RETURN(std::unique_ptr target, + compiler->CreateTargetMachine()); + llvm::DataLayout layout = target->createDataLayout(); + SharedCompiler sc( + elaboration.top() ? elaboration.top()->GetName() + : elaboration.procs().front()->package()->name(), + compiler.get(), std::move(target), std::move(layout)); + JitObjectCode joc; + for (Proc* p : elaboration.procs()) { + joc.entrypoints.push_back({.function = p}); + XLS_ASSIGN_OR_RETURN(joc.entrypoints.back().jit_info, + JittedFunctionBase::Build(p, sc)); + } + XLS_RETURN_IF_ERROR(compiler->CompileModule(std::move(sc).TakeModule())); + XLS_ASSIGN_OR_RETURN(joc.object_code, std::move(compiler)->GetObjectCode()); + return joc; +} + +namespace { +struct AotProcJitArgs { + AotEntrypointProto entrypoint; + Proc* proc; + JitFunctionType unpacked; + std::optional packed; +}; +} // namespace + +absl::StatusOr> CreateAotRuntime( + ProcElaboration elaboration, const AotPackageEntrypointsProto& entrypoints, + absl::Span impls) { + XLS_RET_CHECK_EQ(elaboration.procs().size(), entrypoints.entrypoint_size()); + XLS_RET_CHECK_EQ(elaboration.procs().size(), impls.size()); + absl::flat_hash_map procs_by_name; + for (const auto& entrypoint : entrypoints.entrypoint()) { + XLS_RET_CHECK(!procs_by_name.contains(entrypoint.xls_function_identifier())) + << "Multiple definitions for " << entrypoint.xls_function_identifier(); + procs_by_name[entrypoint.xls_function_identifier()] = { + .entrypoint = entrypoint, + .proc = nullptr, + .unpacked = nullptr, + .packed = std::nullopt}; + } + for (const auto& impl : impls) { + XLS_RET_CHECK(procs_by_name.contains(impl.proc->name())) + << "Unknown implementation of " << impl.proc->name(); + AotProcJitArgs& args = procs_by_name[impl.proc->name()]; + XLS_RET_CHECK(args.proc == nullptr) + << "Multiple copies of impl for " << impl.proc->name(); + args.proc = impl.proc; + args.unpacked = impl.unpacked; + args.packed = impl.packed; + } + XLS_RET_CHECK(absl::c_all_of(elaboration.procs(), [&](Proc* p) { + return procs_by_name.contains(p->name()) && + procs_by_name[p->name()].proc == p; + })) << "Elaboration has unknown procs"; + // Create a queue manager for the queues. This factory verifies that there an + // receive only queue for every receive only channel. + XLS_ASSIGN_OR_RETURN( + std::unique_ptr queue_manager, + JitChannelQueueManager::CreateThreadSafe(std::move(elaboration))); + // Create a ProcJit for each Proc. + std::vector> proc_jits; + for (const auto& [_, jit_args] : procs_by_name) { + XLS_ASSIGN_OR_RETURN( + std::unique_ptr proc_jit, + ProcJit::CreateFromAot(jit_args.proc, &queue_manager->runtime(), + queue_manager.get(), jit_args.entrypoint, + jit_args.unpacked, jit_args.packed)); + proc_jits.push_back(std::move(proc_jit)); + } + + // Create a runtime. + XLS_ASSIGN_OR_RETURN(std::unique_ptr proc_runtime, + SerialProcRuntime::Create(std::move(proc_jits), + std::move(queue_manager))); + + XLS_RETURN_IF_ERROR(InsertInitialChannelValues( + proc_runtime->elaboration(), proc_runtime->queue_manager())); + return std::move(proc_runtime); +} + absl::StatusOr> CreateRuntime( ProcElaboration elaboration) { // Create a queue manager for the queues. This factory verifies that there an @@ -56,17 +223,8 @@ absl::StatusOr> CreateRuntime( SerialProcRuntime::Create(std::move(proc_jits), std::move(queue_manager))); - // Inject initial values into channel queues. - for (ChannelInstance* channel_instance : - proc_runtime->elaboration().channel_instances()) { - Channel* channel = channel_instance->channel; - ChannelQueue& queue = - proc_runtime->queue_manager().GetQueue(channel_instance); - for (const Value& value : channel->initial_values()) { - XLS_RETURN_IF_ERROR(queue.Write(value)); - } - } - + XLS_RETURN_IF_ERROR(InsertInitialChannelValues( + proc_runtime->elaboration(), proc_runtime->queue_manager())); return std::move(proc_runtime); } @@ -86,4 +244,39 @@ absl::StatusOr> CreateJitSerialProcRuntime( return CreateRuntime(std::move(elaboration)); } +absl::StatusOr CreateProcAotObjectCode(Package* package, + bool with_msan) { + XLS_ASSIGN_OR_RETURN(ProcElaboration elaboration, + ProcElaboration::ElaborateOldStylePackage(package)); + return GetAotObjectCode(std::move(elaboration), with_msan); +} +absl::StatusOr CreateProcAotObjectCode(Proc* top, + bool with_msan) { + XLS_ASSIGN_OR_RETURN(ProcElaboration elaboration, + ProcElaboration::Elaborate(top)); + return GetAotObjectCode(std::move(elaboration), with_msan); +} + +// Create a SerialProcRuntime composed of ProcJits. Constructed from the +// elaboration of the given proc using the given impls. All procs in the +// elaboration must have an associated entry in the entrypoints and impls lists. +absl::StatusOr> CreateAotSerialProcRuntime( + Proc* top, const AotPackageEntrypointsProto& entrypoints, + absl::Span impls) { + XLS_ASSIGN_OR_RETURN(ProcElaboration elaboration, + ProcElaboration::Elaborate(top)); + return CreateAotRuntime(std::move(elaboration), entrypoints, impls); +} + +// Create a SerialProcRuntime composed of ProcJits. Constructed from the +// elaboration of the given package using the given impls. All procs in the +// elaboration must have an associated entry in the entrypoints and impls lists. +absl::StatusOr> CreateAotSerialProcRuntime( + Package* package, const AotPackageEntrypointsProto& entrypoints, + absl::Span impls) { + XLS_ASSIGN_OR_RETURN(ProcElaboration elaboration, + ProcElaboration::ElaborateOldStylePackage(package)); + return CreateAotRuntime(std::move(elaboration), entrypoints, impls); +} + } // namespace xls diff --git a/xls/jit/jit_proc_runtime.h b/xls/jit/jit_proc_runtime.h index 4ed1fa5d81..a6ed905535 100644 --- a/xls/jit/jit_proc_runtime.h +++ b/xls/jit/jit_proc_runtime.h @@ -16,10 +16,14 @@ #define XLS_JIT_JIT_PROC_RUNTIME_H_ #include +#include #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xls/interpreter/serial_proc_runtime.h" #include "xls/ir/package.h" +#include "xls/jit/aot_entrypoint.pb.h" +#include "xls/jit/function_base_jit.h" namespace xls { @@ -33,6 +37,40 @@ absl::StatusOr> CreateJitSerialProcRuntime( absl::StatusOr> CreateJitSerialProcRuntime( Proc* top); +struct ProcAotEntrypoints { + // What proc these entrypoints are associated with. + Proc* proc; + // unpacked entrypoint + JitFunctionType unpacked; + // packed entrypoint + std::optional packed = std::nullopt; +}; + +// Create a SerialProcRuntime composed of ProcJits. Constructed from the +// elaboration of the given proc using the given impls. All procs in the +// elaboration must have an associated entry in the entrypoints and impls lists. +// TODO(allight): Requiring the whole package here makes a lot of things simpler +// but it would be nice to not need to parse the package in the aot case. +absl::StatusOr> CreateAotSerialProcRuntime( + Proc* top, const AotPackageEntrypointsProto& entrypoints, + absl::Span impls); + +// Create a SerialProcRuntime composed of ProcJits. Constructed from the +// elaboration of the given package using the given impls. All procs in the +// elaboration must have an associated entry in the entrypoints and impls lists. +// TODO(allight): Requiring the whole package here makes a lot of things simpler +// but it would be nice to not need to parse the package in the aot case. +absl::StatusOr> CreateAotSerialProcRuntime( + Package* package, const AotPackageEntrypointsProto& entrypoints, + absl::Span impls); + +// Generate AOT code for the given proc elaboration. +absl::StatusOr CreateProcAotObjectCode(Package* package, + bool with_msan); +// Generate AOT code for the given proc elaboration. +absl::StatusOr CreateProcAotObjectCode(Proc* top, + bool with_msan); + } // namespace xls #endif // XLS_JIT_JIT_PROC_RUNTIME_H_ diff --git a/xls/jit/jit_wrapper_generator_main.py b/xls/jit/jit_wrapper_generator_main.py index 07d3ebc9cf..803aca246e 100644 --- a/xls/jit/jit_wrapper_generator_main.py +++ b/xls/jit/jit_wrapper_generator_main.py @@ -104,7 +104,7 @@ default=None, help=( "Proto file describing the interface of the available AOT'd functions" - " as a AotEntrypointProto. Must be a binary proto." + " as a AotPackageEntrypointsProto. Must be a binary proto." ), ) @@ -160,7 +160,7 @@ class WrappedIr: header_guard: str header_filename: str namespace: str - aot_entrypoint: Optional[aot_entrypoint_pb2.AotEntrypointProto] + aot_entrypoint: Optional[aot_entrypoint_pb2.AotPackageEntrypointsProto] # Function params and result. params: Optional[Sequence[XlsNamedValue]] = None result: Optional[XlsNamedValue] = None @@ -317,9 +317,9 @@ def interpret_function_interface( class_name: str, header_guard: str, header_filename: str, - aot_info: aot_entrypoint_pb2.AotEntrypointProto, + aot_info: aot_entrypoint_pb2.AotPackageEntrypointsProto, ) -> WrappedIr: - """Fill in a WrappedIr for a function. + """Fill in a WrappedIr for a function from the interface. Args: ir: package IR @@ -331,7 +331,12 @@ def interpret_function_interface( Returns: A wrapped ir for the function. + + Raises: + UsageError: If the aot info is for a different function. """ + if func_ir.base.name != aot_info.entrypoint[0].xls_function_identifier: + raise app.UsageError("Aot info is for a different function.") params = [to_param(p) for p in func_ir.parameters] result = XlsNamedValue( name="result", @@ -401,7 +406,7 @@ def interpret_interface( output_name: str, class_name: str, function_name: str, - aot_info: aot_entrypoint_pb2.AotEntrypointProto, + aot_info: aot_entrypoint_pb2.AotPackageEntrypointsProto, ) -> WrappedIr: """Create a wrapped-ir representation of the IR to be rendered to source. @@ -495,7 +500,7 @@ def main(argv: Sequence[str]) -> None: "Unknown --function_type. Requires none or FUNCTION or PROC" ) with open(_AOT_INFO.value, "rb") as aot_info_file: - aot_info = aot_entrypoint_pb2.AotEntrypointProto.FromString( + aot_info = aot_entrypoint_pb2.AotPackageEntrypointsProto.FromString( aot_info_file.read() ) ir_interface = ir_interface_pb2.PackageInterfaceProto.FromString( @@ -521,12 +526,6 @@ def main(argv: Sequence[str]) -> None: aot_info, ) - if ( - aot_info.xls_function_identifier - and wrapped.function_name != aot_info.xls_function_identifier - ): - raise app.UsageError("Aot info is for a different function!") - # Create the JINJA env and add an append filter. env = jinja2.Environment(undefined=jinja2.StrictUndefined) env.filters["append_each"] = lambda vs, suffix: [v + suffix for v in vs] diff --git a/xls/jit/llvm_compiler.cc b/xls/jit/llvm_compiler.cc index 50846476fd..3edcefafb9 100644 --- a/xls/jit/llvm_compiler.cc +++ b/xls/jit/llvm_compiler.cc @@ -67,7 +67,8 @@ absl::StatusOr LlvmCompiler::CreateDataLayout() { std::unique_ptr LlvmCompiler::NewModule(std::string_view name) { CHECK(!module_created_) << "Only one module should be made."; auto module = std::make_unique(name, *GetContext()); - module->setDataLayout(this->data_layout_); + module->setDataLayout(data_layout_); + module->setTargetTriple(target_triple()); return module; } diff --git a/xls/jit/llvm_compiler.h b/xls/jit/llvm_compiler.h index 09b1fe3b9b..cf333eeecc 100644 --- a/xls/jit/llvm_compiler.h +++ b/xls/jit/llvm_compiler.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -64,7 +65,7 @@ class LlvmCompiler { // // TODO(allight): We should rethink the architecture of the AOT/JIT compiler // at some point. - std::unique_ptr NewModule(std::string_view name); + virtual std::unique_ptr NewModule(std::string_view name); std::string target_triple() const; @@ -77,10 +78,13 @@ class LlvmCompiler { // NewModule. virtual absl::Status CompileModule( std::unique_ptr&& module) = 0; + absl::StatusOr CreateDataLayout(); + virtual absl::StatusOr> CreateTargetMachine() = 0; + int64_t opt_level() const { return opt_level_; } bool include_msan() const { return include_msan_; } protected: @@ -95,6 +99,14 @@ class LlvmCompiler { LlvmCompiler(int64_t opt_level, bool include_msan) : data_layout_(""), opt_level_(opt_level), include_msan_(include_msan) {} + // Constructor to manually setup the compiler without Init. + LlvmCompiler(std::unique_ptr target, + llvm::DataLayout&& layout, int64_t opt_level, bool include_msan) + : target_machine_(std::move(target)), + data_layout_(layout), + opt_level_(opt_level), + include_msan_(include_msan) {} + // Setup by Init std::unique_ptr target_machine_; // Setup by Init diff --git a/xls/jit/proc_jit.cc b/xls/jit/proc_jit.cc index 0e137a0aff..a96d7e5dee 100644 --- a/xls/jit/proc_jit.cc +++ b/xls/jit/proc_jit.cc @@ -199,9 +199,7 @@ absl::Status ProcJitContinuation::NextTick() { return absl::OkStatus(); } -} // namespace - -static absl::StatusOr GetChannelInstance( +absl::StatusOr GetChannelInstance( ProcInstance* proc_instance, std::string_view channel_name, JitChannelQueueManager* queue_mgr) { if (proc_instance->path().has_value()) { @@ -216,6 +214,44 @@ static absl::StatusOr GetChannelInstance( return queue_mgr->elaboration().GetUniqueInstance(channel); } +absl::Status InitializeChannelQueues( + Proc* proc, JitChannelQueueManager* queue_mgr, + const JittedFunctionBase& jitted_function_base, + absl::flat_hash_map>& + channel_queues) { + for (ProcInstance* proc_instance : + queue_mgr->elaboration().GetInstances(proc)) { + channel_queues[proc_instance].resize( + jitted_function_base.queue_indices().size()); + for (const auto& [channel_name, index] : + jitted_function_base.queue_indices()) { + XLS_ASSIGN_OR_RETURN( + ChannelInstance * channel_instance, + GetChannelInstance(proc_instance, channel_name, queue_mgr)); + channel_queues[proc_instance][index] = + &queue_mgr->GetJitQueue(channel_instance); + } + } + return absl::OkStatus(); +} + +} // namespace + +/* static */ absl::StatusOr> ProcJit::CreateFromAot( + Proc* proc, JitRuntime* jit_runtime, JitChannelQueueManager* queue_mgr, + const AotEntrypointProto& entrypoint, JitFunctionType unpacked, + std::optional packed) { + auto jit = std::unique_ptr( + new ProcJit(proc, jit_runtime, queue_mgr, /*orc_jit=*/nullptr)); + XLS_ASSIGN_OR_RETURN( + jit->jitted_function_base_, + JittedFunctionBase::BuildFromAot(proc, entrypoint, unpacked, packed)); + XLS_RET_CHECK(jit->jitted_function_base_.InputsAndOutputsAreEquivalent()); + XLS_RETURN_IF_ERROR(InitializeChannelQueues( + proc, queue_mgr, jit->jitted_function_base_, jit->channel_queues_)); + return jit; +} + absl::StatusOr> ProcJit::Create( Proc* proc, JitRuntime* jit_runtime, JitChannelQueueManager* queue_mgr, JitObserver* observer) { @@ -227,19 +263,8 @@ absl::StatusOr> ProcJit::Create( JittedFunctionBase::Build(proc, jit->GetOrcJit())); XLS_RET_CHECK(jit->jitted_function_base_.InputsAndOutputsAreEquivalent()); - for (ProcInstance* proc_instance : - queue_mgr->elaboration().GetInstances(proc)) { - jit->channel_queues_[proc_instance].resize( - jit->jitted_function_base_.queue_indices().size()); - for (const auto& [channel_name, index] : - jit->jitted_function_base_.queue_indices()) { - XLS_ASSIGN_OR_RETURN( - ChannelInstance * channel_instance, - GetChannelInstance(proc_instance, channel_name, queue_mgr)); - jit->channel_queues_[proc_instance][index] = - &queue_mgr->GetJitQueue(channel_instance); - } - } + XLS_RETURN_IF_ERROR(InitializeChannelQueues( + proc, queue_mgr, jit->jitted_function_base_, jit->channel_queues_)); return jit; } diff --git a/xls/jit/proc_jit.h b/xls/jit/proc_jit.h index 7f3624c99c..1796f7a7a9 100644 --- a/xls/jit/proc_jit.h +++ b/xls/jit/proc_jit.h @@ -15,20 +15,18 @@ #ifndef XLS_JIT_PROC_JIT_H_ #define XLS_JIT_PROC_JIT_H_ -#include #include +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "xls/interpreter/proc_evaluator.h" -#include "xls/ir/elaboration.h" -#include "xls/ir/events.h" #include "xls/ir/proc.h" -#include "xls/ir/value.h" +#include "xls/ir/proc_elaboration.h" +#include "xls/jit/aot_entrypoint.pb.h" #include "xls/jit/function_base_jit.h" -#include "xls/jit/ir_builder_visitor.h" #include "xls/jit/jit_buffer.h" #include "xls/jit/jit_channel_queue.h" #include "xls/jit/jit_runtime.h" @@ -47,6 +45,11 @@ class ProcJit : public ProcEvaluator { Proc* proc, JitRuntime* jit_runtime, JitChannelQueueManager* queue_mgr, JitObserver* observer = nullptr); + static absl::StatusOr> CreateFromAot( + Proc* proc, JitRuntime* jit_runtime, JitChannelQueueManager* queue_mgr, + const AotEntrypointProto& entrypoint, JitFunctionType unpacked, + std::optional packed = std::nullopt); + ~ProcJit() override = default; std::unique_ptr NewContinuation( diff --git a/xls/jit/proc_jit_aot_test.cc b/xls/jit/proc_jit_aot_test.cc new file mode 100644 index 0000000000..a4881ea356 --- /dev/null +++ b/xls/jit/proc_jit_aot_test.cc @@ -0,0 +1,196 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "xls/common/file/filesystem.h" +#include "xls/common/file/get_runfile_path.h" +#include "xls/common/status/matchers.h" +#include "xls/common/status/ret_check.h" +#include "xls/common/status/status_macros.h" +#include "xls/interpreter/channel_queue.h" +#include "xls/ir/events.h" +#include "xls/ir/proc.h" +#include "xls/ir/type_manager.h" +#include "xls/ir/value.h" +#include "xls/ir/value_builder.h" +#include "xls/ir/value_utils.h" +#include "xls/jit/aot_entrypoint.pb.h" +#include "xls/jit/function_base_jit.h" +#include "xls/jit/jit_callbacks.h" +#include "xls/jit/jit_channel_queue.h" +#include "xls/jit/jit_proc_runtime.h" +#include "xls/jit/jit_runtime.h" +#include "xls/public/ir_parser.h" + +extern "C" { +// Top proc entrypoint +int64_t proc_0( // NOLINT + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + xls::InterpreterEvents* events, xls::InstanceContext* instance_context, + xls::JitRuntime* jit_runtime, int64_t continuation_point); +int64_t proc_1( // NOLINT + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + xls::InterpreterEvents* events, xls::InstanceContext* instance_context, + xls::JitRuntime* jit_runtime, int64_t continuation_point); +} + +namespace xls { +namespace { +using testing::Optional; + +static_assert(std::is_same_v, + "Jit function ABI updated. This test needs to be tweaked."); +static_assert(std::is_same_v, + "Jit function ABI updated. This test needs to be tweaked."); + +static constexpr std::string_view kTestAotEntrypointsProto = + "xls/jit/specialized_caps_aot.pb"; +static constexpr std::string_view kGoldIr = "xls/jit/some_caps_no_idents.ir"; + +absl::StatusOr GetEntrypointsProto() { + AotPackageEntrypointsProto proto; + XLS_ASSIGN_OR_RETURN(std::filesystem::path path, + GetXlsRunfilePath(kTestAotEntrypointsProto)); + XLS_ASSIGN_OR_RETURN(std::string bin, GetFileContents(path)); + XLS_RET_CHECK(proto.ParseFromString(bin)); + return proto; +} +bool AreSymbolsAsExpected() { + auto v = GetEntrypointsProto(); + if (!v.ok()) { + return false; + } + return absl::c_any_of(v->entrypoint(), + [](const AotEntrypointProto& p) { + return p.has_function_symbol() && + p.function_symbol() == "proc_0"; + }) && + absl::c_any_of(v->entrypoint(), [](const AotEntrypointProto& p) { + return p.has_function_symbol() && p.function_symbol() == "proc_1"; + }); +} + +// Not really a test just to make sure that if all other tests are disabled due +// to linking failure we have *something* that fails. +TEST(SymbolNames, AreAsExpected) { + ASSERT_TRUE(AreSymbolsAsExpected()) + << "Symbols are not what we expected. This test needs to be updated to " + "match new jit-compiler symbol naming scheme. Symbols are: " + << GetEntrypointsProto(); +} + +class ProcJitAotTest : public testing::Test { + void SetUp() override { + if (!AreSymbolsAsExpected()) { + GTEST_SKIP() << "Linking probably failed. AOTEntrypoints lists " + "unexpected symbol names"; + } + } +}; + +Value StrValue(const char sv[8]) { + auto add_failure_and_return_zero = [&](auto reason) { + ADD_FAILURE() << "Unable to make value with " << sv + << " because: " << reason; + TypeManager tm; + return ZeroOfType(tm.GetArrayType(8, tm.GetBitsType(8))); + }; + std::array ret; + absl::c_copy_n(std::string_view(sv, 8), ret.size(), ret.begin()); + auto value = ValueBuilder::UBitsArray(ret, 8).Build(); + if (value.ok()) { + return *value; + } + return add_failure_and_return_zero(value); +} + +TEST_F(ProcJitAotTest, Tick) { + XLS_ASSERT_OK_AND_ASSIGN(AotPackageEntrypointsProto proto, + GetEntrypointsProto()); + XLS_ASSERT_OK_AND_ASSIGN(auto gold_file, GetXlsRunfilePath(kGoldIr)); + XLS_ASSERT_OK_AND_ASSIGN(std::string pkg_text, GetFileContents(gold_file)); + XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(pkg_text, kGoldIr)); + XLS_ASSERT_OK_AND_ASSIGN(Proc * p0, p->GetProc("proc_0")); + XLS_ASSERT_OK_AND_ASSIGN(Proc * p1, p->GetProc("proc_1")); + XLS_ASSERT_OK_AND_ASSIGN( + auto aot_runtime, + CreateAotSerialProcRuntime( + p.get(), proto, + {ProcAotEntrypoints{.proc = p0, .unpacked = proc_0}, + ProcAotEntrypoints{.proc = p1, .unpacked = proc_1}})); + XLS_ASSERT_OK_AND_ASSIGN(JitChannelQueueManager * chan_man, + aot_runtime->GetJitChannelQueueManager()); + XLS_ASSERT_OK_AND_ASSIGN(ChannelQueue * chan_input, + chan_man->GetQueueByName("chan_0")); + XLS_ASSERT_OK_AND_ASSIGN(ChannelQueue * chan_output, + chan_man->GetQueueByName("chan_1")); + XLS_EXPECT_OK(chan_input->Write(StrValue("abcdefgh"))); + XLS_EXPECT_OK(chan_input->Write(StrValue("ijklmnop"))); + XLS_EXPECT_OK(chan_input->Write(StrValue("qrstuvwx"))); + XLS_EXPECT_OK(chan_input->Write(StrValue("yz012345"))); + XLS_EXPECT_OK(aot_runtime->Tick()); + XLS_EXPECT_OK(aot_runtime->Tick()); + XLS_EXPECT_OK(aot_runtime->Tick()); + XLS_EXPECT_OK(aot_runtime->Tick()); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("ABCDEFGH"))); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("ijklmnop"))); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("QrStUvWx"))); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("YZ012345"))); +} + +TEST_F(ProcJitAotTest, TickUntilBlocked) { + XLS_ASSERT_OK_AND_ASSIGN(AotPackageEntrypointsProto proto, + GetEntrypointsProto()); + XLS_ASSERT_OK_AND_ASSIGN(auto gold_file, GetXlsRunfilePath(kGoldIr)); + XLS_ASSERT_OK_AND_ASSIGN(std::string pkg_text, GetFileContents(gold_file)); + XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(pkg_text, kGoldIr)); + XLS_ASSERT_OK_AND_ASSIGN(Proc * p0, p->GetProc("proc_0")); + XLS_ASSERT_OK_AND_ASSIGN(Proc * p1, p->GetProc("proc_1")); + XLS_ASSERT_OK_AND_ASSIGN( + auto aot_runtime, + CreateAotSerialProcRuntime( + p.get(), proto, + {ProcAotEntrypoints{.proc = p0, .unpacked = proc_0}, + ProcAotEntrypoints{.proc = p1, .unpacked = proc_1}})); + XLS_ASSERT_OK_AND_ASSIGN(JitChannelQueueManager * chan_man, + aot_runtime->GetJitChannelQueueManager()); + XLS_ASSERT_OK_AND_ASSIGN(ChannelQueue * chan_input, + chan_man->GetQueueByName("chan_0")); + XLS_ASSERT_OK_AND_ASSIGN(ChannelQueue * chan_output, + chan_man->GetQueueByName("chan_1")); + XLS_EXPECT_OK(chan_input->Write(StrValue("abcdefgh"))); + XLS_EXPECT_OK(chan_input->Write(StrValue("ijklmnop"))); + XLS_EXPECT_OK(chan_input->Write(StrValue("qrstuvwx"))); + XLS_EXPECT_OK(chan_input->Write(StrValue("yz012345"))); + XLS_ASSERT_OK_AND_ASSIGN(int64_t count, aot_runtime->TickUntilBlocked()); + // 1 for each of the inputs and then another tick succeeds by making progress + // until it tries to recv from the empty channel so 5 count. + EXPECT_EQ(count, 5); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("ABCDEFGH"))); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("ijklmnop"))); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("QrStUvWx"))); + EXPECT_THAT(chan_output->Read(), Optional(StrValue("YZ012345"))); +} +} // namespace +} // namespace xls