Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix scalar transpose for non-power of two inputs (PROOF-925) #215

Closed
wants to merge 17 commits into from
25 changes: 24 additions & 1 deletion sxt/execution/device/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,33 @@ sxt_cc_component(
)

sxt_cc_component(
name = "copy",
name = "generate",
impl_deps = [
],
test_deps = [
"//sxt/base/device:pinned_buffer",
"//sxt/base/device:stream",
"//sxt/base/device:synchronization",
"//sxt/base/test:unit_test",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/resource:managed_device_resource",
],
deps = [
":synchronization",
"//sxt/base/container:span",
"//sxt/base/device:memory_utility",
"//sxt/base/device:pinned_buffer",
"//sxt/base/device:stream",
"//sxt/base/error:assert",
"//sxt/execution/async:coroutine",
"//sxt/execution/async:future",
],
)

sxt_cc_component(
name = "copy",
impl_deps = [
":generate",
"//sxt/base/device:memory_utility",
"//sxt/base/device:stream",
"//sxt/execution/async:coroutine",
Expand Down
77 changes: 18 additions & 59 deletions sxt/execution/device/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,12 @@
#include <cstring>

#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/pinned_buffer.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/device/synchronization.h"
#include "sxt/execution/device/generate.h"

namespace sxt::xendv {
//--------------------------------------------------------------------------------------------------
// strided_copy_host_to_device_one_sweep
//--------------------------------------------------------------------------------------------------
static xena::future<> strided_copy_host_to_device_one_sweep(std::byte* dst,
const basdv::stream& stream,
const std::byte* src, size_t n,
size_t count, size_t stride) noexcept {
auto num_bytes = n * count;
if (num_bytes == 0) {
co_return;
}
basdv::pinned_buffer buffer;
auto data = static_cast<std::byte*>(buffer.data());
for (size_t i = 0; i < count; ++i) {
std::memcpy(data, src, n);
data += n;
src += stride;
}
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), buffer.data(), num_bytes, stream);
co_await await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// strided_copy_host_to_device
//--------------------------------------------------------------------------------------------------
Expand All @@ -63,44 +40,26 @@ xena::future<> strided_copy_host_to_device(std::byte* dst, const basdv::stream&
// clang-format on
);
auto num_bytes = n * count;
if (num_bytes <= basdv::pinned_buffer::size()) {
co_return co_await strided_copy_host_to_device_one_sweep(dst, stream, src, n, count, stride);
}
auto cur_n = n;
auto f = [&](basct::span<std::byte> buffer, size_t index) noexcept {
auto remaining_size = buffer.size();
auto out = buffer.data();
auto slice_index = index / n;
auto slice_pos = index - n * slice_index;

auto data = src + n * slice_index + slice_pos;

auto fill_buffer = [&](basdv::pinned_buffer& buffer) noexcept {
size_t remaining_size = buffer.size();
auto data = static_cast<std::byte*>(buffer.data());
while (remaining_size > 0 && count > 0) {
auto chunk_size = std::min(remaining_size, cur_n);
std::memcpy(data, src, chunk_size);
src += chunk_size;
data += chunk_size;
auto chunk_size = std::min(remaining_size, n - slice_pos);
out = std::copy_n(data, chunk_size, out);
data += stride;
remaining_size -= chunk_size;

while (remaining_size > 0) {
chunk_size = std::min(remaining_size, n);
out = std::copy_n(data, chunk_size, out);
data += stride;
remaining_size -= chunk_size;
cur_n -= chunk_size;
if (cur_n == 0) {
--count;
cur_n = n;
src += stride - n;
}
}
return buffer.size() - remaining_size;
};

// copy
basdv::pinned_buffer cur_buffer, alt_buffer;
auto chunk_size = fill_buffer(cur_buffer);
SXT_DEBUG_ASSERT(count > 0, "copy can't be done in a single sweep");
while (count > 0) {
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), cur_buffer.data(), chunk_size,
stream);
dst += chunk_size;
chunk_size = fill_buffer(alt_buffer);
co_await await_stream(stream);
std::swap(cur_buffer, alt_buffer);
}
basdv::async_memcpy_host_to_device(static_cast<void*>(dst), cur_buffer.data(), chunk_size,
stream);
co_await await_stream(stream);
co_await generate_to_device(basct::span<std::byte>{dst, num_bytes}, stream, f);
}
} // namespace sxt::xendv
17 changes: 17 additions & 0 deletions sxt/execution/device/generate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/execution/device/generate.h"
112 changes: 112 additions & 0 deletions sxt/execution/device/generate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/pinned_buffer.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/device/synchronization.h"

namespace sxt::xendv {
//--------------------------------------------------------------------------------------------------
// generate_to_device_one_sweep
//--------------------------------------------------------------------------------------------------
template <class T, class F>
requires requires(basct::span<T> buf, F f, size_t i) {
{ f(buf, i) } noexcept;
}
xena::future<> generate_to_device_one_sweep(basct::span<T> dst, const basdv::stream& stream,
F f) noexcept {
if (dst.empty()) {
co_return;
}
auto n = dst.size();
auto num_bytes = n * sizeof(T);
SXT_RELEASE_ASSERT(
// clang-format off
basdv::is_active_device_pointer(dst.data()) &&
num_bytes <= basdv::pinned_buffer::size()
// clang-format on
);
basdv::pinned_buffer buffer;
auto data = static_cast<T*>(buffer.data());
f(basct::span<T>{data, n}, 0u);
basdv::async_memcpy_host_to_device(static_cast<void*>(dst.data()), buffer.data(), num_bytes,
stream);
co_await await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// generate_to_device
//--------------------------------------------------------------------------------------------------
template <class T, class F>
requires requires(basct::span<T> buffer, F f, size_t i) {
{ f(buffer, i) } noexcept;
}
xena::future<> generate_to_device(basct::span<T> dst, const basdv::stream& stream, F f) noexcept {
if (dst.empty()) {
co_return;
}
auto n = dst.size();
SXT_RELEASE_ASSERT(
// clang-format off
basdv::is_active_device_pointer(dst.data()) &&
sizeof(T) < basdv::pinned_buffer::size()
// clang-format on
);
auto num_bytes = n * sizeof(T);
if (num_bytes <= basdv::pinned_buffer::size()) {
co_return co_await generate_to_device_one_sweep(dst, stream, f);
}
std::byte* out = reinterpret_cast<std::byte*>(dst.data());
size_t pos = 0;

auto fill_buffer = [&](basdv::pinned_buffer& buffer) noexcept {
size_t remaining_size = buffer.size();
auto data = static_cast<T*>(buffer.data());
while (remaining_size > 0 && pos < n) {
auto chunk_size = std::min(remaining_size / sizeof(T), n - pos);
if (chunk_size == 0) {
break;
}
f(basct::span<T>{data, chunk_size}, pos);
data += chunk_size;
remaining_size -= chunk_size * sizeof(T);
pos += chunk_size;
}
return buffer.size() - remaining_size;
};

// copy
basdv::pinned_buffer cur_buffer, alt_buffer;
auto chunk_size = fill_buffer(cur_buffer);
SXT_DEBUG_ASSERT(pos < n, "copy can't be done in a single sweep");
while (pos < n) {
basdv::async_memcpy_host_to_device(static_cast<void*>(out), cur_buffer.data(), chunk_size,
stream);
out += chunk_size;
chunk_size = fill_buffer(alt_buffer);
co_await await_stream(stream);
std::swap(cur_buffer, alt_buffer);
}
basdv::async_memcpy_host_to_device(static_cast<void*>(out), cur_buffer.data(), chunk_size,
stream);
co_await await_stream(stream);
}
} // namespace sxt::xendv
89 changes: 89 additions & 0 deletions sxt/execution/device/generate.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/execution/device/generate.h"

#include <cstddef>
#include <numeric>
#include <vector>

#include "sxt/base/device/pinned_buffer.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/device/synchronization.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/execution/schedule/scheduler.h"
#include "sxt/memory/resource/managed_device_resource.h"

using namespace sxt;
using namespace sxt::xendv;

TEST_CASE("we can generate an array into device memory") {
const auto bufsize = basdv::pinned_buffer::size();
std::pmr::vector<uint8_t> dst{memr::get_managed_device_resource()};

basdv::stream stream;

auto f = []<class T>(basct::span<T> buffer, size_t index) noexcept {
for (auto& x : buffer) {
x = static_cast<T>(index++);
}
};

SECTION("we can generate an empty array") {
auto fut = generate_to_device<uint8_t>(dst, stream, f);
REQUIRE(fut.ready());
}

SECTION("we can generate a single element") {
dst.resize(1);
auto fut = generate_to_device<uint8_t>(dst, stream, f);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
REQUIRE(dst[0] == 0);
}

SECTION("we can generate two elements") {
dst.resize(2);
auto fut = generate_to_device<uint8_t>(dst, stream, f);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
REQUIRE(dst[0] == 0);
REQUIRE(dst[1] == 1);
}

SECTION("we can generate ints") {
std::pmr::vector<int> dst_p{2, memr::get_managed_device_resource()};
auto fut = generate_to_device<int>(dst_p, stream, f);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
REQUIRE(dst_p[0] == 0);
REQUIRE(dst_p[1] == 1);
}

SECTION("we can generate elements larger than the buffersize") {
std::pmr::vector<int> dst_p{bufsize, memr::get_managed_device_resource()};
auto fut = generate_to_device<int>(dst_p, stream, f);
xens::get_scheduler().run();
REQUIRE(fut.ready());
basdv::synchronize_device();
for (int i = 0; i < bufsize; ++i) {
REQUIRE(dst_p[i] == i);
}
}
}
6 changes: 1 addition & 5 deletions sxt/multiexp/base/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,15 @@ sxt_cc_component(
sxt_cc_component(
name = "scalar_array",
impl_deps = [
"@local_cuda//:cub",
"//sxt/base/container:span_utility",
"//sxt/base/device:memory_utility",
"//sxt/base/device:stream",
"//sxt/base/error:assert",
"//sxt/base/num:ceil_log2",
"//sxt/base/num:constexpr_switch",
"//sxt/base/num:divide_up",
"//sxt/execution/async:future",
"//sxt/execution/async:coroutine",
"//sxt/execution/device:generate",
"//sxt/execution/device:synchronization",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:async_device_resource",
],
test_deps = [
"//sxt/base/device:stream",
Expand Down
Loading