Skip to content
This repository has been archived by the owner on Aug 5, 2024. It is now read-only.

Commit

Permalink
[feature] Resolve Scala SDK independently for each target (#552)
Browse files Browse the repository at this point in the history
* Pre-change refactoring

* Resolve Scala SDK independently for each target

This prepares us for changes in `rules_scala` that will allow customizing the Scala version for each target.
In order to achieve that, we can no longer resolve the Scala SDK globally as the maximal version used. We need to use per-target info.
Scala SDK will be still discovered based on the compiler class path.
Now though we will look into an implicit dependency of each target, the `_scalac`.

This change is backward-compatible, as the mentioned data is already available.
It is also forward-compatible with the anticipated cross-build feature of `rules_scala`.
(see: bazelbuild/rules_scala#1290)

The aspect will produce additional data – namely few compiler classpath jars per Scala target.
Perhaps we could review this change in the future to normalize the data back. Now it's not possible, as the data about alternative configurations (here: of scalac) is discarded in the server.

* Cleanup after

* Output "external-deps-resolve" again

* Code style fix

* Remove `_scalac` again from compile deps
  • Loading branch information
aszady authored Apr 18, 2024
1 parent 45d83ef commit dc332df
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 87 deletions.
2 changes: 0 additions & 2 deletions aspects/core.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ COMPILE_DEPS = [

PRIVATE_COMPILE_DEPS = [
"_java_toolchain",
"_scala_toolchain",
"_scalac",
"_jvm",
"runtime_jdk",
]
Expand Down
33 changes: 11 additions & 22 deletions aspects/rules/scala/scala_info.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,27 @@ def find_scalac_classpath(runfiles):
result.append(file)
return result if found_scala_compiler_jar and len(result) >= 2 else []

def extract_scala_toolchain_info(target, ctx, output_groups, **kwargs):
runfiles = target.default_runfiles.files.to_list()

classpath = find_scalac_classpath(runfiles)

if not classpath:
return None, None

resolve_files = classpath
compiler_classpath = map(file_location, classpath)

if (is_external(target)):
update_sync_output_groups(output_groups, "external-deps-resolve", depset(resolve_files))

scala_toolchain_info = struct(compiler_classpath = compiler_classpath)

return create_proto(target, ctx, scala_toolchain_info, "scala_toolchain_info"), None

def extract_scala_info(target, ctx, output_groups, **kwargs):
kind = ctx.rule.kind
if not kind.startswith("scala_") and not kind.startswith("thrift_"):
return None, None

SCALA_TOOLCHAIN = "@io_bazel_rules_scala//scala:toolchain_type"

scala_info = {}

# check of _scala_toolchain is necessary, because SCALA_TOOLCHAIN will always be present
if hasattr(ctx.rule.attr, "_scala_toolchain"):
common_scalac_opts = ctx.toolchains[SCALA_TOOLCHAIN].scalacopts
if hasattr(ctx.rule.attr, "_scalac"):
scalac = ctx.rule.attr._scalac
compiler_classpath = find_scalac_classpath(scalac.default_runfiles.files.to_list())
if compiler_classpath:
scala_info["compiler_classpath"] = map(file_location, compiler_classpath)
if is_external(scalac):
update_sync_output_groups(output_groups, "external-deps-resolve", depset(compiler_classpath))
else:
common_scalac_opts = []
scalac_opts = common_scalac_opts + getattr(ctx.rule.attr, "scalacopts", [])

scala_info = struct(scalac_opts = scalac_opts)
scala_info["scalac_opts"] = common_scalac_opts + getattr(ctx.rule.attr, "scalacopts", [])

return create_proto(target, ctx, scala_info, "scala_target_info"), None
return create_proto(target, ctx, struct(**scala_info), "scala_target_info"), None
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ enum class Language(private val fileName: String, val ruleNames: List<String>, v
Java("//aspects:rules/java/java_info.bzl", listOf(), listOf("extract_java_toolchain", "extract_java_runtime"), false),
Jvm("//aspects:rules/jvm/jvm_info.bzl", listOf(), listOf("extract_jvm_info"), false),
Python("//aspects:rules/python/python_info.bzl", listOf(), listOf("extract_python_info"), false),
Scala("//aspects:rules/scala/scala_info.bzl", listOf("io_bazel_rules_scala"), listOf("extract_scala_info", "extract_scala_toolchain_info"), false),
Scala("//aspects:rules/scala/scala_info.bzl", listOf("io_bazel_rules_scala"), listOf("extract_scala_info"), false),
Cpp("//aspects:rules/cpp/cpp_info.bzl", listOf("rules_cc"), listOf("extract_cpp_info"), false),
Kotlin("//aspects:rules/kt/kt_info.bzl", listOf("io_bazel_rules_kotlin", "rules_kotlin"), listOf("extract_kotlin_info"), true),
Rust("//aspects:rules/rust/rust_info.bzl", listOf("rules_rust"), listOf("extract_rust_crate_info"), false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,29 +179,27 @@ class BazelProjectMapper(
private fun calculateScalaLibrariesMapper(targetsToImport: Sequence<TargetInfo>): Map<String, List<Library>> {
val projectLevelScalaSdkLibraries = calculateProjectLevelScalaLibraries()
val scalaTargets = targetsToImport.filter { it.hasScalaTargetInfo() }.map { it.id }
return projectLevelScalaSdkLibraries
?.let { libraries -> scalaTargets.associateWith { libraries } }
.orEmpty()
return scalaTargets.associateWith {
languagePluginsService.scalaLanguagePlugin.scalaSdks[it]?.compilerJars?.mapNotNull {
projectLevelScalaSdkLibraries[it]
}.orEmpty()
}
}

private fun calculateProjectLevelScalaLibraries(): List<Library>? {
val scalaSdkLibrariesJars = getProjectLevelScalaSdkLibrariesJars()
return if (scalaSdkLibrariesJars.isNotEmpty()) {
scalaSdkLibrariesJars.map {
Library(
label = Paths.get(it).name,
outputs = setOf(it),
sources = emptySet(),
dependencies = emptyList()
)
}
} else null
}
private fun calculateProjectLevelScalaLibraries(): Map<URI, Library> =
getProjectLevelScalaSdkLibrariesJars().associateWith {
Library(
label = Paths.get(it).name,
outputs = setOf(it),
sources = emptySet(),
dependencies = emptyList()
)
}

private fun getProjectLevelScalaSdkLibrariesJars(): Set<URI> =
languagePluginsService.scalaLanguagePlugin.scalaSdk
?.compilerJars
?.toSet().orEmpty()
languagePluginsService.scalaLanguagePlugin.scalaSdks.values.toSet().flatMap {
it.compilerJars
}.toSet()

private fun calculateAndroidLibrariesMapper(targetsToImport: Sequence<TargetInfo>): Map<String, List<Library>> =
targetsToImport.mapNotNull { target ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import org.jetbrains.bsp.bazel.info.BspTargetInfo.KotlinTargetInfo
import org.jetbrains.bsp.bazel.info.BspTargetInfo.PythonTargetInfo
import org.jetbrains.bsp.bazel.info.BspTargetInfo.RustCrateInfo
import org.jetbrains.bsp.bazel.info.BspTargetInfo.ScalaTargetInfo
import org.jetbrains.bsp.bazel.info.BspTargetInfo.ScalaToolchainInfo
import org.jetbrains.bsp.bazel.info.BspTargetInfo.TargetInfo
import java.net.URI
import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -78,12 +77,6 @@ class TargetInfoReader {
targetInfoBuilder.setScalaTargetInfo(info)
}

"scala_toolchain_info" -> {
val builder = readFromFile(path, ScalaToolchainInfo.newBuilder())
val info = builder.build()
targetInfoBuilder.setScalaToolchainInfo(info)
}

"kotlin_target_info" -> {
val builder = readFromFile(path, KotlinTargetInfo.newBuilder())
val info = builder.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,28 @@ class ScalaLanguagePlugin(
private val bazelPathsResolver: BazelPathsResolver
) : LanguagePlugin<ScalaModule>() {

var scalaSdk: ScalaSdk? = null
var scalaSdks: Map<String, ScalaSdk> = emptyMap()

override fun prepareSync(targets: Sequence<BspTargetInfo.TargetInfo>) {
scalaSdk = ScalaSdkResolver(bazelPathsResolver).resolve(targets)
scalaSdks = targets.associateBy(
{ it.id },
ScalaSdkResolver(bazelPathsResolver)::resolveSdk
).filterValuesNotNull()
}

private fun <K, V> Map<K, V?>.filterValuesNotNull(): Map<K, V> =
filterValues { it != null }.mapValues { it.value!! }

override fun resolveModule(targetInfo: BspTargetInfo.TargetInfo): ScalaModule? {
if (!targetInfo.hasScalaTargetInfo()) {
return null
}
val scalaTargetInfo = targetInfo.scalaTargetInfo
val sdk = getScalaSdkOrThrow()
val sdk = scalaSdks[targetInfo.id] ?: return null
val scalacOpts = scalaTargetInfo.scalacOptsList
return ScalaModule(sdk, scalacOpts, javaLanguagePlugin.resolveModule(targetInfo))
}

private fun getScalaSdkOrThrow(): ScalaSdk =
scalaSdk ?: throw RuntimeException("Failed to resolve Scala SDK for project")

override fun dependencySources(
targetInfo: BspTargetInfo.TargetInfo,
dependencyGraph: DependencyGraph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,16 @@ import org.jetbrains.bsp.bazel.info.BspTargetInfo
import org.jetbrains.bsp.bazel.server.paths.BazelPathsResolver
import java.nio.file.Path
import java.util.regex.Pattern
import kotlin.math.min

class ScalaSdkResolver(private val bazelPathsResolver: BazelPathsResolver) {

fun resolve(targets: Sequence<BspTargetInfo.TargetInfo>): ScalaSdk? =
targets
.mapNotNull(::resolveSdk)
.distinct()
.sortedWith(SCALA_VERSION_COMPARATOR)
.lastOrNull()

private fun resolveSdk(targetInfo: BspTargetInfo.TargetInfo): ScalaSdk? {
if (!targetInfo.hasScalaToolchainInfo()) {
fun resolveSdk(targetInfo: BspTargetInfo.TargetInfo): ScalaSdk? {
if (!targetInfo.hasScalaTargetInfo()) {
return null
}
val scalaToolchain = targetInfo.scalaToolchainInfo
val scalaTarget = targetInfo.scalaTargetInfo
val compilerJars =
bazelPathsResolver.resolvePaths(scalaToolchain.compilerClasspathList).sorted()
bazelPathsResolver.resolvePaths(scalaTarget.compilerClasspathList).sorted()
val maybeVersions = compilerJars.mapNotNull(::extractVersion)
if (maybeVersions.none()) {
return null
Expand All @@ -46,17 +38,6 @@ class ScalaSdkResolver(private val bazelPathsResolver: BazelPathsResolver) {
version.split("\\.".toRegex()).toTypedArray().take(2).joinToString(".")

companion object {
private val SCALA_VERSION_COMPARATOR = Comparator { a: ScalaSdk, b: ScalaSdk ->
val aParts = a.version.split("\\.".toRegex()).toTypedArray()
val bParts = b.version.split("\\.".toRegex()).toTypedArray()
var i = 0
while (i < min(aParts.size, bParts.size)) {
val result = aParts[i].toInt().compareTo(bParts[i].toInt())
if (result != 0) return@Comparator result
i++
}
0
}
private val VERSION_PATTERN =
Pattern.compile("(?:processed_)?scala3?-(?:library|compiler|reflect)(?:_3)?-([.\\d]+)\\.jar")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ message JavaRuntimeInfo {

message ScalaTargetInfo {
repeated string scalac_opts = 1;
}

message ScalaToolchainInfo {
repeated FileLocation compiler_classpath = 1;
repeated FileLocation compiler_classpath = 2;
}

message CppTargetInfo {
Expand Down Expand Up @@ -125,7 +122,6 @@ message TargetInfo {
JavaToolchainInfo java_toolchain_info = 2000;
JavaRuntimeInfo java_runtime_info = 3000;
ScalaTargetInfo scala_target_info = 4000;
ScalaToolchainInfo scala_toolchain_info = 5000;
CppTargetInfo cpp_target_info = 6000;
KotlinTargetInfo kotlin_target_info = 7000;
PythonTargetInfo python_target_info = 8000;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class BazelBspLanguageExtensionsGeneratorTest {
load("//aspects:rules/python/python_info.bzl","extract_python_info")
load("//aspects:rules/cpp/cpp_info.bzl","extract_cpp_info")
load("//aspects:rules/kt/kt_info.bzl","extract_kotlin_info")
load("//aspects:rules/scala/scala_info.bzl","extract_scala_info","extract_scala_toolchain_info")
EXTENSIONS=[extract_java_toolchain,extract_java_runtime,extract_jvm_info,extract_python_info,extract_cpp_info,extract_kotlin_info,extract_scala_info,extract_scala_toolchain_info]
load("//aspects:rules/scala/scala_info.bzl","extract_scala_info")
EXTENSIONS=[extract_java_toolchain,extract_java_runtime,extract_jvm_info,extract_python_info,extract_cpp_info,extract_kotlin_info,extract_scala_info]
TOOLCHAINS=["@bazel_tools//tools/jdk:runtime_toolchain_type","@io_bazel_rules_kotlin//kotlin/internal:kt_toolchain_type","@io_bazel_rules_scala//scala:toolchain_type"]
""".replace(" ", "").replace("\n", "")
private val defaultRuleLanguages =
Expand Down

0 comments on commit dc332df

Please sign in to comment.