Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT] Release the GIL during TransferFromServer #4504

Merged
merged 8 commits into from
Jan 27, 2023

Conversation

will-cromar
Copy link
Collaborator

Because execution in PJRT is asynchronous, if we call TransferFromServer on a buffer produced by a computation that is in progress, TransferToServer will block until the underlying buffer is ready. However, in this case, the thread calling TransferFromServer is also holding the GIL. That means that in a multithreaded context, we can end up with a deadlock where one thread is holding the GIL indefinitely while it waits for another replica to participate in a computation, and the thread that must run that computation is blocked because the first thread is holding the GIL.

This started happening consistently with PJRT on TPU v3 some time in mid-December, triggered by an unrelated commit. I confirmed the GIL deadlock by inspecting hanging threads with gdb:

`gdb` output
Thread 456 (Thread 0x7f3aa37fe700 (LWP 856191)):
#0  syscall () at ../sysdeps/unix/sysv/linux/x86_64/syscall.S:38
#1  0x00007f405a13278e in absl::lts_20220623::synchronization_internal::Waiter::Wait(absl::lts_20220623::synchronization_internal::KernelTimeout) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#2  0x00007f405a132652 in AbslInternalPerThreadSemWait_lts_20220623 () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#3  0x00007f405a133a43 in absl::lts_20220623::Mutex::Block(absl::lts_20220623::base_internal::PerThreadSynch*) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#4  0x00007f404d4d7c9c in absl::lts_20220623::Mutex::LockSlowWithDeadline(absl::lts_20220623::MuHowS const*, absl::lts_20220623::Condition const*, absl::lts_20220623::synchronization_internal::KernelTimeout, int) [clone .cold.39] () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#5  0x00007f404d4d7cb0 in absl::lts_20220623::Mutex::LockSlow(absl::lts_20220623::MuHowS const*, absl::lts_20220623::Condition const*, int) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#6  0x00007f405a1351d3 in absl::lts_20220623::Notification::WaitForNotification() const () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#7  0x00007f404d7d5275 in xla::PjRtComputationClient::TransferFromServer(absl::lts_20220623::Span<std::shared_ptr<xla::ComputationClient::Data> const>) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
Thread 452 (Thread 0x7fa11bfff700 (LWP 1271425)):
Traceback (most recent call first):
  Waiting for the GIL
  <built-in method write of _io.TextIOWrapper object at remote 0x7fa81e9d0040>
  <built-in method print of module object at remote 0x7fa81ea2c310>
  (frame information optimized out)
  File "/usr/local/lib/python3.8/site-packages/torch/_tensor.py", line 935, in __hash__
    return id(self)

This problem was irrelevant to XRT for two reasons:

  1. Each replica runs in a distinct process, so they don't share a GIL. (This is also the reason why this problem does not manifest on TPU v4 with PJRT).
  2. Execution is synchronous in XRT, so TransferFromServer will never block and wait for a result.

We have to do one of the following to prevent this deadlock:

  1. Don't release the device lock until execution is finished. This will block e.g. device transfers. In my experiments with different ways of doing this, I saw a consistent 5-10% performance degradation.
  2. Release the GIL while calling TransferFromServer to allow other threads to move forward.

This PR implements (2) to preserve performance.

@will-cromar will-cromar added DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing runtime labels Jan 25, 2023
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Will for the investigation. @ronghanghu FYI

@will-cromar
Copy link
Collaborator Author

This looks like a real error from the CPU test: test_view_copy_out_xla (__main__.TestViewOpsXLA) ... Fatal Python error: PyEval_SaveThread: NULL tstate

The linter is just getting confused. I'll expand those Py_BEGIN_ALLOW_THREADS macros since clang-format doesn't know what to do with them.

@will-cromar
Copy link
Collaborator Author

I'm not sure where this test actually comes from: run_dynamic python3 /tmp/pytorch/xla/test/../../test/test_view_ops.py -v TestViewOpsXLA

I get this error:

$ python ../test/test_view_ops.py -v TestViewOpsXLA
Fail to import hypothesis in common_utils, tests are not derandomized
TestViewOpsXLA (unittest.loader._FailedTest) ... ERROR

======================================================================
ERROR: TestViewOpsXLA (unittest.loader._FailedTest)
----------------------------------------------------------------------
AttributeError: module '__main__' has no attribute 'TestViewOpsXLA'

----------------------------------------------------------------------
Ran 1 test in 0.000s

FAILED (errors=1)

Also nothing if I search this test name: https://github.com/pytorch/pytorch/search?q=TestViewOpsXLA

@will-cromar
Copy link
Collaborator Author

I can reproduce the error locally like this:

$ PJRT_DEVICE=CPU python test/test_operations.py -v TestAtenXlaTensor.test_masked_fill_with_tensor
/usr/local/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libc10_cuda.so: cannot open shared object file: No such file or directory
  warn(f"Failed to load image Python extension: {e}")
test_masked_fill_with_tensor (__main__.TestAtenXlaTensor) ... Fatal Python error: PyEval_SaveThread: NULL tstate
Python runtime state: initialized

@will-cromar
Copy link
Collaborator Author

The next step is to rebase onto #4503 and confirm performance on v4 and to double-check this on XRT. I don't expect either to be impacted since both cases have replicas in separate processes.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Jan 25, 2023

Tested this with PJRT and XRT on both v3 and v4. I don't see any noticeable impact on performance in any case.*

* For some reason, the performance is lower with XRT than it is in our automated test, but that's true with or without this change. I'll keep an eye on the benchmarks after this is merged.

@will-cromar
Copy link
Collaborator Author

The CI test failure looks real. I can reproduce it with XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localservice:0;grpc://localhost:51011" test/cpp/run_tests.sh -F AtenXlaTensorTest.TestSymSizes

@will-cromar will-cromar marked this pull request as ready for review January 25, 2023 22:29
@will-cromar
Copy link
Collaborator Author

I have the CPP tests passing locally again. 🤞

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do NoGilSection in the python binding layer? Just doesn't feel right to have code below the python binding layer to deal with any python concepts. wdyt?

@will-cromar
Copy link
Collaborator Author

I agree, and this was my first thought as well. But I traced the offending call back all the way, and it didn't pass through our python binding layer (init_python_bindings.cpp) at all. I'm definitely open to moving this to a more logical place if we can find one, though.

Attaching the full gdb output for reference.

Full `gdb` trace
Thread 456 (Thread 0x7f3aa37fe700 (LWP 856191)):
#0  syscall () at ../sysdeps/unix/sysv/linux/x86_64/syscall.S:38
#1  0x00007f405a13278e in absl::lts_20220623::synchronization_internal::Waiter::Wait(absl::lts_20220623::synchronization_internal::KernelTimeout) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#2  0x00007f405a132652 in AbslInternalPerThreadSemWait_lts_20220623 () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#3  0x00007f405a133a43 in absl::lts_20220623::Mutex::Block(absl::lts_20220623::base_internal::PerThreadSynch*) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#4  0x00007f404d4d7c9c in absl::lts_20220623::Mutex::LockSlowWithDeadline(absl::lts_20220623::MuHowS const*, absl::lts_20220623::Condition const*, absl::lts_20220623::synchronization_internal::KernelTimeout, int) [clone .cold.39] () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#5  0x00007f404d4d7cb0 in absl::lts_20220623::Mutex::LockSlow(absl::lts_20220623::MuHowS const*, absl::lts_20220623::Condition const*, int) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#6  0x00007f405a1351d3 in absl::lts_20220623::Notification::WaitForNotification() const () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#7  0x00007f404d7d5275 in xla::PjRtComputationClient::TransferFromServer(absl::lts_20220623::Span<std::shared_ptr<xla::ComputationClient::Data> const>) () from /pytorch/xla/torch_xla/lib/libxla_computation_client.so
#8  0x00007f405e0191f9 in torch_xla::XLAGraphExecutor::GetTensorsFused (this=<optimized out>, tensors=0x7f3aa37fb860) at torch_xla/csrc/xla_graph_executor.cpp:630
#9  0x00007f405e01872d in torch_xla::XLAGraphExecutor::GetTensors (this=0x7f405e402548 <torch_xla::XLAGraphExecutor::Get()::arena>, tensors=0x7f3aa37fb860) at torch_xla/csrc/xla_graph_executor.cpp:441
#10 0x00007f405e00a783 in torch_xla::bridge::XlaCreateTensorList (tensors=...) at torch_xla/csrc/aten_xla_bridge.cpp:154
#11 0x00007f405def2c01 in torch_xla::XLANativeFunctions::_to_cpu (tensors=...) at torch_xla/csrc/aten_xla_type.cpp:507
#12 0x00007f405e17895f in at::(anonymous namespace)::(anonymous namespace)::wrapper___to_cpu (tensors=...) at torch_xla/csrc/generated/RegisterXLA.cpp:360
#13 c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::ArrayRef<at::Tensor>), &at::(anonymous namespace)::(anonymous namespace)::wrapper___to_cpu>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::ArrayRef<at::Tensor> > >::operator()(c10::ArrayRef<at::Tensor>) (this=<optimized out>, args=...) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13
#14 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::ArrayRef<at::Tensor>), &at::(anonymous namespace)::(anonymous namespace)::wrapper___to_cpu>, std::vector<at::Tensor, std::allocator<at::Tensor> >, c10::guts::typelist::typelist<c10::ArrayRef<at::Tensor> > >, std::vector<at::Tensor, std::allocator<at::Tensor> > (c10::ArrayRef<at::Tensor>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, c10::ArrayRef<at::Tensor>) (functor=<optimized out>, args=...) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:461
#15 0x00007f41a3bc1fd8 in at::_ops::_to_cpu::call(c10::ArrayRef<at::Tensor>) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007f41a323c51f in at::native::to_cpu(c10::ArrayRef<at::Tensor> const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007f41a323d1b4 in at::native::cpu_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007f405dffb599 in torch_xla::xla_cpu_fallback (op=..., stack=0x7f3aa37fc068) at torch_xla/csrc/aten_cpu_fallback.cpp:45
#19 0x00007f405df59ce9 in c10::BoxedKernel::callBoxed (this=0x7f3aa37fc0e8, opHandle=..., dispatchKeySet=..., stack=0x7f3a91322e80) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h:41
#20 c10::impl::BoxedKernelWrapper<c10::Scalar (at::Tensor const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&) (boxed_kernel_func=..., opHandle=..., dispatchKeySet=..., args=...) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/boxing.h:227
#21 0x00007f405df3fb7c in at::native::_call_fallback_fn<&torch_xla::xla_cpu_fallback, at::_ops::_local_scalar_dense, false, c10::Scalar (at::Tensor const&)>::call(at::Tensor const&) (args=...) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/native/CPUFallback.h:29
#22 0x00007f405df342a4 in torch_xla::XLANativeFunctions::_local_scalar_dense (self=...) at torch_xla/csrc/aten_xla_type.cpp:3047
#23 0x00007f405e17649c in at::(anonymous namespace)::(anonymous namespace)::wrapper___local_scalar_dense (self=...) at torch_xla/csrc/generated/RegisterXLA.cpp:271
#24 c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper___local_scalar_dense>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >::operator()(at::Tensor const&) (this=<optimized out>, args=...) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h:13
#25 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper___local_scalar_dense>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (functor=<optimized out>, args=...) at /usr/local/lib/python3.8/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:461
#26 0x00007f41a3f4a092 in at::_ops::_local_scalar_dense::redispatch(c10::DispatchKeySet, at::Tensor const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#27 0x00007f41a5cb4a1f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::_local_scalar_dense>, c10::Scalar, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, c10::Scalar (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#28 0x00007f41a3f49e41 in at::_ops::_local_scalar_dense::call(at::Tensor const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#29 0x00007f41a354acfc in at::native::item(at::Tensor const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#30 0x00007f41a479a74c in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper__item>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#31 0x00007f41a3da5bd1 in at::_ops::item::call(at::Tensor const&) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#32 0x00007f41ae103aae in torch::autograd::THPVariable_item(_object*, _object*) () from /usr/local/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#33 0x00007f41afd19019 in cfunction_vectorcall_NOARGS (func=<built-in method item of Tensor object at remote 0x7f3bc0033e00>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/methodobject.c:463
#34 0x00007f41afdde140 in _PyObject_Vectorcall (callable=<built-in method item of Tensor object at remote 0x7f3bc0033e00>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ./Include/cpython/abstract.h:127
#35 0x00007f41afc8e15d in trace_call_function (tstate=0x5629d3119e20, func=<built-in method item of Tensor object at remote 0x7f3bc0033e00>, args=0x7f3a983e3088, nargs=1, kwnames=0x0) at Python/ceval.c:4938
#36 0x00007f41afc8cdd2 in call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4960
#37 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3486
#38 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=6, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#39 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f4038618718, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#40 0x00007f41afcee0ed in PyVectorcall_Call (callable=<function at remote 0x7f40387aaca0>, tuple=<optimized out>, kwargs=<optimized out>) at Objects/call.c:200
#41 0x00007f41afd66fcb in do_call_core (kwdict=0x0, callargs=(<torch.device at remote 0x7f40385b60f0>, 20, <Tensor at remote 0x7f3bc0033e00>, <RateTracker(_smooth_factor=<float at remote 0x7f405e448930>, _start_time=<float at remote 0x7f3cf83095f0>, _partial_time=<float at remote 0x7f3ce808e0f0>, _partial_count=<float at remote 0x7f3ce808e170>, _partial_rate=<float at remote 0x7f3ce808e110>, _count=<float at remote 0x7f40385b65b0>) at remote 0x7f3ce806a220>, 1, None), func=<function at remote 0x7f40387aaca0>, tstate=<optimized out>) at Python/ceval.c:5010
#42 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3559
#43 0x00007f41afd64668 in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=locals@entry=0x0, args=args@entry=0x7f3ce808a5c0, argcount=argcount@entry=0, kwnames=0x0, kwargs=0x7f3ce808a5c0, kwcount=<optimized out>, kwstep=1, defs=0x7f4038619e38, defcount=1, kwdefs=0x0, closure=(<cell at remote 0x7f3ce80898b0>,), name='<lambda>', qualname='add_step_closure.<locals>.<lambda>') at Python/ceval.c:4298
#44 0x00007f41afced062 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f3ce808a5c0, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:436
#45 0x00007f41afdde140 in _PyObject_Vectorcall (callable=<function at remote 0x7f40385b4f70>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ./Include/cpython/abstract.h:127
#46 0x00007f41afc8c299 in call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4960
#47 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3500
#48 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=0, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#49 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f3a98f1c6a8, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#50 0x00007f41afdde140 in _PyObject_Vectorcall (callable=<function at remote 0x7f40397d0af0>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ./Include/cpython/abstract.h:127
#51 0x00007f41afc8c299 in call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4960
#52 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3500
#53 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=0, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#54 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f37160f9798, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#55 0x00007f41afdde140 in _PyObject_Vectorcall (callable=<function at remote 0x7f40397d0b80>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ./Include/cpython/abstract.h:127
#56 0x00007f41afc8c9ac in call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4960
#57 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3469
#58 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=1, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#59 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f3a90554068, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#60 0x00007f41afdde140 in _PyObject_Vectorcall (callable=<function at remote 0x7f40387aa310>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ./Include/cpython/abstract.h:127
#61 0x00007f41afc8cdd2 in call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4960
#62 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3486
#63 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=1, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#64 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f3aa37fce20, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#65 0x00007f41afceef53 in _PyObject_Vectorcall (kwnames=0x0, nargsf=1, args=0x7f3aa37fce20, callable=<function at remote 0x7f40387aa1f0>) at ./Include/cpython/abstract.h:127
#66 _PyObject_FastCall (nargs=1, args=0x7f3aa37fce20, func=<function at remote 0x7f40387aa1f0>) at ./Include/cpython/abstract.h:147
#67 _PyObject_FastCall_Prepend (callable=callable@entry=<function at remote 0x7f40387aa1f0>, obj=obj@entry=<PerDeviceLoader(_loader=<ParallelLoader(_loader=<SampleGenerator(_data=(<Tensor at remote 0x7f40385bd090>, <Tensor at remote 0x7f40385bd180>), _sample_count=4687, _count=0) at remote 0x7f40385b39d0>, _devices=[<torch.device at remote 0x7f3cf817fcf0>], _batchdim=0, _batches_per_execution=1, _done=False, _queues={<torch.device at remote 0x7f3cf817fcf0>: <PerDeviceQueue(device=<torch.device at remote 0x7f3cf817fcf0>, loader_queue=<Queue(_maxsize=8, _lock=<_thread.lock at remote 0x7f3ce806a390>, _ready_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock object at remote 0x7f3ce806a390>, _waiters=<collections.deque at remote 0x7f40385c8280>) at remote 0x7f3ce806a3d0>, _space_available_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock ...(truncated), args=args@entry=0x0, nargs=1, nargs@entry=0) at Objects/call.c:850
#68 0x00007f41afd351d5 in call_unbound (nargs=0, args=0x0, self=<PerDeviceLoader(_loader=<ParallelLoader(_loader=<SampleGenerator(_data=(<Tensor at remote 0x7f40385bd090>, <Tensor at remote 0x7f40385bd180>), _sample_count=4687, _count=0) at remote 0x7f40385b39d0>, _devices=[<torch.device at remote 0x7f3cf817fcf0>], _batchdim=0, _batches_per_execution=1, _done=False, _queues={<torch.device at remote 0x7f3cf817fcf0>: <PerDeviceQueue(device=<torch.device at remote 0x7f3cf817fcf0>, loader_queue=<Queue(_maxsize=8, _lock=<_thread.lock at remote 0x7f3ce806a390>, _ready_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock object at remote 0x7f3ce806a390>, _waiters=<collections.deque at remote 0x7f40385c8280>) at remote 0x7f3ce806a3d0>, _space_available_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock ...(truncated), func=<function at remote 0x7f40387aa1f0>, unbound=1) at Objects/typeobject.c:1485
#69 call_method (nargs=0, args=0x0, name=0x7f41aff18020 <PyId___next__.16531>, obj=<PerDeviceLoader(_loader=<ParallelLoader(_loader=<SampleGenerator(_data=(<Tensor at remote 0x7f40385bd090>, <Tensor at remote 0x7f40385bd180>), _sample_count=4687, _count=0) at remote 0x7f40385b39d0>, _devices=[<torch.device at remote 0x7f3cf817fcf0>], _batchdim=0, _batches_per_execution=1, _done=False, _queues={<torch.device at remote 0x7f3cf817fcf0>: <PerDeviceQueue(device=<torch.device at remote 0x7f3cf817fcf0>, loader_queue=<Queue(_maxsize=8, _lock=<_thread.lock at remote 0x7f3ce806a390>, _ready_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock object at remote 0x7f3ce806a390>, _waiters=<collections.deque at remote 0x7f40385c8280>) at remote 0x7f3ce806a3d0>, _space_available_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock ...(truncated)) at Objects/typeobject.c:1485
#70 slot_tp_iternext (self=<PerDeviceLoader(_loader=<ParallelLoader(_loader=<SampleGenerator(_data=(<Tensor at remote 0x7f40385bd090>, <Tensor at remote 0x7f40385bd180>), _sample_count=4687, _count=0) at remote 0x7f40385b39d0>, _devices=[<torch.device at remote 0x7f3cf817fcf0>], _batchdim=0, _batches_per_execution=1, _done=False, _queues={<torch.device at remote 0x7f3cf817fcf0>: <PerDeviceQueue(device=<torch.device at remote 0x7f3cf817fcf0>, loader_queue=<Queue(_maxsize=8, _lock=<_thread.lock at remote 0x7f3ce806a390>, _ready_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock object at remote 0x7f3ce806a390>, _waiters=<collections.deque at remote 0x7f40385c8280>) at remote 0x7f3ce806a3d0>, _space_available_cv=<Condition(_lock=<_thread.lock at remote 0x7f3ce806a390>, acquire=<built-in method acquire of _thread.lock object at remote 0x7f3ce806a390>, release=<built-in method release of _thread.lock ...(truncated)) at Objects/typeobject.c:6732
#71 0x00007f41afcf346f in enum_next (en=0x7f3ce8068ac0) at Objects/enumobject.c:162
#72 0x00007f41afd6592e in _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3202
#73 0x00007f41afd647ed in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=locals@entry=0x0, args=args@entry=0x7f3a90001298, argcount=argcount@entry=2, kwnames=0x0, kwargs=0x7f3a900012a8, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=(<cell at remote 0x7f40385b3a30>, <cell at remote 0x7f40385b3a00>, <cell at remote 0x7f40385b39a0>, <cell at remote 0x7f40385b38e0>, <cell at remote 0x7f40385b3af0>, <cell at remote 0x7f40385b3940>, <cell at remote 0x7f40385b3910>), name='train_loop_fn', qualname='train_imagenet.<locals>.train_loop_fn') at Python/ceval.c:4298
#74 0x00007f41afced062 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f3a90001298, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:436
#75 0x00007f41afdde140 in _PyObject_Vectorcall (callable=<function at remote 0x7f40385b4940>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ./Include/cpython/abstract.h:127
#76 0x00007f41afc8c299 in call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4960
#77 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3500
#78 0x00007f41afd647ed in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=locals@entry=0x0, args=args@entry=0x7f40397c6fb8, argcount=argcount@entry=1, kwnames=kwnames@entry=0x0, kwargs=0x7f40397c6fc0, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name='train_imagenet', qualname='train_imagenet') at Python/ceval.c:4298
#79 0x00007f41afcec48f in _PyFunction_Vectorcall (kwnames=0x0, nargsf=<optimized out>, stack=0x7f40397c6fb8, func=<function at remote 0x7f403866d310>) at Objects/call.c:436
#80 _PyObject_FastCallDict (callable=<function at remote 0x7f403866d310>, args=args@entry=0x7f40397c6fb8, nargsf=<optimized out>, kwargs=kwargs@entry=0x0) at Objects/call.c:96
#81 0x00007f41afdff219 in partial_fastcall (pto=0x7f403860d4a0, pto=0x7f403860d4a0, kwargs=0x0, nargs=0, args=0x18) at ./Modules/_functoolsmodule.c:170
#82 partial_call (pto=pto@entry=0x7f403860d4a0, args=args@entry=(), kwargs=kwargs@entry=0x0) at ./Modules/_functoolsmodule.c:225
#83 0x00007f41afcec804 in _PyObject_MakeTpCall (callable=<functools.partial at remote 0x7f403860d4a0>, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at Objects/call.c:159
#84 0x00007f41afd69767 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7f40385bad90, callable=<optimized out>) at ./Include/cpython/abstract.h:125
#85 _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7f40385bad90, callable=<optimized out>) at ./Include/cpython/abstract.h:115
#86 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4963
#87 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3500
#88 0x00007f41afd64668 in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=locals@entry=0x0, args=args@entry=0x7f40385b36b8, argcount=argcount@entry=1, kwnames=0x0, kwargs=0x7f40385b36c0, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=(<cell at remote 0x7f403860a640>,), name='_thread_fn', qualname='_run_thread_per_device.<locals>._thread_fn') at Python/ceval.c:4298
#89 0x00007f41afced062 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f40385b36b8, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:436
#90 0x00007f41afcee0ed in PyVectorcall_Call (callable=<function at remote 0x7f405e420160>, tuple=<optimized out>, kwargs=<optimized out>) at Objects/call.c:200
#91 0x00007f41afd66fcb in do_call_core (kwdict={}, callargs=('xla:1',), func=<function at remote 0x7f405e420160>, tstate=<optimized out>) at Python/ceval.c:5010
#92 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3559
#93 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=1, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#94 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f40385c25c0, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#95 0x00007f41afd656f2 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7f40385c25c0, callable=<function at remote 0x7f40385b43a0>) at ./Include/cpython/abstract.h:127
#96 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4963
#97 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3486
#98 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=4, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#99 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f40385b2aa8, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#100 0x00007f41afcee0ed in PyVectorcall_Call (callable=<function at remote 0x7f40385b4280>, tuple=<optimized out>, kwargs=<optimized out>) at Objects/call.c:200
#101 0x00007f41afd66fcb in do_call_core (kwdict={}, callargs=(<weakref at remote 0x7f40385b2c20>, <_queue.SimpleQueue at remote 0x7f405ef0b040>, None, ()), func=<function at remote 0x7f40385b4280>, tstate=<optimized out>) at Python/ceval.c:5010
#102 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3559
#103 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=1, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#104 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f40385bf9b8, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#105 0x00007f41afd656f2 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7f40385bf9b8, callable=<function at remote 0x7f41af613a60>) at ./Include/cpython/abstract.h:127
#106 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4963
#107 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3486
#108 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=1, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#109 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f40385b5df8, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#110 0x00007f41afd656f2 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7f40385b5df8, callable=<function at remote 0x7f41af613d30>) at ./Include/cpython/abstract.h:127
#111 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5629d3119e20) at Python/ceval.c:4963
#112 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at Python/ceval.c:3486
#113 0x00007f41afcecfca in function_code_fastcall (globals=<optimized out>, nargs=1, args=<optimized out>, co=<optimized out>) at Objects/call.c:284
#114 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f3aa37fde08, nargsf=<optimized out>, kwnames=<optimized out>) at Objects/call.c:411
#115 0x00007f41afcef359 in _PyObject_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, args=<optimized out>, callable=<optimized out>) at ./Include/cpython/abstract.h:127
#116 method_vectorcall (method=method@entry=<method at remote 0x7f4038942080>, args=<optimized out>, nargsf=nargsf@entry=0, kwnames=<optimized out>) at Objects/classobject.c:67
#117 0x00007f41afcee0ed in PyVectorcall_Call (callable=<method at remote 0x7f4038942080>, tuple=<optimized out>, kwargs=<optimized out>) at Objects/call.c:200
#118 0x00007f41afe083c7 in t_bootstrap (boot_raw=boot_raw@entry=0x7f40385b37b0) at ./Modules/_threadmodule.c:1002
#119 0x00007f41afdf0314 in pythread_wrapper (arg=<optimized out>) at Python/thread_pthread.h:232
#120 0x00007f41afb70fa3 in start_thread (arg=<optimized out>) at pthread_create.c:486
#121 0x00007f41af91306f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

@alanwaketan
Copy link
Collaborator

alanwaketan commented Jan 25, 2023

Yea, after eyeballing the trace. It looks like the GIL is held by the PyTorch layer. I guess it's okay then. Did you consider that maybe we can move the NoGilSection section out from pybind layer to a shared header, and reuse that code here?

@will-cromar
Copy link
Collaborator Author

I thought about bringing in NoGilSection, but in this case we have to explicitly check if Python is initialized and that the GIL is not already released before we save the thread state. What do you think of modifying NoGilSection to perform the same checks?

@alanwaketan
Copy link
Collaborator

It looks like the upstream is using gil_scoped_release and gil_scoped_acquire. Maybe we can try these two as well?
https://pybind11.readthedocs.io/en/stable/advanced/misc.html

@will-cromar
Copy link
Collaborator Author

Good find! Skimming the code, I'm not sure how it will handle our case where this function is called from a non-python context (C++ tests mainly) or the GIL has already been released. I found that PyEval_SaveThread crashes when either case is true.

https://github.com/pybind/pybind11/blob/b07d08f6009a52050bb588a3e6fb0489c98af85e/include/pybind11/gil.h#L145-L146

@alanwaketan
Copy link
Collaborator

Haha, not even pybind people has thought about our use cases... Let's leave it as it is now. But probably add some notes suggesting that this code is more like an exception. I just try to prevent people overusing them in the cpp layer.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@will-cromar
Copy link
Collaborator Author

will-cromar commented Jan 26, 2023

Confirmed performance after rebasing and added a warning in the comments. We have other risky changes going into the next nightly build (especially the TF pin update), so I'm going to hold off on merging this until tests run again.

@will-cromar
Copy link
Collaborator Author

Implemented offline request from @yeounoh to add an easy switch to disable this change

@will-cromar will-cromar removed the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label Jan 26, 2023
@will-cromar will-cromar merged commit d41c5ba into master Jan 27, 2023
ManfeiBai pushed a commit that referenced this pull request Jan 30, 2023
* Release the GIL during `TransferFromServer`

* Check for GIL before releasing it

* Add comment

* formatting

* Check for Python intitialization before GIL status

* Add explanation and warning

* Add temporary switch to hold onto GIL

* Formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants