Skip to content

Commit

Permalink
roll back copy
Browse files Browse the repository at this point in the history
  • Loading branch information
rnburn committed Jan 28, 2025
1 parent a680b70 commit 10bbcd9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 66 deletions.
77 changes: 59 additions & 18 deletions sxt/execution/device/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,35 @@
#include <cstring>

#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/pinned_buffer.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/device/generate.h"
#include "sxt/execution/device/synchronization.h"

namespace sxt::xendv {
//--------------------------------------------------------------------------------------------------
// strided_copy_host_to_device_one_sweep
//--------------------------------------------------------------------------------------------------
static xena::future<> strided_copy_host_to_device_one_sweep(std::byte* dst,
const basdv::stream& stream,
const std::byte* src, size_t n,
size_t count, size_t stride) noexcept {
auto num_bytes = n * count;
if (num_bytes == 0) {
co_return;
}
basdv::pinned_buffer buffer;
auto data = static_cast<std::byte*>(buffer.data());
for (size_t i = 0; i < count; ++i) {
std::memcpy(data, src, n);
data += n;
src += stride;
}
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), buffer.data(), num_bytes, stream);
co_await await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// strided_copy_host_to_device
//--------------------------------------------------------------------------------------------------
Expand All @@ -40,26 +63,44 @@ xena::future<> strided_copy_host_to_device(std::byte* dst, const basdv::stream&
// clang-format on
);
auto num_bytes = n * count;
auto f = [&](basct::span<std::byte> buffer, size_t index) noexcept {
auto remaining_size = buffer.size();
auto out = buffer.data();
auto slice_index = index / n;
auto slice_pos = index - n * slice_index;

auto data = src + n * slice_index + slice_pos;
if (num_bytes <= basdv::pinned_buffer::size()) {
co_return co_await strided_copy_host_to_device_one_sweep(dst, stream, src, n, count, stride);
}
auto cur_n = n;

auto chunk_size = std::min(remaining_size, n - slice_pos);
out = std::copy_n(data, chunk_size, out);
data += stride;
remaining_size -= chunk_size;

while (remaining_size > 0) {
chunk_size = std::min(remaining_size, n);
out = std::copy_n(data, chunk_size, out);
data += stride;
auto fill_buffer = [&](basdv::pinned_buffer& buffer) noexcept {
size_t remaining_size = buffer.size();
auto data = static_cast<std::byte*>(buffer.data());
while (remaining_size > 0 && count > 0) {
auto chunk_size = std::min(remaining_size, cur_n);
std::memcpy(data, src, chunk_size);
src += chunk_size;
data += chunk_size;
remaining_size -= chunk_size;
cur_n -= chunk_size;
if (cur_n == 0) {
--count;
cur_n = n;
src += stride - n;
}
}
return buffer.size() - remaining_size;
};
co_await generate_to_device(basct::span<std::byte>{dst, num_bytes}, stream, f);

// copy
basdv::pinned_buffer cur_buffer, alt_buffer;
auto chunk_size = fill_buffer(cur_buffer);
SXT_DEBUG_ASSERT(count > 0, "copy can't be done in a single sweep");
while (count > 0) {
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), cur_buffer.data(), chunk_size,
stream);
dst += chunk_size;
chunk_size = fill_buffer(alt_buffer);
co_await await_stream(stream);
std::swap(cur_buffer, alt_buffer);
}
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), cur_buffer.data(), chunk_size,
stream);
co_await await_stream(stream);
}
} // namespace sxt::xendv
48 changes: 0 additions & 48 deletions sxt/execution/device/copy.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <cstddef>
#include <numeric>
#include <vector>
#include <random>

#include "sxt/base/device/pinned_buffer.h"
#include "sxt/base/device/stream.h"
Expand All @@ -38,7 +37,6 @@ TEST_CASE("we can copy strided memory from host to device") {

basdv::stream stream;

#if 0
SECTION("we can copy empty data") {
auto fut = strided_copy_host_to_device<uint8_t>(dst, stream, src, 1, 0, 0);
REQUIRE(fut.ready());
Expand Down Expand Up @@ -118,50 +116,4 @@ TEST_CASE("we can copy strided memory from host to device") {
basdv::synchronize_device();
REQUIRE(std::vector<uint8_t>(dst.begin(), dst.end()) == src);
}
#endif

#if 0
SECTION("we can perform random copies") {
std::mt19937 rng{0};
for (int i = 0; i < 10; ++i) {
auto n = std::uniform_int_distribution<size_t>{0, bufsize * 10}(rng);
src.resize(4 * n);
auto data = reinterpret_cast<unsigned*>(src.data());
std::iota(data, data + n, 0u);
dst.resize(src.size() - 1);
auto fut = strided_copy_host_to_device<uint8_t>(dst, stream, src, 4 * n - 1, 4 * n - 1, 1);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
REQUIRE(std::vector<uint8_t>(dst.begin(), dst.end()) ==
std::vector<uint8_t>(src.begin() + 1, src.end()));
}
}
#endif

SECTION("we can perform random copies") {
std::mt19937 rng{0};
for (int i = 0; i < 10; ++i) {
auto n = std::uniform_int_distribution<size_t>{0, bufsize * 10}(rng);
src.resize(4 * n);
std::iota(src.begin(), src.end(), 0);
/* auto data = reinterpret_cast<unsigned*>(src.data()); */
/* std::iota(data, data + n, 0u); */
dst.resize(2 * n);
auto fut = strided_copy_host_to_device<uint8_t>(dst, stream, src, 2 * n, n, 0);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
/* for (size_t i=0; i<dst.size(); ++i) { */
for (size_t i = 0; i < 2 * n - 1; ++i) {
auto expected = 2 * n * (i / n) + (i % n);
/* auto expected = n * (i / n) + (i % n); */
/* std::println("dst[{}]", i); */
REQUIRE(dst[i] == static_cast<uint8_t>(expected));
}
std::println("dst[last] = {}", dst[dst.size() - 4]);
/* REQUIRE(std::vector<uint8_t>(dst.begin(), dst.end()) == */
/* std::vector<uint8_t>(src.begin() + 1, src.end())); */
}
}
}

0 comments on commit 10bbcd9

Please sign in to comment.