Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core dump when copying tensor from cuda to xla gpu device #3466

Open
cicirori opened this issue Apr 1, 2022 · 3 comments
Open

core dump when copying tensor from cuda to xla gpu device #3466

cicirori opened this issue Apr 1, 2022 · 3 comments

Comments

@cicirori
Copy link
Contributor

cicirori commented Apr 1, 2022

🐛 Bug

core dump when copying a tensor from cuda to xla gpu device

To Reproduce

import torch
import torch_xla.core.xla_model as xm

cuda_device = 'cuda:0'
xla_device = xm.xla_device()

t = torch.randn(1024, 1024, device=cuda_device)
# always core dump whether sync or not.
# torch.cuda.synchronize()
t = t.to(xla_device)
xm.mark_step()

Expected behavior

exit successfully

Environment

  • Reproducible on XLA backend: GPU
  • torch_xla version: 1.10

Additional context

core dump call stack

#0  __memmove_avx_unaligned_erms () at ../sysdeps/x86_64/multiarch/memmove-vec-unaligned-erms.S:257
#1  0x00007f7d5588a226 in std::__copy_move<false, true, std::random_access_iterator_tag>::__copy_m<float> (__first=<optimized out>, __result=<optimized out>, __last=<optimized out>)
    at /usr/bin/../lib/gcc/x86_64-linux-gnu/7.5.0/../../../../include/c++/7.5.0/bits/stl_algobase.h:368
#2  std::__copy_move_a<false, float const*, float*> (__first=<optimized out>, __result=<optimized out>, __last=<optimized out>)
    at /usr/bin/../lib/gcc/x86_64-linux-gnu/7.5.0/../../../../include/c++/7.5.0/bits/stl_algobase.h:385
#3  std::__copy_move_a2<false, float const*, float*> (__first=<optimized out>, __result=<optimized out>, __last=<optimized out>)
    at /usr/bin/../lib/gcc/x86_64-linux-gnu/7.5.0/../../../../include/c++/7.5.0/bits/stl_algobase.h:422
#4  std::copy<float const*, float*> (__first=<optimized out>, __result=<optimized out>, __last=<optimized out>) at /usr/bin/../lib/gcc/x86_64-linux-gnu/7.5.0/../../../../include/c++/7.5.0/bits/stl_algobase.h:454
#5  torch_xla::(anonymous namespace)::CopyData<float, float> (dest=<optimized out>, source=<optimized out>, n=<optimized out>) at /opt/pytorch/xla/torch_xla/csrc/tensor_util.cpp:308
#6  torch_xla::(anonymous namespace)::CopyTensors<float, float> (src_buffer=<optimized out>, src_shape=..., dest_buffer=<optimized out>, dest_buffer_size=<optimized out>, dest_shape=...)
    at /opt/pytorch/xla/torch_xla/csrc/tensor_util.cpp:441
#7  0x00007f7d5581f331 in torch_xla::(anonymous namespace)::TensorToBuffer<float, float> (tensor=..., dest_shape=..., dest_buffer=0x7f7530000d00, dest_buffer_size=4194304, device=...)
    at /opt/pytorch/xla/torch_xla/csrc/tensor_util.cpp:474
#8  torch_xla::(anonymous namespace)::TensorToBufferSType<float> (tensor=..., dest_shape=..., dest_buffer=0x7f7530000d00, dest_buffer_size=4194304, device=...)
    at /opt/pytorch/xla/torch_xla/csrc/tensor_util.cpp:492
#9  0x00007f7d3e29a6de in std::_Function_handler<void (), xla::XrtComputationClient::TransferToServerInternal(absl::lts_20210324::Span<xla::ComputationClient::TensorSource const>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) () from /opt/conda/lib/python3.7/site-packages/torch_xla/lib/libxla_computation_client.so
#10 0x00007f7d3e271a3d in xla::util::MultiWait::Complete(std::function<void ()> const&) () from /opt/conda/lib/python3.7/site-packages/torch_xla/lib/libxla_computation_client.so
#11 0x00007f7d3e277440 in xla::env::(anonymous namespace)::ThreadPool::Worker() () from /opt/conda/lib/python3.7/site-packages/torch_xla/lib/libxla_computation_client.so
#12 0x00007f7daf75a6df in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#13 0x00007f7dd25ca6db in start_thread (arg=0x7f75377ae700) at pthread_create.c:463
#14 0x00007f7dd22f361f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95
@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 6, 2022

I think eventually it will call

CopyTensors<SType, DType>(contiguous_tensor.data_ptr<SType>(), src_shape,

not exactly sure what this will output for a at::Tensor on a gpu device. @bdhirsh any idea?

@bdhirsh
Copy link
Collaborator

bdhirsh commented Apr 6, 2022

Hmmm I think the .data_ptr() call on a cuda tensor is fine, but it looks like CopyData eventually calls one of std::copy() or std::memcpy(), which would probably both break if the source buffer isn't on cpu (here)

Easy fix would probably to move the source to cpu first :) but then you're twice as slow. I'm not 100% familiar with the API's you'd need to do a cuda device-to-host copy to an xla memory buffer, but you'd probably need to include some cuda headers?

@bdhirsh
Copy link
Collaborator

bdhirsh commented Apr 6, 2022

Although I wonder if we can do something inside of that function like:

void TensorToBuffer(const at::Tensor& src, const xla::Shape& dest_shape,
                    void* dest_buffer, size_t dest_buffer_size,
                    const Device& device) {
    if (src.is_cuda()) {
      // where does `dest_buffer` live - is it cpu memory? or tpu memory..
      // if it's cpu memory, maybe we can just wrap it in a tensor
      const Tensor dest_as_tensor = at::from_blob(dest_buffer, {dest_buffer_size});
      // ...and copy the cuda tensor directly into the dest buffer' memory
      dest_as_tensor.copy_(src);
    }
    ...
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants