From 894603f1551e2c4ca36d621766055dde94cda2b1 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Sat, 7 Oct 2023 00:47:15 +0800 Subject: [PATCH] Improve RDC (#167) - Remove rdc output from cuda_library It is incorrect in the case of whole archive linking and prevent us from creating a shared library. - Better handling of rdc objects archiving logic - Make transitive rdc cuda_library correct --- cuda/private/cuda_helper.bzl | 19 ++++++- cuda/private/providers.bzl | 23 ++++++-- cuda/private/rules/cuda_library.bzl | 87 ++++++++++++++++++++--------- cuda/private/rules/cuda_objects.bzl | 35 ++++++++---- 4 files changed, 122 insertions(+), 42 deletions(-) diff --git a/cuda/private/cuda_helper.bzl b/cuda/private/cuda_helper.bzl index 13cd9456..b835d806 100644 --- a/cuda/private/cuda_helper.bzl +++ b/cuda/private/cuda_helper.bzl @@ -308,7 +308,18 @@ def _create_common(ctx): transitive_linking_contexts = transitive_linking_contexts, ) -def _create_cuda_info(defines = None, objects = None, rdc_objects = None, pic_objects = None, rdc_pic_objects = None): +def _create_cuda_info( + defines = None, + objects = None, + rdc_objects = None, + pic_objects = None, + rdc_pic_objects = None, + archive_objects = None, + archive_rdc_objects = None, + archive_pic_objects = None, + archive_rdc_pic_objects = None, + dlink_rdc_objects = None, + dlink_rdc_pic_objects = None): """Constructor for `CudaInfo`. See the providers documentation for detail.""" ret = CudaInfo( defines = defines if defines != None else depset([]), @@ -316,6 +327,12 @@ def _create_cuda_info(defines = None, objects = None, rdc_objects = None, pic_ob rdc_objects = rdc_objects if rdc_objects != None else depset([]), pic_objects = pic_objects if pic_objects != None else depset([]), rdc_pic_objects = rdc_pic_objects if rdc_pic_objects != None else depset([]), + archive_objects = archive_objects if archive_objects != None else depset([]), + archive_rdc_objects = archive_rdc_objects if archive_rdc_objects != None else depset([]), + archive_pic_objects = archive_pic_objects if archive_pic_objects != None else depset([]), + archive_rdc_pic_objects = archive_rdc_pic_objects if archive_rdc_pic_objects != None else depset([]), + dlink_rdc_objects = dlink_rdc_objects if dlink_rdc_objects != None else depset([]), + dlink_rdc_pic_objects = dlink_rdc_pic_objects if dlink_rdc_pic_objects != None else depset([]), ) return ret diff --git a/cuda/private/providers.bzl b/cuda/private/providers.bzl index 1c7baad4..cdd49ef2 100644 --- a/cuda/private/providers.bzl +++ b/cuda/private/providers.bzl @@ -59,13 +59,26 @@ CudaInfo = provider( """Provides cuda build artifacts that can be consumed by device linking or linking process. This provider is analog to [CcInfo](https://bazel.build/rules/lib/CcInfo) but only contains necessary information for -linking in a flat structure.""", +linking in a flat structure. Objects are grouped by direct and transitive, because we have no way to split them again +if merged a single depset. +""", fields = { "defines": "A depset of strings. It is used for the compilation during device linking.", - "objects": "A depset of objects.", # but not rdc and pic - "rdc_objects": "A depset of relocatable device code objects.", # but not pic - "pic_objects": "A depset of position indepentent code objects.", # but not rdc - "rdc_pic_objects": "A depset of relocatable device code and position indepentent code objects.", + # direct only: + "objects": "A depset of objects. Direct artifacts of the rule.", # but not rdc and pic + "pic_objects": "A depset of position indepentent code objects. Direct artifacts of the rule.", # but not rdc + "rdc_objects": "A depset of relocatable device code objects. Direct artifacts of the rule.", # but not pic + "rdc_pic_objects": "A depset of relocatable device code and position indepentent code objects. Direct artifacts of the rule.", + # transitive archive only (cuda_objects): + "archive_objects": "A depset of rdc objects. cuda_objects only. Gathered from the transitive dependencies for archiving.", + "archive_pic_objects": "A depset of rdc pic objects. cuda_objects only. Gathered from the transitive dependencies for archiving.", + "archive_rdc_objects": "A depset of rdc objects. cuda_objects only. Gathered from the transitive dependencies for archiving or device linking.", + "archive_rdc_pic_objects": "A depset of rdc pic objects. cuda_objects only. Gathered from the transitive dependencies for archiving or device linking.", + + # transitive dlink only (cuda_library): + # NOTE: ideally, we can use the archived library to do the device linking, but the nvlink is not happy with library with *_dlink.o included + "dlink_rdc_objects": "A depset of rdc objects. cuda_library only. Gathered from the transitive dependencies for device linking.", + "dlink_rdc_pic_objects": "A depset of rdc pic objects. cuda_library only. Gathered from the transitive dependencies for device linking.", }, ) diff --git a/cuda/private/rules/cuda_library.bzl b/cuda/private/rules/cuda_library.bzl index 5b764075..433fea26 100644 --- a/cuda/private/rules/cuda_library.bzl +++ b/cuda/private/rules/cuda_library.bzl @@ -28,24 +28,59 @@ def _cuda_library_impl(ctx): for src in ctx.attr.srcs: src_files.extend(src[DefaultInfo].files.to_list()) - # outputs - objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = use_rdc)) - pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = use_rdc)) - rdc_objects = depset([]) - rdc_pic_objects = depset([]) + # merge deps' direct objects and archive objects as our archive objects + archive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_objects for dep in attr.deps if CudaInfo in dep]) + archive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_pic_objects for dep in attr.deps if CudaInfo in dep]) + archive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_rdc_objects for dep in attr.deps if CudaInfo in dep]) + archive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_rdc_pic_objects for dep in attr.deps if CudaInfo in dep]) - # if rdc is enabled for this cuda_library, then we need futher do a pass of device link + # Gather transitive dlink objects that may come from other `cuda_library`s + dlink_rdc_objects = depset(transitive = [dep[CudaInfo].dlink_rdc_objects for dep in attr.deps if CudaInfo in dep]) + dlink_rdc_pic_objects = depset(transitive = [dep[CudaInfo].dlink_rdc_pic_objects for dep in attr.deps if CudaInfo in dep]) + + # direct outputs + objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False)) if not use_rdc else depset([]) + pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False)) if not use_rdc else depset([]) + rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True)) if use_rdc else depset([]) + rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True)) if use_rdc else depset([]) + + # if rdc is enabled for this `cuda_library`, then we need to do a pass of device link further. + rdc_dlink_inputs = None + rdc_pic_dlink_inputs = None if use_rdc: - transitive_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep]) - transitive_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep]) - objects = depset(transitive = [objects, transitive_objects]) - rdc_objects = objects - pic_objects = depset(transitive = [pic_objects, transitive_pic_objects]) - rdc_pic_objects = pic_objects - dlink_object = depset([device_link(ctx, cuda_toolchain, cc_toolchain, objects, common, pic = False, rdc = use_rdc)]) - dlink_pic_object = depset([device_link(ctx, cuda_toolchain, cc_toolchain, pic_objects, common, pic = True, rdc = use_rdc)]) - objects = depset(transitive = [objects, dlink_object]) - pic_objects = depset(transitive = [pic_objects, dlink_pic_object]) + # TODO: Switch to explicit dlink with attr `dlink=True`, then add support dlink with libraries. At the moment, + # all libraries produced by this rule with `rdc=True` will have an _dlink..o archived, and nvlink + # refuses to consume such libraries and ignores them silently. + + # prepare inputs for device_link, take use_rdc=True and non-pic as an example: + # rdc_objects: produce with this rule + # archive_rdc_objects: propagate from other `cuda_objects` + # dlink_rdc_objects: propagate from other `cuda_library`s + rdc_dlink_inputs = depset(transitive = [rdc_objects, archive_rdc_objects, dlink_rdc_objects]) + rdc_pic_dlink_inputs = depset(transitive = [rdc_pic_objects, archive_rdc_pic_objects, dlink_rdc_pic_objects]) + + rdc_dlink_output = depset([device_link(ctx, cuda_toolchain, cc_toolchain, rdc_dlink_inputs, common, pic = False, rdc = True)]) + rdc_pic_dlink_output = depset([device_link(ctx, cuda_toolchain, cc_toolchain, rdc_pic_dlink_inputs, common, pic = True, rdc = True)]) + + # update the **direct** outputs + rdc_objects = depset(transitive = [rdc_objects, rdc_dlink_output]) + rdc_pic_objects = depset(transitive = [rdc_pic_objects, rdc_pic_dlink_output]) + + # objects to archive: objects directly outputed by this rule and all objects transitively from deps, + # take use_rdc=True and non-pic as an example: + # rdc_objects: produce with this rule, thus it must be archived in the library produced by this rule + # archive_rdc_objects: propagate from other `cuda_objects`, so this rule is in charge of archiving them + # dlink_rdc_objects is NOT included! + if not use_rdc: + archive_content = depset(transitive = [objects, archive_objects]) + pic_archive_content = depset(transitive = [pic_objects, archive_pic_objects]) + else: + archive_content = depset(transitive = [rdc_objects, archive_rdc_objects]) + pic_archive_content = depset(transitive = [rdc_pic_objects, archive_rdc_pic_objects]) compilation_ctx = cc_common.create_compilation_context( headers = common.headers, @@ -67,7 +102,7 @@ def _cuda_library_impl(ctx): actions = ctx.actions, feature_configuration = cc_feature_config, cc_toolchain = cc_toolchain, - compilation_outputs = cc_common.create_compilation_outputs(objects = objects, pic_objects = pic_objects), + compilation_outputs = cc_common.create_compilation_outputs(objects = archive_content, pic_objects = pic_archive_content), user_link_flags = common.host_link_flags, alwayslink = attr.alwayslink, linking_contexts = common.transitive_linking_contexts, @@ -82,7 +117,10 @@ def _cuda_library_impl(ctx): libs = [] if lib == None else [lib] pic_libs = [] if pic_lib == None else [pic_lib] - cc_info = cc_common.merge_cc_infos(direct_cc_infos = [CcInfo(compilation_context = compilation_ctx, linking_context = linking_ctx)], cc_infos = [common.transitive_cc_info]) + cc_info = cc_common.merge_cc_infos( + direct_cc_infos = [CcInfo(compilation_context = compilation_ctx, linking_context = linking_ctx)], + cc_infos = [common.transitive_cc_info], + ) return [ DefaultInfo(files = depset(libs + pic_libs)), @@ -100,10 +138,9 @@ def _cuda_library_impl(ctx): ), cuda_helper.create_cuda_info( defines = depset(common.defines), - objects = objects, - pic_objects = pic_objects, - rdc_objects = rdc_objects, - rdc_pic_objects = rdc_pic_objects, + # all objects from cuda_objects should be properly archived, thus, the transitivity of archive is cut off here. + dlink_rdc_objects = rdc_dlink_inputs, + dlink_rdc_pic_objects = rdc_pic_dlink_inputs, ), ] @@ -118,10 +155,8 @@ cuda_library = rule( "alwayslink": attr.bool(default = False), "rdc": attr.bool( default = False, - doc = ("Whether to produce and consume relocateable device code. " + - "Transitive deps that contain device code must all either be cuda_objects or cuda_library(rdc = True). " + - "If False, all device code must be in the same translation unit. May have performance implications. " + - "See https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#using-separate-compilation-in-cuda."), + doc = ("Whether to perform device linking for relocateable device code. " + + "Transitive deps that contain device code must all either be cuda_objects or cuda_library(rdc = True)."), ), "includes": attr.string_list(doc = "List of include dirs to be added to the compile line."), "host_copts": attr.string_list(doc = "Add these options to the CUDA host compilation command."), diff --git a/cuda/private/rules/cuda_objects.bzl b/cuda/private/rules/cuda_objects.bzl index c84b48b3..adc2fa6f 100644 --- a/cuda/private/rules/cuda_objects.bzl +++ b/cuda/private/rules/cuda_objects.bzl @@ -19,16 +19,21 @@ def _cuda_objects_impl(ctx): for src in ctx.attr.srcs: src_files.extend(src[DefaultInfo].files.to_list()) - transitive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep]) - transitive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep]) - transitive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep]) - transitive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep]) + # merge deps' direct objects and archive objects as our archive objects + archive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_objects for dep in attr.deps if CudaInfo in dep]) + archive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_pic_objects for dep in attr.deps if CudaInfo in dep]) + archive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_rdc_objects for dep in attr.deps if CudaInfo in dep]) + archive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep] + + [dep[CudaInfo].archive_rdc_pic_objects for dep in attr.deps if CudaInfo in dep]) - # outputs - objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False), transitive = [transitive_objects]) - rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True), transitive = [transitive_rdc_objects]) - pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False), transitive = [transitive_pic_objects]) - rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True), transitive = [transitive_rdc_pic_objects]) + # direct outputs + objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False)) + pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False)) + rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True)) + rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True)) compilation_ctx = cc_common.create_compilation_context( headers = common.headers, @@ -39,6 +44,11 @@ def _cuda_objects_impl(ctx): local_defines = depset(common.host_local_defines), ) + cc_info = cc_common.merge_cc_infos( + direct_cc_infos = [CcInfo(compilation_context = compilation_ctx)], + cc_infos = [common.transitive_cc_info], + ) + return [ # default output is only enabled for rdc_objects, otherwise, when you build with # @@ -60,7 +70,8 @@ def _cuda_objects_impl(ctx): rdc_pic_objects = rdc_pic_objects, ), CcInfo( - compilation_context = compilation_ctx, + compilation_context = cc_info.compilation_context, + linking_context = cc_info.linking_context, ), cuda_helper.create_cuda_info( defines = depset(common.defines), @@ -68,6 +79,10 @@ def _cuda_objects_impl(ctx): pic_objects = pic_objects, rdc_objects = rdc_objects, rdc_pic_objects = rdc_pic_objects, + archive_objects = archive_objects, + archive_pic_objects = archive_pic_objects, + archive_rdc_objects = archive_rdc_objects, + archive_rdc_pic_objects = archive_rdc_pic_objects, ), ]