diff --git a/jmh/jmh.bzl b/jmh/jmh.bzl index 58ac7fc3e..ac8fe70f5 100644 --- a/jmh/jmh.bzl +++ b/jmh/jmh.bzl @@ -75,18 +75,22 @@ def jmh_repositories(maven_servers = ["http://central.maven.org/maven2"]): actual = "@io_bazel_rules_scala_org_apache_commons_commons_math3//jar", ) -def _scala_construct_runtime_classpath(deps): - files = [] - [files.append(target[JavaInfo].transitive_runtime_deps) for target in deps if JavaInfo in target] - - return depset(transitive = files) - def _scala_generate_benchmark(ctx): - class_jar = ctx.attr.src.scala.outputs.class_jar - classpath = _scala_construct_runtime_classpath([ctx.attr.src]) + # we use required providers to ensure JavaInfo exists + info = ctx.attr.src[JavaInfo] + # TODO, if we emit more than one jar, which scala_library does not, + # this might fail. We could possibly extend the BenchmarkGenerator + # to accept more than one jar to scan, and then allow multiple labels + # in ctx.attr.src + outs = info.outputs.jars + if len(outs) != 1: + print("expected exactly 1 output jar in: " + ctx.label) + # just try to take the first one and see if that works + class_jar = outs[0].class_jar + classpath = info.transitive_runtime_deps ctx.actions.run( outputs = [ctx.outputs.src_jar, ctx.outputs.resource_jar], - inputs = depset([class_jar], transitive = [classpath]), + inputs = classpath, executable = ctx.executable._generator, arguments = [ctx.attr.generator_type] + [ f.path @@ -99,7 +103,7 @@ def _scala_generate_benchmark(ctx): scala_generate_benchmark = rule( implementation = _scala_generate_benchmark, attrs = { - "src": attr.label(allow_single_file = True, mandatory = True), + "src": attr.label(mandatory = True, providers = [[JavaInfo]]), "generator_type": attr.string( default = "reflection", mandatory = False,