diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl index 01ea4f69f932a2..03345ac26902c9 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl @@ -24,73 +24,146 @@ ProtoInfo = _builtins.toplevel.ProtoInfo native_proto_common = _builtins.toplevel.proto_common def _check_srcs_package(target_package, srcs): - """Makes sure the given srcs live in the given package.""" + """Check that .proto files in sources are from the same package. + + This is done to avoid clashes with the generated sources.""" + + #TODO(bazel-team): this does not work with filegroups that contain files that are not in the package for src in srcs: if target_package != src.label.package: fail("Proto source with label '%s' must be in same package as consuming rule." % src.label) -def _join(*path): - return "/".join([p for p in path if p != ""]) - -def _create_proto_info(ctx): - srcs = ctx.files.srcs - deps = [dep[ProtoInfo] for dep in ctx.attr.deps] - exports = [dep[ProtoInfo] for dep in ctx.attr.exports] +def _get_import_prefix(ctx): + """Gets and verifies import_prefix attribute if it is declared.""" import_prefix = ctx.attr.import_prefix if hasattr(ctx.attr, "import_prefix") else "" + if not paths.is_normalized(import_prefix): fail("should be normalized (without uplevel references or '.' path segments)", attr = "import_prefix") + if paths.is_absolute(import_prefix): + fail("should be a relative path", attr = "import_prefix") + + return import_prefix + +def _get_strip_import_prefix(ctx): + """Gets and verifies strip_import_prefix.""" strip_import_prefix = ctx.attr.strip_import_prefix + if not paths.is_normalized(strip_import_prefix): fail("should be normalized (without uplevel references or '.' path segments)", attr = "strip_import_prefix") - if strip_import_prefix.startswith("/"): + + if paths.is_absolute(strip_import_prefix): strip_import_prefix = strip_import_prefix[1:] elif strip_import_prefix != "DO_NOT_STRIP": # Relative to current package strip_import_prefix = _join(ctx.label.package, strip_import_prefix) else: strip_import_prefix = "" - has_generated_sources = False - if ctx.fragments.proto.generated_protos_in_virtual_imports(): - has_generated_sources = any([not src.is_source for src in srcs]) + return strip_import_prefix - direct_sources = [] - if import_prefix != "" or strip_import_prefix != "" or has_generated_sources: - # Use virtual source roots - if paths.is_absolute(import_prefix): - fail("should be a relative path", attr = "import_prefix") +def _proto_library_impl(ctx): + semantics.preprocess(ctx) - virtual_imports = _join("_virtual_imports", ctx.label.name) - if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): # siblingRepositoryLayout - proto_path = _join(ctx.genfiles_dir.path, ctx.label.package, virtual_imports) - else: - proto_path = _join(ctx.genfiles_dir.path, ctx.label.workspace_root, ctx.label.package, virtual_imports) + # Verifies attributes. + _check_srcs_package(ctx.label.package, ctx.attr.srcs) + srcs = ctx.files.srcs + deps = [dep[ProtoInfo] for dep in ctx.attr.deps] + exports = [dep[ProtoInfo] for dep in ctx.attr.exports] + import_prefix = _get_import_prefix(ctx) + strip_import_prefix = _get_strip_import_prefix(ctx) - for src in srcs: - if ctx.label.workspace_name == "": - repository_relative_path = src.short_path - else: - repository_relative_path = paths.relativize(src.short_path, "../" + ctx.label.workspace_name) + proto_path, direct_sources = _create_proto_sources(ctx, srcs, import_prefix, strip_import_prefix) + descriptor_set = ctx.actions.declare_file(ctx.label.name + "-descriptor-set.proto.bin") + proto_info = _create_proto_info(ctx, direct_sources, deps, exports, proto_path, descriptor_set) + _write_descriptor_set(ctx, deps, proto_info, descriptor_set) + + # We assume that the proto sources will not have conflicting artifacts + # with the same root relative path + data_runfiles = ctx.runfiles( + files = [proto_info.direct_descriptor_set], + transitive_files = depset(transitive = [proto_info.transitive_sources]), + ) + return [ + proto_info, + DefaultInfo( + files = depset([proto_info.direct_descriptor_set]), + default_runfiles = ctx.runfiles(), # empty + data_runfiles = data_runfiles, + ), + ] - if not repository_relative_path.startswith(strip_import_prefix): - fail(".proto file '%s' is not under the specified strip prefix '%s'" % - (src.short_path, strip_import_prefix)) - import_path = repository_relative_path[len(strip_import_prefix):] +def _create_proto_sources(ctx, srcs, import_prefix, strip_import_prefix): + """Transforms Files in srcs to ProtoSources, optionally symlinking them to _virtual_imports. - virtual_src = ctx.actions.declare_file(_join(virtual_imports, import_prefix, import_path)) - ctx.actions.symlink( - output = virtual_src, - target_file = src, - progress_message = "Symlinking virtual .proto sources for %{label}", - ) - direct_sources.append(native_proto_common.ProtoSource(virtual_src, src, proto_path)) + Returns: + A pair proto_path, directs_sources. + """ + generate_protos_in_virtual_imports = False + if ctx.fragments.proto.generated_protos_in_virtual_imports(): + generate_protos_in_virtual_imports = any([not src.is_source for src in srcs]) + if import_prefix != "" or strip_import_prefix != "" or generate_protos_in_virtual_imports: + # Use virtual source roots + return _symlink_to_virtual_imports(ctx, srcs, import_prefix, strip_import_prefix) else: # No virtual source roots - proto_path = "." + direct_sources = [] for src in srcs: - direct_sources.append(native_proto_common.ProtoSource(src, src, ctx.label.workspace_root + src.root.path)) + if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): + # source_root == ''|'bazel-out/foo/k8-fastbuild/bin' + source_root = src.root.path + else: + # source_root == ''|'bazel-out/foo/k8-fastbuild/bin' / 'external/repo' + source_root = _join(src.root.path, ctx.label.workspace_root) + direct_sources.append(native_proto_common.ProtoSource(src, src, source_root)) + + return ctx.label.workspace_root if ctx.label.workspace_root else ".", direct_sources + +def _join(*path): + return "/".join([p for p in path if p != ""]) + +def _symlink_to_virtual_imports(ctx, srcs, import_prefix, strip_import_prefix): + """Symlinks srcs to _virtual_imports. + + Returns: + A pair proto_path, directs_sources. + """ + virtual_imports = _join("_virtual_imports", ctx.label.name) + if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): # siblingRepositoryLayout + # Example: `bazel-out/[repo/]target/bin / pkg / _virtual_imports/name` + proto_path = _join(ctx.genfiles_dir.path, ctx.label.package, virtual_imports) + else: + # Example: `bazel-out/target/bin / repo / pkg / _virtual_imports/name` + proto_path = _join(ctx.genfiles_dir.path, ctx.label.workspace_root, ctx.label.package, virtual_imports) + + direct_sources = [] + for src in srcs: + if ctx.label.workspace_name == "": + repository_relative_path = src.short_path + else: + # src.short_path = ../repo/pkg/a.proto + repository_relative_path = paths.relativize(src.short_path, "../" + ctx.label.workspace_name) + + # Remove strip_import_prefix + if not repository_relative_path.startswith(strip_import_prefix): + fail(".proto file '%s' is not under the specified strip prefix '%s'" % + (src.short_path, strip_import_prefix)) + import_path = repository_relative_path[len(strip_import_prefix):] + + # Add import_prefix + virtual_src = ctx.actions.declare_file(_join(virtual_imports, import_prefix, import_path)) + + ctx.actions.symlink( + output = virtual_src, + target_file = src, + progress_message = "Symlinking virtual .proto sources for %{label}", + ) + direct_sources.append(native_proto_common.ProtoSource(virtual_src, src, proto_path)) + return proto_path, direct_sources + +def _create_proto_info(ctx, direct_sources, deps, exports, proto_path, descriptor_set): + """Constructs ProtoInfo.""" # Construct ProtoInfo transitive_proto_sources = depset( @@ -112,9 +185,8 @@ def _create_proto_info(ctx): else: check_deps_sources = depset(transitive = [dep.check_deps_sources for dep in deps]) - direct_descriptor_set = ctx.actions.declare_file(ctx.label.name + "-descriptor-set.proto.bin") transitive_descriptor_sets = depset( - direct = [direct_descriptor_set], + direct = [descriptor_set], transitive = [dep.transitive_descriptor_sets for dep in deps], ) @@ -137,21 +209,19 @@ def _create_proto_info(ctx): transitive_proto_sources, transitive_proto_path, check_deps_sources, - direct_descriptor_set, + descriptor_set, transitive_descriptor_sets, exported_sources, strict_importable_sources, public_import_protos, ) -def _write_descriptor_set(ctx, proto_info): - descriptor_set = proto_info.direct_descriptor_set - +def _write_descriptor_set(ctx, deps, proto_info, descriptor_set): + """Writes descriptor set.""" if proto_info.direct_sources == []: ctx.actions.write(descriptor_set, "") return - deps = [dep[ProtoInfo] for dep in ctx.attr.deps] dependencies_descriptor_sets = depset(transitive = [dep.transitive_descriptor_sets for dep in deps]) args = [] @@ -170,29 +240,6 @@ def _write_descriptor_set(ctx, proto_info): additional_args = args, ) -def _proto_library_impl(ctx): - semantics.preprocess(ctx) - - _check_srcs_package(ctx.label.package, ctx.attr.srcs) - - proto_info = _create_proto_info(ctx) - - _write_descriptor_set(ctx, proto_info) - - data_runfiles = ctx.runfiles( - files = [proto_info.direct_descriptor_set], - transitive_files = depset(transitive = [proto_info.transitive_sources]), - ) - - return [ - proto_info, - DefaultInfo( - files = depset([proto_info.direct_descriptor_set]), - default_runfiles = ctx.runfiles(), # empty - data_runfiles = data_runfiles, - ), - ] - proto_library = rule( _proto_library_impl, attrs = dict({ diff --git a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java index a293f0be0b0596..a55418ae8ae4ee 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java +++ b/src/test/java/com/google/devtools/build/lib/rules/proto/BazelProtoLibraryTest.java @@ -351,7 +351,8 @@ public void testStripImportPrefixWithDeps() throws Exception { "."); } - private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throws Exception { + private void testExternalRepoWithGeneratedProto( + boolean siblingRepoLayout, boolean useVirtualImports) throws Exception { if (!isThisBazel()) { return; } @@ -361,6 +362,9 @@ private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throw if (siblingRepoLayout) { setBuildLanguageOptions("--experimental_sibling_repository_layout"); } + if (!useVirtualImports) { + useConfiguration("--noincompatible_generated_protos_in_virtual_imports"); + } invalidatePackages(); scratch.file("/foo/WORKSPACE"); @@ -369,7 +373,6 @@ private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throw TestConstants.LOAD_PROTO_LIBRARY, "proto_library(name='x', srcs=['generated.proto'])", "genrule(name='g', srcs=[], outs=['generated.proto'], cmd='')"); - scratch.file( "a/BUILD", TestConstants.LOAD_PROTO_LIBRARY, @@ -380,27 +383,42 @@ private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throw .getGenfilesFragment( siblingRepoLayout ? RepositoryName.create("@foo") : RepositoryName.MAIN) .toString(); + String fooProtoRoot; + if (useVirtualImports) { + fooProtoRoot = + genfiles + (siblingRepoLayout ? "" : "/external/foo") + "/x/_virtual_imports/x"; + } else { + fooProtoRoot = (siblingRepoLayout ? "../foo" : "external/foo"); + } ConfiguredTarget a = getConfiguredTarget("//a:a"); ProtoInfo aInfo = a.get(ProtoInfo.PROVIDER); - assertThat(aInfo.getTransitiveProtoSourceRoots().toList()) - .containsExactly( - ".", genfiles + (siblingRepoLayout ? "" : "/external/foo") + "/x/_virtual_imports/x"); + assertThat(aInfo.getTransitiveProtoSourceRoots().toList()).containsExactly(".", fooProtoRoot); ConfiguredTarget x = getConfiguredTarget("@foo//x:x"); ProtoInfo xInfo = x.get(ProtoInfo.PROVIDER); - assertThat(xInfo.getTransitiveProtoSourceRoots().toList()) - .containsExactly( - genfiles + (siblingRepoLayout ? "" : "/external/foo") + "/x/_virtual_imports/x"); + assertThat(xInfo.getTransitiveProtoSourceRoots().toList()).containsExactly(fooProtoRoot); } @Test public void testExternalRepoWithGeneratedProto_withSubdirRepoLayout() throws Exception { - testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ false); + testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ false, true); } @Test public void test_siblingRepoLayout_externalRepoWithGeneratedProto() throws Exception { - testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ true); + testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ true, true); + } + + @Test + public void testExternalRepoWithGeneratedProto_withSubdirRepoLayoutAndNoVritualImports() + throws Exception { + testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ false, false); + } + + @Test + public void test_siblingRepoLayout_externalRepoWithGeneratedProtoAndNoVritualImports() + throws Exception { + testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ true, false); } @Test