Skip to content

Commit

Permalink
Use AOT code for Function type jit-wrapper
Browse files Browse the repository at this point in the history
This speeds up startup for jit-wrapper using tools substantially (at the cost of making the name a bit of a lie).

Bug: #1403
Bug: #1422
PiperOrigin-RevId: 635926105
  • Loading branch information
allight authored and copybara-github committed May 21, 2024
1 parent aa905d9 commit b8e5b2e
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 13 deletions.
2 changes: 2 additions & 0 deletions xls/build_rules/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ bzl_library(
deps = [
":xls_common_rules_bzl",
":xls_config_rules_bzl",
":xls_internal_aot_rules_bzl",
":xls_internal_build_defs_bzl",
":xls_ir_rules_bzl",
":xls_providers_bzl",
":xls_toolchains_bzl",
Expand Down
17 changes: 17 additions & 0 deletions xls/build_rules/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ load(
"xls_ir_opt_ir",
"xls_ir_verilog",
)
load("//xls/build_rules:xls_internal_aot_rules.bzl", "xls_aot_generate")
load(
"//xls/build_rules:xls_rules_build_defs.bzl",
"FUNCTION_WRAPPER_TYPE",
Expand Down Expand Up @@ -582,12 +583,19 @@ xls_dslx_verilog(
deps = [":simple_example_4_dslx"],
)

xls_aot_generate(
name = "simple_example_5_one_stage_jit_wrapper_aot_gen",
src = ":simple_example_5_one_stage.opt.ir",
with_msan = False,
)

# The xls_ir_jit_wrapper rule using an optimized IR file as input.
# The header and source files are outputs of the rule, they can be referenced
# by other rules.
xls_ir_jit_wrapper_macro(
name = "simple_example_5_one_stage_jit_wrapper",
src = ":simple_example_5_one_stage.opt.ir",
aot_info = ":simple_example_5_one_stage_jit_wrapper_aot_gen",
header_file = "simple_example_5_one_stage_jit_wrapper.h",
jit_wrapper_args = {
"namespace": "not_xls::test",
Expand Down Expand Up @@ -698,6 +706,7 @@ xls_dslx_verilog_native_rule(
xls_ir_jit_wrapper_native_rule(
name = "simple_example_5_one_stage_jit_wrapper_native_rule",
src = ":simple_example_5_one_stage.opt.ir",
aot_info = ":simple_example_5_one_stage_jit_wrapper_aot_gen",
header_file = "simple_example_5_one_stage_jit_wrapper_native_rule.h",
jit_wrapper_args = {
"namespace": "not_xls::test",
Expand Down Expand Up @@ -779,9 +788,16 @@ xls_dslx_verilog_native_rule(
verilog_file = "xls_dslx_verilog_native_rule.sv",
)

xls_aot_generate(
name = "user_defined_output_filename_jit_wrapper_aot_gen",
src = ":xls_ir_opt_ir.opt.ir",
with_msan = False,
)

xls_ir_jit_wrapper_macro(
name = "user_defined_output_filename_jit_wrapper",
src = ":xls_ir_opt_ir.opt.ir",
aot_info = ":user_defined_output_filename_jit_wrapper_aot_gen",
header_file = "xls_ir_jit_wrapper.h",
jit_wrapper_args = {
"namespace": "not_xls::test",
Expand All @@ -793,6 +809,7 @@ xls_ir_jit_wrapper_macro(
xls_ir_jit_wrapper_native_rule(
name = "user_defined_output_filename_native_rule_jit_wrapper",
src = ":xls_ir_opt_ir.opt.ir",
aot_info = ":user_defined_output_filename_jit_wrapper_aot_gen",
header_file = "xls_ir_jit_wrapper_native_rule.h",
jit_wrapper_args = {
"namespace": "not_xls::test",
Expand Down
2 changes: 1 addition & 1 deletion xls/build_rules/xls_internal_aot_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ _PROTO_FILE_EXTENSION = ".pb"

_OBJ_FILE_EXTENSION = ".o"

visibility(["//xls/build_rules", "//xls/jit"])
visibility(["//xls/build_rules/...", "//xls/jit"])

_xls_aot_files_attrs = {
"with_msan": attr.bool(
Expand Down
68 changes: 65 additions & 3 deletions xls/build_rules/xls_jit_wrapper_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@ load(
"CONFIG",
"enable_generated_file_wrapper",
)
load(
"//xls/build_rules:xls_internal_aot_rules.bzl",
"xls_aot_generate",
)
load(
"//xls/build_rules:xls_internal_build_defs.bzl",
"XLS_IS_MSAN_BUILD",
)
load("//xls/build_rules:xls_ir_rules.bzl", "xls_ir_common_attrs")
load("//xls/build_rules:xls_providers.bzl", "JitWrapperInfo")
load("//xls/build_rules:xls_providers.bzl", "AotCompileInfo", "JitWrapperInfo")
load(
"//xls/build_rules:xls_toolchains.bzl",
"xls_toolchain_attrs",
Expand Down Expand Up @@ -63,6 +71,10 @@ _xls_ir_jit_wrapper_attrs = {
doc = "type of function_base we are wrapping.",
mandatory = True,
),
"aot_info": attr.label(
doc = "The target which contains information about available AOT code.",
mandatory = True,
),
}

def _xls_ir_jit_wrapper_impl(ctx):
Expand Down Expand Up @@ -147,11 +159,15 @@ def _xls_ir_jit_wrapper_impl(ctx):
# function_type
jit_wrapper_flags.add("--function_type", ctx.attr.wrapper_type)

# Aot information
aot_info_file = ctx.attr.aot_info[AotCompileInfo].proto_file
jit_wrapper_flags.add("--aot_info", aot_info_file.path)

my_generated_files = [cc_file, h_file]

# Get runfiles
jit_wrapper_tool_runfiles = ctx.attr._xls_jit_wrapper_tool[DefaultInfo].default_runfiles
runfiles = get_runfiles_for_xls(ctx, [jit_wrapper_tool_runfiles], [src])
runfiles = get_runfiles_for_xls(ctx, [jit_wrapper_tool_runfiles], [src, aot_info_file])

ctx.actions.run(
outputs = my_generated_files,
Expand Down Expand Up @@ -211,12 +227,31 @@ Examples:
),
)

def _no_aot_info_impl(ctx):
"""Helper rule to create an empty AotInfo proto."""
file = ctx.actions.declare_file(ctx.attr.name + ".pb")
ctx.actions.write(file, "", is_executable = False)
return [
DefaultInfo(files = depset([file])),
AotCompileInfo(object_file = None, proto_file = file),
]

_no_aot_info = rule(
doc = """Internal only utility rule to generate an empty AotCompileInfo proto file.
This can be used with function types that don't yet support AOT.
""",
implementation = _no_aot_info_impl,
attrs = {},
)

def xls_ir_jit_wrapper_macro(
name,
src,
source_file,
header_file,
wrapper_type,
aot_info,
jit_wrapper_args = {},
enable_generated_file = True,
enable_presubmit_generated_file = False,
Expand All @@ -235,6 +270,8 @@ def xls_ir_jit_wrapper_macro(
header_file: The generated header file. See 'header_file' attribute from
the 'xls_ir_jit_wrapper' rule.
wrapper_type: What sort of function base are we wrapping.
aot_info: AotCompileInfo generating label with information about the AOT
code that is available.
jit_wrapper_args: Arguments of the JIT tool. See 'jit_wrapper_args'
attribute from the 'xls_ir_jit_wrapper' rule.
enable_generated_file: See 'enable_generated_file' from
Expand All @@ -250,6 +287,7 @@ def xls_ir_jit_wrapper_macro(
string_type_check("source_file", source_file)
string_type_check("header_file", header_file)
string_type_check("wrapper_type", wrapper_type)
string_type_check("aot_info", aot_info)
dictionary_type_check("jit_wrapper_args", jit_wrapper_args)
bool_type_check("enable_generated_file", enable_generated_file)
bool_type_check("enable_presubmit_generated_file", enable_presubmit_generated_file)
Expand All @@ -259,6 +297,7 @@ def xls_ir_jit_wrapper_macro(
src = src,
source_file = source_file,
header_file = header_file,
aot_info = aot_info,
jit_wrapper_args = jit_wrapper_args,
wrapper_type = wrapper_type,
outs = [source_file, header_file],
Expand All @@ -281,6 +320,8 @@ _BASE_JIT_WRAPPER_DEPS = {
PROC_WRAPPER_TYPE: "//xls/jit:proc_base_jit_wrapper",
}

_AOT_SUPPORTED_WRAPPERS = [FUNCTION_WRAPPER_TYPE]

def cc_xls_ir_jit_wrapper(
name,
src,
Expand All @@ -305,6 +346,9 @@ def cc_xls_ir_jit_wrapper(
for compatibility.
**kwargs: Keyword arguments. Named arguments.
"""

# TODO(allight): We should add top as an argument here. With the new
# jit-wrapper architecture it would be simple to support.
dictionary_type_check("jit_wrapper_args", jit_wrapper_args)
string_type_check("src", src)

Expand All @@ -324,10 +368,28 @@ def cc_xls_ir_jit_wrapper(

source_filename = name + _CC_FILE_EXTENSION
header_filename = name + _H_FILE_EXTENSION

extra_lib_deps = []
if wrapper_type in _AOT_SUPPORTED_WRAPPERS:
xls_aot_generate(
name = name + "_aot_code_for_wrapper",
src = src,
with_msan = XLS_IS_MSAN_BUILD,
# The XLS AOT compiler does not currently support cross-compilation.
)
aot_info_target = ":" + name + "_aot_code_for_wrapper"
extra_lib_deps.append(aot_info_target)
else:
# Simplify the xls_ir_jit_wrapper_macro by making sure it always gets an AotCompileInfo
_no_aot_info(name = name + "_empty_aot_info")
aot_info_target = ":" + name + "_empty_aot_info"
# Since this doesn't define any actual AOT code we don't need to add anything to the deps.

xls_ir_jit_wrapper_macro(
name = "__" + name + "_xls_ir_jit_wrapper",
src = src,
jit_wrapper_args = jit_wrapper_args,
aot_info = aot_info_target,
wrapper_type = wrapper_type,
source_file = source_filename,
header_file = header_filename,
Expand All @@ -338,7 +400,7 @@ def cc_xls_ir_jit_wrapper(
name = name,
srcs = [":" + source_filename],
hdrs = [":" + header_filename],
deps = [
deps = extra_lib_deps + [
_BASE_JIT_WRAPPER_DEPS[wrapper_type],
"@com_google_absl//absl/status",
"//xls/common/status:status_macros",
Expand Down
5 changes: 3 additions & 2 deletions xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ cc_library(
# Allow jit-wrapper users to see this.
visibility = ["//xls:xls_users"],
deps = [
":aot_entrypoint_cc_proto",
":function_base_jit",
":function_jit",
":jit_runtime",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir",
Expand All @@ -215,7 +216,6 @@ cc_library(
"//xls/ir:value_view",
"//xls/public:ir_parser",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -255,6 +255,7 @@ py_binary(
],
visibility = ["//xls:xls_users"],
deps = [
":aot_entrypoint_py_pb2",
requirement("Jinja2"),
requirement("MarkupSafe"),
"//xls/common:runfiles",
Expand Down
18 changes: 13 additions & 5 deletions xls/jit/function_base_jit_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand All @@ -37,8 +35,9 @@
#include "xls/ir/package.h"
#include "xls/ir/value.h"
#include "xls/ir/value_view.h"
#include "xls/jit/aot_entrypoint.pb.h"
#include "xls/jit/function_base_jit.h"
#include "xls/jit/function_jit.h"
#include "xls/jit/jit_runtime.h"
#include "xls/public/ir_parser.h"

namespace xls {
Expand All @@ -62,13 +61,22 @@ class BaseFunctionJitWrapper {

template <typename RealType>
static absl::StatusOr<std::unique_ptr<RealType>> Create(
std::string_view ir_text, std::string_view function_name)
std::string_view ir_text, std::string_view function_name,
absl::Span<uint8_t const> aot_entrypoint_proto_bin,
JitFunctionType unpacked_entrypoint, JitFunctionType packed_entrypoint)
requires(std::is_base_of_v<BaseFunctionJitWrapper, RealType>)
{
XLS_ASSIGN_OR_RETURN(auto package,
ParsePackage(ir_text, /*filename=*/std::nullopt));
XLS_ASSIGN_OR_RETURN(auto function, package->GetFunction(function_name));
XLS_ASSIGN_OR_RETURN(auto jit, FunctionJit::Create(function));
AotEntrypointProto proto;
// NB We could fallback to real jit here maybe?
XLS_RET_CHECK(proto.ParseFromArray(aot_entrypoint_proto_bin.data(),
aot_entrypoint_proto_bin.size()))
<< "Unable to parse aot information.";
XLS_ASSIGN_OR_RETURN(
auto jit, FunctionJit::CreateFromAot(
function, proto, unpacked_entrypoint, packed_entrypoint));
return std::unique_ptr<RealType>(
new RealType(std::move(package), std::move(jit),
MatchesImplicitToken(function->GetType()->parameters())));
Expand Down
39 changes: 38 additions & 1 deletion xls/jit/jit_function_wrapper_cc.tmpl
Original file line number Diff line number Diff line change
@@ -1,26 +1,63 @@
#include "{{ wrapped.header_filename }}"

#include <cstdint>
#include <array>
#include <string_view>

#include "xls/common/status/status_macros.h"
#include "xls/jit/function_base_jit_wrapper.h"

extern "C" {

// The actual symbols the AOT generates.
// Unpacked entrypoint
int64_t {{wrapped.aot_entrypoint.function_symbol}}( // NOLINT
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
xls::InterpreterEvents* events, xls::InstanceContext* instance_context,
xls::JitRuntime* jit_runtime, int64_t continuation_point);
// Packed entrypoint
int64_t {{wrapped.aot_entrypoint.packed_function_symbol}}( // NOLINT
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
xls::InterpreterEvents* events, xls::InstanceContext* instance_context,
xls::JitRuntime* jit_runtime, int64_t continuation_point);
}

namespace {{ wrapped.namespace }} {

namespace {

#ifdef ABSL_HAVE_MEMORY_SANITIZER
static constexpr bool kTargetHasSanitizer = true;
#else
static constexpr bool kTargetHasSanitizer = false;
#endif
static constexpr bool kExternHasSanitizer = {{ "true" if wrapped.aot_entrypoint.has_msan else "false" }};

static_assert(kTargetHasSanitizer == kExternHasSanitizer,
"sanitizer states do not match!");

static constexpr std::string_view kFunctionName = "{{ wrapped.function_name }}";

// Note: This is a plain array as the content can be so large that it exceeds
// compiler constexpr limits if attempting to assign to std::string_view
static constexpr char kIrText[] =
R"original_ir({{wrapped.ir_text}})original_ir";

// Bytes of the AOT entrypoint message:
{{ str(wrapped.aot_entrypoint).split("\n") | prefix_each("// ") | join("\n") }}
static constexpr std::array<uint8_t, {{len(wrapped.aot_entrypoint.SerializeToString())}}> kAotEntrypointProtoBin = {
{{wrapped.aot_entrypoint.SerializeToString() | list | join(", ")}}
};
} // namespace

absl::StatusOr<std::unique_ptr<{{ wrapped.class_name }}>>
{{ wrapped.class_name }}::Create() {
return xls::BaseFunctionJitWrapper::Create<{{wrapped.class_name}}>(
kIrText, kFunctionName);
kIrText,
kFunctionName,
kAotEntrypointProtoBin,
{{wrapped.aot_entrypoint.function_symbol}},
{{wrapped.aot_entrypoint.packed_function_symbol}});
}

absl::StatusOr<xls::Value> {{ wrapped.class_name }}::Run(
Expand Down
Loading

0 comments on commit b8e5b2e

Please sign in to comment.