diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl index 08207249455ce6..056e4ee61951ae 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl @@ -350,6 +350,34 @@ def _declare_generated_files( return outputs +def _find_toolchain(ctx, legacy_attr, toolchain_type): + if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(): + toolchain = ctx.toolchains[toolchain_type] + if not toolchain: + fail("No toolchains registered for '%s'." % toolchain_type) + return toolchain.proto + else: + return getattr(ctx.attr, legacy_attr)[ProtoLangToolchainInfo] + +def _use_toolchain(toolchain_type): + if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(): + return [_builtins.toplevel.config_common.toolchain_type(toolchain_type, mandatory = False)] + else: + return [] + +def _if_legacy_toolchain(legacy_attr_dict): + if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(): + return {} + else: + return legacy_attr_dict + +toolchains = struct( + use_toolchain = _use_toolchain, + find_toolchain = _find_toolchain, + if_legacy_toolchain = _if_legacy_toolchain, + INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(), +) + proto_common_do_not_use = struct( compile = _compile, declare_generated_files = _declare_generated_files, diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl index 72eaafda28e0af..3e609efc20be39 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl @@ -14,8 +14,8 @@ """A Starlark implementation of the proto_lang_toolchain rule.""" +load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains") load(":common/proto/proto_info.bzl", "ProtoInfo") -load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo") load(":common/proto/proto_semantics.bzl", "semantics") PackageSpecificationInfo = _builtins.toplevel.PackageSpecificationInfo @@ -32,9 +32,9 @@ def _rule_impl(ctx): if ctx.attr.plugin != None: plugin = ctx.attr.plugin[DefaultInfo].files_to_run - if semantics.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: - proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN_TYPE].proto.proto_compiler - protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN_TYPE].proto.protoc_opts + if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.proto_compiler + protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.protoc_opts else: proto_compiler = ctx.attr._proto_compiler.files_to_run protoc_opts = ctx.fragments.proto.experimental_protoc_opts @@ -81,7 +81,7 @@ proto_lang_toolchain = rule( cfg = "exec", providers = [PackageSpecificationInfo], ), - } | ({} if semantics.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { + } | ({} if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { "_proto_compiler": attr.label( cfg = "exec", executable = True, @@ -91,5 +91,5 @@ proto_lang_toolchain = rule( }), provides = [ProtoLangToolchainInfo], fragments = ["proto"], - toolchains = semantics.PROTO_TOOLCHAIN, # Used to obtain protoc + toolchains = toolchains.use_toolchain(semantics.PROTO_TOOLCHAIN), # Used to obtain protoc ) diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl index 2d8c805afb6ce2..f14fdee4c36db4 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl @@ -17,7 +17,7 @@ Definition of proto_library rule. """ load(":common/paths.bzl", "paths") -load(":common/proto/proto_common.bzl", proto_common = "proto_common_do_not_use") +load(":common/proto/proto_common.bzl", "toolchains", proto_common = "proto_common_do_not_use") load(":common/proto/proto_info.bzl", "ProtoInfo") load(":common/proto/proto_semantics.bzl", "semantics") @@ -208,8 +208,8 @@ def _write_descriptor_set(ctx, proto_info, deps, exports, descriptor_set): map_each = proto_common.get_import_path, join_with = ":", ) - if semantics.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: - toolchain = ctx.toolchains[semantics.PROTO_TOOLCHAIN_TYPE] + if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + toolchain = ctx.toolchains[semantics.PROTO_TOOLCHAIN] if not toolchain: fail("Protocol compiler toolchain could not be resolved.") proto_lang_toolchain_info = toolchain.proto @@ -256,7 +256,7 @@ proto_library = rule( flags = ["SKIP_CONSTRAINTS_OVERRIDE"], ), "licenses": attr.license() if hasattr(attr, "license") else attr.string_list(), - } | ({} if semantics.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { + } | toolchains.if_legacy_toolchain({ "_proto_compiler": attr.label( cfg = "exec", executable = True, @@ -267,5 +267,5 @@ proto_library = rule( fragments = ["proto"] + semantics.EXTRA_FRAGMENTS, provides = [ProtoInfo], exec_groups = semantics.EXEC_GROUPS, - toolchains = semantics.PROTO_TOOLCHAIN, + toolchains = toolchains.use_toolchain(semantics.PROTO_TOOLCHAIN), ) diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl index a9c931f22a279d..d34f5412f99354 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_semantics.bzl @@ -19,18 +19,8 @@ Proto Semantics def _preprocess(ctx): pass -_PROTO_TOOLCHAIN_TYPE = "@rules_proto//proto:toolchain_type" - -def _get_proto_toolchain(): - if _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(): - return [_builtins.toplevel.config_common.toolchain_type(_PROTO_TOOLCHAIN_TYPE, mandatory = False)] - else: - return [] - semantics = struct( - PROTO_TOOLCHAIN_TYPE = _PROTO_TOOLCHAIN_TYPE, - PROTO_TOOLCHAIN = _get_proto_toolchain(), - INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(), + PROTO_TOOLCHAIN = "@rules_proto//proto:toolchain_type", PROTO_COMPILER_LABEL = "@bazel_tools//tools/proto:protoc", EXTRA_ATTRIBUTES = { "import_prefix": attr.string(),