Skip to content

Commit

Permalink
PR #2843: Switch to use compiler param file for msvc
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#2843

Merging this change closes #2843

PiperOrigin-RevId: 536435983
  • Loading branch information
cloudhan authored and copybara-github committed May 30, 2023
1 parent e25e8d9 commit 29e0ab8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
7 changes: 1 addition & 6 deletions third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def _features(cpu, compiler, ctx):
]
elif cpu == "x64_windows":
return [
feature(name = "compiler_param_file"),
feature(name = "no_legacy_features"),
feature(
name = "common_flags",
Expand All @@ -623,12 +624,6 @@ def _features(cpu, compiler, ctx):
flag_set(
actions = all_compile_actions(),
flag_groups = [
flag_group(
flags = [
"-B",
"external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py",
],
),
_nologo(),
flag_group(
flags = [
Expand Down
33 changes: 33 additions & 0 deletions third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,29 @@ def InvokeNvcc(argv, log=False):
proc.wait()
return proc.returncode

def ExpandParamsFileForArgv():
new_argv = []
for arg in sys.argv:
if arg.startswith("@"):
with open(arg.strip("@")) as f:
new_argv.extend([l.strip() for l in f.readlines()])
else:
new_argv.append(arg)

sys.argv = new_argv

def ProcessFlagForCommandFile(flag):
if flag.startswith("/D") or flag.startswith("-D"):
# We need to re-escape /DFOO="BAR" as /DFOO=\"BAR\", so that we get
# `#define FOO "BAR"` after expansion as a string literal define
if flag.endswith('"') and not flag.endswith('\\"'):
flag = '\\"'.join(flag.split('"', 1))
flag = '\\"'.join(flag.rsplit('"', 1))
return flag
return flag

def main():
ExpandParamsFileForArgv()
parser = ArgumentParser()
parser.add_argument('-x', nargs=1)
parser.add_argument('--cuda_log', action='store_true')
Expand All @@ -212,7 +234,18 @@ def main():
cpu_compiler_flags = [flag for flag in sys.argv[1:]
if not flag.startswith(('--cuda_log'))
and not flag.startswith(('-nvcc_options'))]
output = [flag for flag in cpu_compiler_flags if flag.startswith("/Fo")]

# Store command line options in a file to avoid hitting the character limit.
if len(output) == 1:
commandfile_path = output[0][3:] + ".msvc_params"
commandfile = open(commandfile_path, "w")
cpu_compiler_flags = [ProcessFlagForCommandFile(flag) for flag in cpu_compiler_flags]
commandfile.write("\n".join(cpu_compiler_flags))
commandfile.close()
return subprocess.call([CPU_COMPILER, "@" + commandfile_path])
else:
return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)

if __name__ == '__main__':
Expand Down
8 changes: 7 additions & 1 deletion third_party/gpus/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _get_win_cuda_defines(repository_ctx):
),
)

msvc_cl_path = get_python_bin(repository_ctx)
msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat"
msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
"\\",
"/",
Expand Down Expand Up @@ -1301,6 +1301,12 @@ def _create_local_cuda_repository(repository_ctx):
tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"],
wrapper_defines,
)
repository_ctx.file(
"crosstool/windows/msvc_wrapper_for_nvcc.bat",
content = "@echo OFF\n{} -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py %*".format(
get_python_bin(repository_ctx),
),
)
repository_ctx.template(
"crosstool/windows/msvc_wrapper_for_nvcc.py",
tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"],
Expand Down

0 comments on commit 29e0ab8

Please sign in to comment.