From b8e5b2eb46a3fd220a8928a1080e3b6b03062db5 Mon Sep 17 00:00:00 2001 From: Alex Light Date: Tue, 21 May 2024 14:18:51 -0700 Subject: [PATCH] Use AOT code for Function type jit-wrapper This speeds up startup for jit-wrapper using tools substantially (at the cost of making the name a bit of a lie). Bug: https://github.com/google/xls/issues/1403 Bug: https://github.com/google/xls/issues/1422 PiperOrigin-RevId: 635926105 --- xls/build_rules/BUILD | 2 + xls/build_rules/tests/BUILD | 17 ++++++ xls/build_rules/xls_internal_aot_rules.bzl | 2 +- xls/build_rules/xls_jit_wrapper_rules.bzl | 68 +++++++++++++++++++++- xls/jit/BUILD | 5 +- xls/jit/function_base_jit_wrapper.h | 18 ++++-- xls/jit/jit_function_wrapper_cc.tmpl | 39 ++++++++++++- xls/jit/jit_wrapper_generator_main.py | 33 ++++++++++- 8 files changed, 171 insertions(+), 13 deletions(-) diff --git a/xls/build_rules/BUILD b/xls/build_rules/BUILD index 9e18d09518..daae208f50 100644 --- a/xls/build_rules/BUILD +++ b/xls/build_rules/BUILD @@ -166,6 +166,8 @@ bzl_library( deps = [ ":xls_common_rules_bzl", ":xls_config_rules_bzl", + ":xls_internal_aot_rules_bzl", + ":xls_internal_build_defs_bzl", ":xls_ir_rules_bzl", ":xls_providers_bzl", ":xls_toolchains_bzl", diff --git a/xls/build_rules/tests/BUILD b/xls/build_rules/tests/BUILD index 1f3f9a3612..acd31d2737 100644 --- a/xls/build_rules/tests/BUILD +++ b/xls/build_rules/tests/BUILD @@ -32,6 +32,7 @@ load( "xls_ir_opt_ir", "xls_ir_verilog", ) +load("//xls/build_rules:xls_internal_aot_rules.bzl", "xls_aot_generate") load( "//xls/build_rules:xls_rules_build_defs.bzl", "FUNCTION_WRAPPER_TYPE", @@ -582,12 +583,19 @@ xls_dslx_verilog( deps = [":simple_example_4_dslx"], ) +xls_aot_generate( + name = "simple_example_5_one_stage_jit_wrapper_aot_gen", + src = ":simple_example_5_one_stage.opt.ir", + with_msan = False, +) + # The xls_ir_jit_wrapper rule using an optimized IR file as input. # The header and source files are outputs of the rule, they can be referenced # by other rules. xls_ir_jit_wrapper_macro( name = "simple_example_5_one_stage_jit_wrapper", src = ":simple_example_5_one_stage.opt.ir", + aot_info = ":simple_example_5_one_stage_jit_wrapper_aot_gen", header_file = "simple_example_5_one_stage_jit_wrapper.h", jit_wrapper_args = { "namespace": "not_xls::test", @@ -698,6 +706,7 @@ xls_dslx_verilog_native_rule( xls_ir_jit_wrapper_native_rule( name = "simple_example_5_one_stage_jit_wrapper_native_rule", src = ":simple_example_5_one_stage.opt.ir", + aot_info = ":simple_example_5_one_stage_jit_wrapper_aot_gen", header_file = "simple_example_5_one_stage_jit_wrapper_native_rule.h", jit_wrapper_args = { "namespace": "not_xls::test", @@ -779,9 +788,16 @@ xls_dslx_verilog_native_rule( verilog_file = "xls_dslx_verilog_native_rule.sv", ) +xls_aot_generate( + name = "user_defined_output_filename_jit_wrapper_aot_gen", + src = ":xls_ir_opt_ir.opt.ir", + with_msan = False, +) + xls_ir_jit_wrapper_macro( name = "user_defined_output_filename_jit_wrapper", src = ":xls_ir_opt_ir.opt.ir", + aot_info = ":user_defined_output_filename_jit_wrapper_aot_gen", header_file = "xls_ir_jit_wrapper.h", jit_wrapper_args = { "namespace": "not_xls::test", @@ -793,6 +809,7 @@ xls_ir_jit_wrapper_macro( xls_ir_jit_wrapper_native_rule( name = "user_defined_output_filename_native_rule_jit_wrapper", src = ":xls_ir_opt_ir.opt.ir", + aot_info = ":user_defined_output_filename_jit_wrapper_aot_gen", header_file = "xls_ir_jit_wrapper_native_rule.h", jit_wrapper_args = { "namespace": "not_xls::test", diff --git a/xls/build_rules/xls_internal_aot_rules.bzl b/xls/build_rules/xls_internal_aot_rules.bzl index 578b0e6389..ba38661c5c 100644 --- a/xls/build_rules/xls_internal_aot_rules.bzl +++ b/xls/build_rules/xls_internal_aot_rules.bzl @@ -41,7 +41,7 @@ _PROTO_FILE_EXTENSION = ".pb" _OBJ_FILE_EXTENSION = ".o" -visibility(["//xls/build_rules", "//xls/jit"]) +visibility(["//xls/build_rules/...", "//xls/jit"]) _xls_aot_files_attrs = { "with_msan": attr.bool( diff --git a/xls/build_rules/xls_jit_wrapper_rules.bzl b/xls/build_rules/xls_jit_wrapper_rules.bzl index d6a13398a6..cb124cbdcd 100644 --- a/xls/build_rules/xls_jit_wrapper_rules.bzl +++ b/xls/build_rules/xls_jit_wrapper_rules.bzl @@ -28,8 +28,16 @@ load( "CONFIG", "enable_generated_file_wrapper", ) +load( + "//xls/build_rules:xls_internal_aot_rules.bzl", + "xls_aot_generate", +) +load( + "//xls/build_rules:xls_internal_build_defs.bzl", + "XLS_IS_MSAN_BUILD", +) load("//xls/build_rules:xls_ir_rules.bzl", "xls_ir_common_attrs") -load("//xls/build_rules:xls_providers.bzl", "JitWrapperInfo") +load("//xls/build_rules:xls_providers.bzl", "AotCompileInfo", "JitWrapperInfo") load( "//xls/build_rules:xls_toolchains.bzl", "xls_toolchain_attrs", @@ -63,6 +71,10 @@ _xls_ir_jit_wrapper_attrs = { doc = "type of function_base we are wrapping.", mandatory = True, ), + "aot_info": attr.label( + doc = "The target which contains information about available AOT code.", + mandatory = True, + ), } def _xls_ir_jit_wrapper_impl(ctx): @@ -147,11 +159,15 @@ def _xls_ir_jit_wrapper_impl(ctx): # function_type jit_wrapper_flags.add("--function_type", ctx.attr.wrapper_type) + # Aot information + aot_info_file = ctx.attr.aot_info[AotCompileInfo].proto_file + jit_wrapper_flags.add("--aot_info", aot_info_file.path) + my_generated_files = [cc_file, h_file] # Get runfiles jit_wrapper_tool_runfiles = ctx.attr._xls_jit_wrapper_tool[DefaultInfo].default_runfiles - runfiles = get_runfiles_for_xls(ctx, [jit_wrapper_tool_runfiles], [src]) + runfiles = get_runfiles_for_xls(ctx, [jit_wrapper_tool_runfiles], [src, aot_info_file]) ctx.actions.run( outputs = my_generated_files, @@ -211,12 +227,31 @@ Examples: ), ) +def _no_aot_info_impl(ctx): + """Helper rule to create an empty AotInfo proto.""" + file = ctx.actions.declare_file(ctx.attr.name + ".pb") + ctx.actions.write(file, "", is_executable = False) + return [ + DefaultInfo(files = depset([file])), + AotCompileInfo(object_file = None, proto_file = file), + ] + +_no_aot_info = rule( + doc = """Internal only utility rule to generate an empty AotCompileInfo proto file. + + This can be used with function types that don't yet support AOT. + """, + implementation = _no_aot_info_impl, + attrs = {}, +) + def xls_ir_jit_wrapper_macro( name, src, source_file, header_file, wrapper_type, + aot_info, jit_wrapper_args = {}, enable_generated_file = True, enable_presubmit_generated_file = False, @@ -235,6 +270,8 @@ def xls_ir_jit_wrapper_macro( header_file: The generated header file. See 'header_file' attribute from the 'xls_ir_jit_wrapper' rule. wrapper_type: What sort of function base are we wrapping. + aot_info: AotCompileInfo generating label with information about the AOT + code that is available. jit_wrapper_args: Arguments of the JIT tool. See 'jit_wrapper_args' attribute from the 'xls_ir_jit_wrapper' rule. enable_generated_file: See 'enable_generated_file' from @@ -250,6 +287,7 @@ def xls_ir_jit_wrapper_macro( string_type_check("source_file", source_file) string_type_check("header_file", header_file) string_type_check("wrapper_type", wrapper_type) + string_type_check("aot_info", aot_info) dictionary_type_check("jit_wrapper_args", jit_wrapper_args) bool_type_check("enable_generated_file", enable_generated_file) bool_type_check("enable_presubmit_generated_file", enable_presubmit_generated_file) @@ -259,6 +297,7 @@ def xls_ir_jit_wrapper_macro( src = src, source_file = source_file, header_file = header_file, + aot_info = aot_info, jit_wrapper_args = jit_wrapper_args, wrapper_type = wrapper_type, outs = [source_file, header_file], @@ -281,6 +320,8 @@ _BASE_JIT_WRAPPER_DEPS = { PROC_WRAPPER_TYPE: "//xls/jit:proc_base_jit_wrapper", } +_AOT_SUPPORTED_WRAPPERS = [FUNCTION_WRAPPER_TYPE] + def cc_xls_ir_jit_wrapper( name, src, @@ -305,6 +346,9 @@ def cc_xls_ir_jit_wrapper( for compatibility. **kwargs: Keyword arguments. Named arguments. """ + + # TODO(allight): We should add top as an argument here. With the new + # jit-wrapper architecture it would be simple to support. dictionary_type_check("jit_wrapper_args", jit_wrapper_args) string_type_check("src", src) @@ -324,10 +368,28 @@ def cc_xls_ir_jit_wrapper( source_filename = name + _CC_FILE_EXTENSION header_filename = name + _H_FILE_EXTENSION + + extra_lib_deps = [] + if wrapper_type in _AOT_SUPPORTED_WRAPPERS: + xls_aot_generate( + name = name + "_aot_code_for_wrapper", + src = src, + with_msan = XLS_IS_MSAN_BUILD, + # The XLS AOT compiler does not currently support cross-compilation. + ) + aot_info_target = ":" + name + "_aot_code_for_wrapper" + extra_lib_deps.append(aot_info_target) + else: + # Simplify the xls_ir_jit_wrapper_macro by making sure it always gets an AotCompileInfo + _no_aot_info(name = name + "_empty_aot_info") + aot_info_target = ":" + name + "_empty_aot_info" + # Since this doesn't define any actual AOT code we don't need to add anything to the deps. + xls_ir_jit_wrapper_macro( name = "__" + name + "_xls_ir_jit_wrapper", src = src, jit_wrapper_args = jit_wrapper_args, + aot_info = aot_info_target, wrapper_type = wrapper_type, source_file = source_filename, header_file = header_filename, @@ -338,7 +400,7 @@ def cc_xls_ir_jit_wrapper( name = name, srcs = [":" + source_filename], hdrs = [":" + header_filename], - deps = [ + deps = extra_lib_deps + [ _BASE_JIT_WRAPPER_DEPS[wrapper_type], "@com_google_absl//absl/status", "//xls/common/status:status_macros", diff --git a/xls/jit/BUILD b/xls/jit/BUILD index 42d6f18138..934645b0b9 100644 --- a/xls/jit/BUILD +++ b/xls/jit/BUILD @@ -205,8 +205,9 @@ cc_library( # Allow jit-wrapper users to see this. visibility = ["//xls:xls_users"], deps = [ + ":aot_entrypoint_cc_proto", + ":function_base_jit", ":function_jit", - ":jit_runtime", "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/ir", @@ -215,7 +216,6 @@ cc_library( "//xls/ir:value_view", "//xls/public:ir_parser", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -255,6 +255,7 @@ py_binary( ], visibility = ["//xls:xls_users"], deps = [ + ":aot_entrypoint_py_pb2", requirement("Jinja2"), requirement("MarkupSafe"), "//xls/common:runfiles", diff --git a/xls/jit/function_base_jit_wrapper.h b/xls/jit/function_base_jit_wrapper.h index 7687b51d86..074a229ced 100644 --- a/xls/jit/function_base_jit_wrapper.h +++ b/xls/jit/function_base_jit_wrapper.h @@ -19,14 +19,12 @@ #include #include #include -#include #include #include #include #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -37,8 +35,9 @@ #include "xls/ir/package.h" #include "xls/ir/value.h" #include "xls/ir/value_view.h" +#include "xls/jit/aot_entrypoint.pb.h" +#include "xls/jit/function_base_jit.h" #include "xls/jit/function_jit.h" -#include "xls/jit/jit_runtime.h" #include "xls/public/ir_parser.h" namespace xls { @@ -62,13 +61,22 @@ class BaseFunctionJitWrapper { template static absl::StatusOr> Create( - std::string_view ir_text, std::string_view function_name) + std::string_view ir_text, std::string_view function_name, + absl::Span aot_entrypoint_proto_bin, + JitFunctionType unpacked_entrypoint, JitFunctionType packed_entrypoint) requires(std::is_base_of_v) { XLS_ASSIGN_OR_RETURN(auto package, ParsePackage(ir_text, /*filename=*/std::nullopt)); XLS_ASSIGN_OR_RETURN(auto function, package->GetFunction(function_name)); - XLS_ASSIGN_OR_RETURN(auto jit, FunctionJit::Create(function)); + AotEntrypointProto proto; + // NB We could fallback to real jit here maybe? + XLS_RET_CHECK(proto.ParseFromArray(aot_entrypoint_proto_bin.data(), + aot_entrypoint_proto_bin.size())) + << "Unable to parse aot information."; + XLS_ASSIGN_OR_RETURN( + auto jit, FunctionJit::CreateFromAot( + function, proto, unpacked_entrypoint, packed_entrypoint)); return std::unique_ptr( new RealType(std::move(package), std::move(jit), MatchesImplicitToken(function->GetType()->parameters()))); diff --git a/xls/jit/jit_function_wrapper_cc.tmpl b/xls/jit/jit_function_wrapper_cc.tmpl index 5cfb252710..33fabd088a 100644 --- a/xls/jit/jit_function_wrapper_cc.tmpl +++ b/xls/jit/jit_function_wrapper_cc.tmpl @@ -1,26 +1,63 @@ #include "{{ wrapped.header_filename }}" +#include #include #include #include "xls/common/status/status_macros.h" #include "xls/jit/function_base_jit_wrapper.h" +extern "C" { + +// The actual symbols the AOT generates. +// Unpacked entrypoint +int64_t {{wrapped.aot_entrypoint.function_symbol}}( // NOLINT + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + xls::InterpreterEvents* events, xls::InstanceContext* instance_context, + xls::JitRuntime* jit_runtime, int64_t continuation_point); +// Packed entrypoint +int64_t {{wrapped.aot_entrypoint.packed_function_symbol}}( // NOLINT + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + xls::InterpreterEvents* events, xls::InstanceContext* instance_context, + xls::JitRuntime* jit_runtime, int64_t continuation_point); +} + namespace {{ wrapped.namespace }} { namespace { + +#ifdef ABSL_HAVE_MEMORY_SANITIZER +static constexpr bool kTargetHasSanitizer = true; +#else +static constexpr bool kTargetHasSanitizer = false; +#endif +static constexpr bool kExternHasSanitizer = {{ "true" if wrapped.aot_entrypoint.has_msan else "false" }}; + +static_assert(kTargetHasSanitizer == kExternHasSanitizer, + "sanitizer states do not match!"); + static constexpr std::string_view kFunctionName = "{{ wrapped.function_name }}"; // Note: This is a plain array as the content can be so large that it exceeds // compiler constexpr limits if attempting to assign to std::string_view static constexpr char kIrText[] = R"original_ir({{wrapped.ir_text}})original_ir"; + +// Bytes of the AOT entrypoint message: +{{ str(wrapped.aot_entrypoint).split("\n") | prefix_each("// ") | join("\n") }} +static constexpr std::array kAotEntrypointProtoBin = { + {{wrapped.aot_entrypoint.SerializeToString() | list | join(", ")}} +}; } // namespace absl::StatusOr> {{ wrapped.class_name }}::Create() { return xls::BaseFunctionJitWrapper::Create<{{wrapped.class_name}}>( - kIrText, kFunctionName); + kIrText, + kFunctionName, + kAotEntrypointProtoBin, + {{wrapped.aot_entrypoint.function_symbol}}, + {{wrapped.aot_entrypoint.packed_function_symbol}}); } absl::StatusOr {{ wrapped.class_name }}::Run( diff --git a/xls/jit/jit_wrapper_generator_main.py b/xls/jit/jit_wrapper_generator_main.py index 6cbd8c08a0..07d3ebc9cf 100644 --- a/xls/jit/jit_wrapper_generator_main.py +++ b/xls/jit/jit_wrapper_generator_main.py @@ -31,6 +31,7 @@ from xls.common import runfiles from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2 from xls.ir import xls_type_pb2 as type_pb2 +from xls.jit import aot_entrypoint_pb2 _FUNCTION_TYPE = flags.DEFINE_string( @@ -97,6 +98,15 @@ default="xls", help="C++ namespace to put the wrapper in.", ) +_AOT_INFO = flags.DEFINE_string( + "aot_info", + required=True, + default=None, + help=( + "Proto file describing the interface of the available AOT'd functions" + " as a AotEntrypointProto. Must be a binary proto." + ), +) @dataclasses.dataclass(frozen=True) @@ -150,6 +160,7 @@ class WrappedIr: header_guard: str header_filename: str namespace: str + aot_entrypoint: Optional[aot_entrypoint_pb2.AotEntrypointProto] # Function params and result. params: Optional[Sequence[XlsNamedValue]] = None result: Optional[XlsNamedValue] = None @@ -306,6 +317,7 @@ def interpret_function_interface( class_name: str, header_guard: str, header_filename: str, + aot_info: aot_entrypoint_pb2.AotEntrypointProto, ) -> WrappedIr: """Fill in a WrappedIr for a function. @@ -315,6 +327,7 @@ def interpret_function_interface( class_name: The class name header_guard: The header-guard string header_filename: The header file name. + aot_info: The aot info for the function. Returns: A wrapped ir for the function. @@ -337,6 +350,7 @@ def interpret_function_interface( namespace=namespace, params=params, result=result, + aot_entrypoint=aot_info, ) @@ -362,6 +376,7 @@ def interpret_proc_interface( incoming_channels=input_channels, outgoing_channels=output_channels, state=state, + aot_entrypoint=None, ) @@ -386,6 +401,7 @@ def interpret_interface( output_name: str, class_name: str, function_name: str, + aot_info: aot_entrypoint_pb2.AotEntrypointProto, ) -> WrappedIr: """Create a wrapped-ir representation of the IR to be rendered to source. @@ -395,6 +411,7 @@ def interpret_interface( output_name: what the file basename we are writing to is. class_name: what the class we are creating is called. function_name: what the IR function we are actually calling is. + aot_info: The aot info for the function. Returns: A WrappedIr ready for rendering. @@ -422,6 +439,7 @@ def interpret_interface( class_name, header_guard, header_filename, + aot_info, ) # Try to find a proc if _FUNCTION_TYPE.value in (None, "PROC"): @@ -476,6 +494,10 @@ def main(argv: Sequence[str]) -> None: raise app.UsageError( "Unknown --function_type. Requires none or FUNCTION or PROC" ) + with open(_AOT_INFO.value, "rb") as aot_info_file: + aot_info = aot_entrypoint_pb2.AotEntrypointProto.FromString( + aot_info_file.read() + ) ir_interface = ir_interface_pb2.PackageInterfaceProto.FromString( subprocess.check_output([ runfiles.get_path("xls/tools/extract_interface_main"), @@ -496,12 +518,21 @@ def main(argv: Sequence[str]) -> None: output_name, class_name, function_name, + aot_info, ) + if ( + aot_info.xls_function_identifier + and wrapped.function_name != aot_info.xls_function_identifier + ): + raise app.UsageError("Aot info is for a different function!") + # Create the JINJA env and add an append filter. env = jinja2.Environment(undefined=jinja2.StrictUndefined) env.filters["append_each"] = lambda vs, suffix: [v + suffix for v in vs] - bindings = {"wrapped": wrapped, "len": len} + env.filters["prefix_each"] = lambda vs, prefix: [prefix + v for v in vs] + env.filters["to_char_ints"] = lambda v: [x for x in v] + bindings = {"wrapped": wrapped, "len": len, "str": str} with open(f"{_OUTPUT_DIR.value}/{output_name}.cc", "wt") as cc_file: cc_template = env.from_string(_CC_TEMPLATES[wrapped.jit_type])