Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build Fail with "fatal error LNK1169: one or more multiply defined symbols found" #14165

Closed
adam-hartshorne opened this issue Jan 26, 2023 · 7 comments
Labels
bug Something isn't working

Comments

@adam-hartshorne
Copy link

Description

Using the following build command

python .\build\build.py --enable_cuda --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" --cuda_compute_capabilities="7.5" --cuda_version="11.7" --cudnn_version="8.4.0" --noenable_rocm --noenable_tpu

Build fails at the linking stage with the following error message.

b'LINK : warning LNK4044: unrecognized option \'/lm\'; ignored\r\nffi.lib(ffi.obj) : error LNK2005: "struct XLA_FFI_Stream * __cdecl xla::runtime::ffi::GetXlaFfiStream(class xla::runtime::PtrMapByType<class xla::runtime::CustomCall,16> const *,class xla::runtime::DiagnosticEngine const *)" (?GetXlaFfiStream@ffi@runtime@xla@@YAPEAUXLA_FFI_Stream@@PEBV?$PtrMapByType@VCustomCall@runtime@xla@@$0BA@@23@PEBVDiagnosticEngine@23@@Z) already defined in executable.lib(executable.obj)\r\n Creating library bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.lib and object bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.exp\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'send_recv.lib(send_recv.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'llvm_gpu_backend.lib(gpu_backend_lib.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'transpose.lib(transpose.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'bfc_allocator.lib(bfc_allocator.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'nvptx_compiler_impl.lib(nvptx_compiler.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_compiler.lib(gpu_compiler.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'cpu_runtime.lib(cpu_runtime.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_helpers.lib(gpu_helpers.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjrt_stream_executor_client.lib(pjrt_stream_executor_client.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'local_device_state.lib(local_device_state.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'profiler.lib(profiler.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'outfeed_receiver.lib(outfeed_receiver.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'py_client.lib(py_values.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'tfrt_cpu_pjrt_client.lib(tfrt_cpu_pjrt_client.obj)\'\r\nLINK : warning LNK4217: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'allocator_registry_impl.lo.lib(cpu_allocator_impl.obj)\' in function \'"public: static void __cdecl tsl::profiler::TraceMe::InstantActivity<class <lambda_869dba8525ee1b44cc1026465438dec2>,1>(class <lambda_869dba8525ee1b44cc1026465438dec2> &&,int)" (??$InstantActivity@V<lambda_869dba8525ee1b44cc1026465438dec2>@@$00@TraceMe@profiler@tsl@@SAX$$QEAV<lambda_869dba8525ee1b44cc1026465438dec2>@@H@Z)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pmap_lib.lib(pmap_lib.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjit.lib(pjit.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'jax_jit.lib(jax_jit.obj)\'\r\nLINK : warning LNK4217: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\' in function \'"public: __cdecl tsl::profiler::ScopedAnnotation::ScopedAnnotation<class <lambda_abcddd76c82971d27d2956bbd819e43a> >(class <lambda_abcddd76c82971d27d2956bbd819e43a>)" (??$?0V<lambda_abcddd76c82971d27d2956bbd819e43a>@@@ScopedAnnotation@profiler@tsl@@QEAA@V<lambda_abcddd76c82971d27d2956bbd819e43a>@@@Z)\'\r\nLINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(sequential_thunk.obj)\'\r\nLINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'tracing.lib(tracing.obj)\'\r\nbazel-out\\x64_windows-opt\\bin\\external\\org_tensorflow\\tensorflow\\compiler\\xla\\python\\xla_extension.so : fatal error LNK1169: one or more multiply defined symbols found\r\n'

What jax/jaxlib version are you using?

jaxlib v0.4.2, jax 0.4.2

Which accelerator(s) are you using?

GPU

Additional system info

Windows 10, Python 3.9, Cuda 11.7, Cudnn 8.4.0

NVIDIA GPU info

No response

@adam-hartshorne adam-hartshorne added the bug Something isn't working label Jan 26, 2023
@zhangqiaorjc
Copy link
Collaborator

@ezhulenev could this be due to the new XLA runtime for CPU or GPU. i see "xla::runtime::ffi::GetXlaFfiStream" symbol.

Is there some c++ language feature we used that broke Windows build?

@zhangqiaorjc
Copy link
Collaborator

@adam-hartshorne could you make a more readable error report? something with proper newlines?

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Jan 27, 2023

LINK : warning LNK4044: unrecognized option \'/lm\'; ignored
ffi.lib(ffi.obj) : error LNK2005: "struct XLA_FFI_Stream * __cdecl xla::runtime::ffi::GetXlaFfiStream(class xla::runtime::PtrMapByType<class xla::runtime::CustomCall,16> const *,class xla::runtime::DiagnosticEngine const *)" (?GetXlaFfiStream@ffi@runtime@xla@@YAPEAUXLA_FFI_Stream@@PEBV?$PtrMapByType@VCustomCall@runtime@xla@@$0BA@@23@PEBVDiagnosticEngine@23@@Z) already defined in executable.lib(executable.obj)   
Creating library bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.lib and object bazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so.if.exp
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'send_recv.lib(send_recv.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'llvm_gpu_backend.lib(gpu_backend_lib.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'transpose.lib(transpose.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'bfc_allocator.lib(bfc_allocator.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'nvptx_compiler_impl.lib(nvptx_compiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_compiler.lib(gpu_compiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'cpu_runtime.lib(cpu_runtime.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'gpu_helpers.lib(gpu_helpers.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjrt_stream_executor_client.lib(pjrt_stream_executor_client.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'local_device_state.lib(local_device_state.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'profiler.lib(profiler.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'outfeed_receiver.lib(outfeed_receiver.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'py_client.lib(py_values.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'tfrt_cpu_pjrt_client.lib(tfrt_cpu_pjrt_client.obj)\'
LINK : warning LNK4217: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'allocator_registry_impl.lo.lib(cpu_allocator_impl.obj)\' in function \'"public: static void __cdecl tsl::profiler::TraceMe::InstantActivity<class <lambda_869dba8525ee1b44cc1026465438dec2>,1>(class <lambda_869dba8525ee1b44cc1026465438dec2> &&,int)" (??$InstantActivity@V<lambda_869dba8525ee1b44cc1026465438dec2>@@$00@TraceMe@profiler@tsl@@SAX$$QEAV<lambda_869dba8525ee1b44cc1026465438dec2>@@H@Z)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pmap_lib.lib(pmap_lib.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'pjit.lib(pjit.obj)\'
LINK : warning LNK4286: symbol \'?g_trace_level@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_trace_level)\' defined in \'traceme_recorder_impl.lo.lib(traceme_recorder.obj)\' is imported by \'jax_jit.lib(jax_jit.obj)\'
LINK : warning LNK4217: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(gpu_executable.obj)\' in function \'"public: __cdecl tsl::profiler::ScopedAnnotation::ScopedAnnotation<class <lambda_abcddd76c82971d27d2956bbd819e43a> >(class <lambda_abcddd76c82971d27d2956bbd819e43a>)" (??$?0V<lambda_abcddd76c82971d27d2956bbd819e43a>@@@ScopedAnnotation@profiler@tsl@@QEAA@V<lambda_abcddd76c82971d27d2956bbd819e43a>@@@Z)\'
LINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'gpu_executable.lib(sequential_thunk.obj)\'
LINK : warning LNK4286: symbol \'?g_annotation_enabled@internal@profiler@tsl@@3U?$atomic@H@std@@A (struct std::atomic<int> tsl::profiler::internal::g_annotation_enabled)\' defined in \'annotation_stack_impl.lo.lib(annotation_stack.obj)\' is imported by \'tracing.lib(tracing.obj)\'
bazel-out\\x64_windows-opt\\bin\\external\\org_tensorflow\\tensorflow\\compiler\\xla\\python\\xla_extension.so : fatal error LNK1169: one or more multiply defined symbols found

@ezhulenev
Copy link
Collaborator

We intentionally define this symbol twice, the default implementation uses weak linking here: https://github.com/openxla/xla/blob/f95fde0a6bfe3e7f92f26294586be978434f3e79/xla/runtime/ffi.cc#L123-L128. The "real" implementation in executable.cc is supposed to overwrite it at link time. But looks like MSVC doesn't like it? ABSL_ATTRIBUTE_WEAK is incorrectly defined?

@ezhulenev
Copy link
Collaborator

Skimming through https://stackoverflow.com/questions/2290587/gcc-style-weak-linking-in-visual-studio/11529277#11529277 it looks like GCC/LLVM style weak linking doesn't work in MSVC. If there are no flags/attributes to make it work with MSVC I'll take a look at changing the registration mechanism to something that does not rely on weak symbols.

@adam-hartshorne
Copy link
Author

It seems quite a few build issues come from the fact that MSVC always has its own (often off standard) quirks and the various teams involved with JAX / TF / XLA never use MSVC. Is there any reason not to just move to a situation where the Windows build of JAX is by default set to compile with LLVM (as it is available for Windows)?

@hawkinsp
Copy link
Collaborator

I'm guessing this issue is stale, given we have a Windows CPU CI build that is not showing this problem at head.

And yes, I think we would like to build our Windows wheels with clang, because it would be one fewer difference between platforms. We will probably wait for TensorFlow to make that switch, though, first (I gather it is in progress), simply to share the work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants