Skip to content

Commit

Permalink
Add custom tag to go_sdk and generate host_compatible_sdk repo
Browse files Browse the repository at this point in the history
  • Loading branch information
ylecornec committed Apr 17, 2023
1 parent 4660427 commit 958b568
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
60 changes: 59 additions & 1 deletion go/private/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
load("//go/private:sdk.bzl", "go_download_sdk_rule", "go_host_sdk_rule", "go_multiple_toolchains")
load("//go/private:sdk.bzl", "detect_host_platform", "go_download_sdk_rule", "go_host_sdk_rule", "go_multiple_toolchains")
load("//go/private:repositories.bzl", "go_rules_dependencies")

# Repository rule used to generate the `host_compatible_sdk` repository.
# Which can be used to bootstrap gazelle instead of relying on go_default_sdk
def host_compatible_toolchain_impl(rctx):
rctx.file("BUILD.bazel", content = "")
loads = []
host_compatible_labels = []
for i, l in enumerate(rctx.attr.toolchains):
if l.endswith(".bzl"):
# bzl file with a variable containing the toolchain label
# This indirection enables the user to generate the file differently depending on the environment.
# In particular the toolchain variable can be equal to None in which case it will be ignored
label = "toolchain_{}".format(i)
loads.append(
"""load("{custom_toolchain}", {label} = "toolchain")""".format(custom_toolchain = l, label = label),
)
else:
label = """Label("{}")""".format(l)
host_compatible_labels.append(label)

rctx.file("defs.bzl", content = """
{loads}
host_compatible_label = {host_compatible_label}
""".format(loads = "\n".join(loads), host_compatible_label = " or ".join(host_compatible_labels)))

host_compatible_toolchain = repository_rule(
implementation = host_compatible_toolchain_impl,
attrs = {"toolchains": attr.string_list(mandatory = True)},
)

_download_tag = tag_class(
attrs = {
"name": attr.string(),
Expand All @@ -20,6 +49,18 @@ _host_tag = tag_class(
},
)

_custom_tag = tag_class(
doc = "Declare a custom toolchain to rules_go. It will then be considered when choosing the toolchain exposed by the `host_compatible_sdk` repository",
attrs = {
"custom_toolchain_bzl_file": attr.label(
doc = """bzl file containing a `toolchain` variable.
This indirection enables the user to generate the file differently depending on the environment.
In particular the toolchain variable can be equal to None in which case this custom toolchain will be ignored.
""",
),
},
)

# This limit can be increased essentially arbitrarily, but doing so will cause a rebuild of all
# targets using any of these toolchains due to the changed repository name.
_MAX_NUM_TOOLCHAINS = 9999
Expand All @@ -33,8 +74,18 @@ def _go_sdk_impl(ctx):
else:
multi_version_module[module.name] = False

host_detected_goos, host_detected_goarch = detect_host_platform(ctx)
toolchains = []

# label of toolchains compatible with the current host
# declared with either download, host or custom tags
host_compatible_toolchains = []

for module in ctx.modules:
for index, custom_tag in enumerate(module.tags.custom):
# Custom toolchains should contain a `toolchain` variable equal to None unless compatible with the host.
host_compatible_toolchains.append(str(custom_tag.custom_toolchain_bzl_file))

for index, download_tag in enumerate(module.tags.download):
# SDKs without an explicit version are fetched even when not selected by toolchain
# resolution. This is acceptable if brought in by the root module, but transitive
Expand Down Expand Up @@ -66,6 +117,10 @@ def _go_sdk_impl(ctx):
version = download_tag.version,
)

if not download_tag.goos or (download_tag.goos == host_detected_goos and download_tag.goarch == host_detected_goarch):
# Among toolchains defined with the download tag, we only consider those consistent with `detect_host_platform`
host_compatible_toolchains.append("@{}//:ROOT".format(name))

toolchains.append(struct(
goos = download_tag.goos,
goarch = download_tag.goarch,
Expand Down Expand Up @@ -99,7 +154,9 @@ def _go_sdk_impl(ctx):
sdk_type = "host",
sdk_version = host_tag.version,
))
host_compatible_toolchains.append("@{}//:ROOT".format(name))

host_compatible_toolchain(name = "host_compatible_sdk", toolchains = host_compatible_toolchains)
if len(toolchains) > _MAX_NUM_TOOLCHAINS:
fail("more than {} go_sdk tags are not supported".format(_MAX_NUM_TOOLCHAINS))

Expand Down Expand Up @@ -150,6 +207,7 @@ go_sdk = module_extension(
tag_classes = {
"download": _download_tag,
"host": _host_tag,
"custom": _custom_tag,
},
)

Expand Down
8 changes: 4 additions & 4 deletions go/private/sdk.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def go_host_sdk(name, register_toolchains = True, **kwargs):

def _go_download_sdk_impl(ctx):
if not ctx.attr.goos and not ctx.attr.goarch:
goos, goarch = _detect_host_platform(ctx)
goos, goarch = detect_host_platform(ctx)
else:
if not ctx.attr.goos:
fail("goarch set but goos not set")
Expand Down Expand Up @@ -173,7 +173,7 @@ def _to_constant_name(s):

def go_toolchains_single_definition(ctx, *, prefix, goos, goarch, sdk_repo, sdk_type, sdk_version):
if not goos and not goarch:
goos, goarch = _detect_host_platform(ctx)
goos, goarch = detect_host_platform(ctx)
else:
if not goos:
fail("goarch set but goos not set")
Expand Down Expand Up @@ -354,7 +354,7 @@ def _go_wrap_sdk_impl(ctx):
if ctx.attr.root_file:
root_file = ctx.attr.root_file
else:
goos, goarch = _detect_host_platform(ctx)
goos, goarch = detect_host_platform(ctx)
platform = goos + "_" + goarch
if platform not in ctx.attr.root_files:
fail("unsupported platform {}".format(platform))
Expand Down Expand Up @@ -466,7 +466,7 @@ def _sdk_build_file(ctx, platform, version, experiments):
content = _define_version_constants(version),
)

def _detect_host_platform(ctx):
def detect_host_platform(ctx):
goos = ctx.os.name
if goos == "mac os x":
goos = "darwin"
Expand Down

0 comments on commit 958b568

Please sign in to comment.